Giter Club home page Giter Club logo

spatial-transformer-network's Introduction

Spatial Transformer Networks

This is a Tensorflow implementation of Spatial Transformer Networks by Max Jaderberg, Karen Simonyan, Andrew Zisserman and Koray Kavukcuoglu, accompanying by two-part blog tutorial series.

Spatial Transformer Networks (STN) is a differentiable module that can be inserted anywhere in ConvNet architecture to increase its geometric invariance. It effectively gives the network the ability to spatially transform feature maps at no extra data or supervision cost.

Installation

Install the stn package using:

pip3 install stn

Then, you can call the STN layer as follows:

from stn import spatial_transformer_network as transformer

out = transformer(input_feature_map, theta, out_dims)

Parameters

  • input_feature_map: the output of the layer preceding the localization network. If the STN layer is the first layer of the network, then this corresponds to the input images. Shape should be (B, H, W, C).
  • theta: this is the output of the localization network. Shape should be (B, 6)
  • out_dims: desired (H, W) of the output feature map. Useful for upsampling or downsampling. If not specified, then output dimensions will be equal to input_feature_map dimensions.

Background Information

The STN is composed of 3 elements.

  • localization network: takes the feature map as input and outputs the parameters of the affine transformation that should be applied to that feature map.

  • grid generator: generates a grid of (x,y) coordinates using the parameters of the affine transformation that correspond to a set of points where the input feature map should be sampled to produce the transformed output feature map.

  • bilinear sampler: takes as input the input feature map and the grid generated by the grid generator and produces the output feature map using bilinear interpolation.

The affine transformation is specified through the transformation matrix A

It can be constrained to one of attention by writing it in the form

where the parameters s, t_x and t_y can be regressed to allow cropping, translation, and isotropic scaling.

For a more in-depth explanation of STNs, read the two part blog post: part1 and part2.

Explore

Run the Sanity Check to get a feel of how the spatial transformer can be plugged into any existing code. For example, here's the result of a 45 degree rotation:

Drawing Drawing

Usage Note

You must define a localization network right before using this layer. The localization network is usually a ConvNet or a FC-net that has 6 output nodes (the 6 parameters of the affine transformation).

It is good practice to initialize the localization network to the identity transform before starting the training process. Here's a small sample code for illustration purposes.

# params
n_fc = 6
B, H, W, C = (2, 200, 200, 3)

# identity transform
initial = np.array([[1., 0, 0], [0, 1., 0]])
initial = initial.astype('float32').flatten()

# input placeholder
x = tf.placeholder(tf.float32, [B, H, W, C])

# localization network
W_fc1 = tf.Variable(tf.zeros([H*W*C, n_fc]), name='W_fc1')
b_fc1 = tf.Variable(initial_value=initial, name='b_fc1')
h_fc1 = tf.matmul(tf.zeros([B, H*W*C]), W_fc1) + b_fc1

# spatial transformer layer
h_trans = transformer(x, h_fc1)

Attribution

spatial-transformer-network's People

Contributors

kevinzakka avatar robotrory avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

spatial-transformer-network's Issues

Thete

How could theta be generated?

Padding mode support?

Thanks for the work. I notice that it's by default zero-padded. I wonder if there's any way to achieve a different padding mode? For example, the torch.nn.grid_sample() supports "border" padding.

Bug in bilinear sampler of transformer

Hi,
I tried the spatial transformer with a simple toy image and identity transform to verify that it works correctly, but I think that it has a bug, unless I'm doing something wrong.

This is the input image I used:
original_toy_example

When I transform it with theta=[[1,0,0,0,1,0]], I get this as an output, which is clearly not identical to the input:
identity_transform

I fixed it by adjusting the grid range in the bilinear sampler by -1 and shifting the calculation of the deltas before the clipping operation. Here is the fixed code:

def bilinear_sampler(img, x, y):
        """
        Performs bilinear sampling of the input images according to the 
        normalized coordinates provided by the sampling grid. Note that 
        the sampling is done identically for each channel of the input.
        To test if the function works properly, output image should be
        identical to input image when theta is initialized to identity
        transform.
        Input
        -----
        - img: batch of images in (B, H, W, C) layout.
        - grid: x, y which is the output of affine_grid_generator.
        Returns
        -------
        - interpolated images according to grids. Same size as grid.
        """
        # prepare useful params
        B = tf.shape(img)[0]
        H = tf.shape(img)[1]
        W = tf.shape(img)[2]
        C = tf.shape(img)[3]

        max_y = tf.cast(H - 1, 'int32')
        max_x = tf.cast(W - 1, 'int32')
        zero = tf.zeros([], dtype='int32')

        # cast indices as float32 (for rescaling)
        x = tf.cast(x, 'float32')
        y = tf.cast(y, 'float32')

        # rescale x and y to [0, W-1/H-1]
        x = 0.5 * ((x + 1.0) * tf.cast(max_x, 'float32'))
        y = 0.5 * ((y + 1.0) * tf.cast(max_y, 'float32'))

        # grab 4 nearest corner points for each (x_i, y_i)
        # i.e. we need a rectangle around the point of interest
        x0 = tf.floor(x)
        x1 = x0 + 1
        y0 = tf.floor(y)
        y1 = y0 + 1

        # calculate deltas
        wa = (x1-x) * (y1-y)
        wb = (x1-x) * (y-y0)
        wc = (x-x0) * (y1-y)
        wd = (x-x0) * (y-y0)

        #recast as int for index calculation
        x0 = tf.cast(x0, 'int32')
        x1 = tf.cast(x1, 'int32')
        y0 = tf.cast(y0, 'int32')
        y1 = tf.cast(y1, 'int32')

        # clip to range [0, H/W] to not violate img boundaries
        x0 = tf.clip_by_value(x0, zero, max_x)
        x1 = tf.clip_by_value(x1, zero, max_x)
        y0 = tf.clip_by_value(y0, zero, max_y)
        y1 = tf.clip_by_value(y1, zero, max_y)

        # get pixel value at corner coords
        Ia = get_pixel_value(img, x0, y0)
        Ib = get_pixel_value(img, x0, y1)
        Ic = get_pixel_value(img, x1, y0)
        Id = get_pixel_value(img, x1, y1)

        # add dimension for addition
        wa = tf.expand_dims(wa, axis=3)
        wb = tf.expand_dims(wb, axis=3)
        wc = tf.expand_dims(wc, axis=3)
        wd = tf.expand_dims(wd, axis=3)

        # compute output
        out = tf.add_n([wa*Ia, wb*Ib, wc*Ic, wd*Id])

        return out

With this code I recover the original image by transformation with the identity matrix.

Cliping issue

Hey, man! Thanks for your implementation. I guess there is something that has to be noticed. Don't you think you have to clip the x in your bilinear function, too?

Localization network

Hello,
Tensorflow newbie here.I want to compute the affine transformation parameters for custom dataset.Any inputs?
Thanks

Question about Affine matrix structure

Hi @kevinzakka
First of all, thank you for detailed explanation and implementation. But I have a question about theta matrix.
I want to use spatial transformer network to estimate affine transformation from 1 image to another and then use those Affine matrix with OpenCV function cv2.warpAffine().
I understood that translation is normalized between (-1,1), but how to deal with scale and rotation parameters?

I created a function to convert the matrix to appropriate format, but it don't work:

def tf_to_cv(theta, input_img):
    rows, cols, ch = input_img.shape
    # print(theta.reshape(2,3))
    M = theta.copy().reshape(2,3)
    M[0][0] = 1/M[0][0]
    M[1][1] = 1/M[1][1]
    M[0][1] = 1/M[0][1]
    M[1][0] = 1/M[1][0]
    M[0][1] = M[0][1] * -1
    M[1][0] = M[1][0] * -1
    M[0][2] = -cols * M[0][2]/4
    M[1][2] = -rows * M[1][2]/4
    # print(M[0][2], M[1][2])
    M[0][2] = M[0][2] + ((1 - M[0][0]) * cols/2 - M[0][1]*rows/2)
    M[1][2] = M[1][2] + (-M[1][0] * cols/2 + (1-M[1][1])*rows/2)
    return M

How to achieve interpolation?

How to achieve interpolation?When I set the initial = np.array([[1,0,0.1],[0,1,0.1]])
.But the is not interpolation.it have black parts

does not work with batch size 1

These lines should be modified as follows(avoid using squeeze):
# extract x and y coordinates
x_s = batch_grids[:, 0, :, :]
y_s = batch_grids[:, 1, :, :]

Can you give me your data?

I can't find your MNIST data, Can you give me your MNIST data for my job?
It is important to me!
Thanks for you

Conceptual bug!!

Hi,

In the script main.py, it's cool that you input the output of the localization network to the spatial transformer to obtain the transformed image as h_trans. But, post that you are providing the initial (original/untransformed) image to the convolutional network. Should not it be the transformed image (h_trans) that should go through the CNN?

I trained the model over the complete dataset (using the script as it is), but the transformations learned are not what expected. It seems to be random! Which is justified in the current scenario as the localization is not part of the optimization routine!

STN for only Attention Mechanism (Isotropic Scaling)

I was searching for STN implementations in GitHub and came across yours. I have a few queries regarding STN implementation only for the Attention mechanism which have fixed isotropic scaling say 0.5 and the localization network predicts only translation parameters (tx and ty).

Queries:

  1. If you use the same Spatial Transformer module written by you will it work?
  2. Should I use the localization network to predict only two parameters (tx and ty)?
  3. I bring the theta of shape (2, ) to (2, 3) by following the steps below,
    a. [tx, ty] * [0, 0, 1] --> [[0, 0, tx], [0, 0, ty]]
    b. [[0, 0, tx], [0, 0, ty]] + [[0.5, 0, 0], [0, 0.5, 0]] --> [[0.5, 0, tx], [0, 0.5, ty]]
    will it still be differentiable?
    Followed by a spatial transformer network.
    Will this work?

QQ

1.how to do add when it is after the convolution??

  1. if my input image is different size , can it use?

Request to add a license?

Hello there,

I've gone through your blog-posts on spatial-transformer networks, and they're very informative! I would like to have these networks for myself, but I don't want to reinvent the wheel—I would simply use your implementation, but your repository doesn't list a license, so I hesitate to do so. Therefore, I kindly request that you add a LICENSE.txt.

If you are inclined to do so, please see this website for some informal guidance on common open-source licenses.

About Theta Problem

Thank you for sharing. I have a problem after reading your code.
Is the theta artificially given in localisation net? Rather than given by training.

3d generalization

Does anyone know of the 3D generalization of this the spatial transformer networks? I would like to apply these models on 3D numpy arrays.

A Running Error Using Tensorflow 1.8.0

My tensorflow version is 1.8.0
I clone the code to my computer and download the mnist_cluttered_60x60_6distortions.npz dataset into ./data/.
When I run the main.py, I get the following ERROR:

Loading the data...
Building ConvNet...
Traceback (most recent call last):
  File "/home/xinlab/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/tensor_util.py", line 517, in make_tensor_proto
    str_values = [compat.as_bytes(x) for x in proto_values]
  File "/home/xinlab/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/tensor_util.py", line 517, in <listcomp>
    str_values = [compat.as_bytes(x) for x in proto_values]
  File "/home/xinlab/anaconda3/lib/python3.6/site-packages/tensorflow/python/util/compat.py", line 67, in as_bytes
    (bytes_or_text,))
TypeError: Expected binary or unicode string, got -1

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "main.py", line 422, in <module>
    main()
  File "main.py", line 275, in main
    pool3_flat, pool3_size = Flatten(pool3)
  File "/home/LiChenyang/spatial-transformer-network/utils/layer_utils.py", line 85, in Flatten
    layer_flat = tf.reshape(layer, [-1, num_features])
  File "/home/xinlab/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/gen_array_ops.py", line 6113, in reshape
    "Reshape", tensor=tensor, shape=shape, name=name)
  File "/home/xinlab/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 513, in _apply_op_helper
    raise err
  File "/home/xinlab/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 510, in _apply_op_helper
    preferred_dtype=default_dtype)
  File "/home/xinlab/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1104, in internal_convert_to_tensor
    ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)
  File "/home/xinlab/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/constant_op.py", line 235, in _constant_tensor_conversion_function
    return constant(v, dtype=dtype, name=name)
  File "/home/xinlab/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/constant_op.py", line 214, in constant
    value, dtype=dtype, shape=shape, verify_shape=verify_shape))
  File "/home/xinlab/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/tensor_util.py", line 521, in make_tensor_proto
    "supported type." % (type(values), values))
TypeError: Failed to convert object of type <class 'list'> to Tensor. Contents: [-1, None]. Consider casting elements to a supported type.
LiChenyang@ubuntu:~/spatial-transformer-network$

Any idea to fix this problem?

How can it do backpropogation?

After reading your code, it seems that the code doesn't mention anything about backpropogation. Does tensorflow do it automatically?
Thanks!

STN implementation in tensorflow

Does anyone find something wrong that tf.gather or tf.gather_nd will not be differential w.t. coordinates, which used in STN? In addition, TF officially implement an op namely, tf.contrib.resampler.resampler that is totally differential.

location

hello,I want to know how to get the network attention area in the resource.

A question

hey, if I want to embed the stn in a new net structure,how should I define the loss?Or Should I do that? @kevinzakka

Something wrong with get_pixel_value?

@kevinzakka @robotrory
In get_pixel_value, the description states that x and y are flattened tensors of shape (BxHxW,). However, in the first lines of the function, it is assumed that the input tensors are actually of the size (B,H,W,C) as shown here:

    shape = tf.shape(x)
    batch_size = shape[0]
    height = shape[1]
    width = shape[2]

This creates an error as shape[1] or higher is out of bound for a flattened tensor.

What am i missing?

The training does not converge

Hi, thanks a lot for your great work, your code is very clear and easy to understand.

However, when I train with your main.py, the network does not converge. The only change is that I set "SAMPLE" from True to False, since if True, then only 500 samples are used for training. However, the training loss always be around 2.3 and the training is terminated because there is no improvement after 1000 steps. Could you tell me your best accuracy achieved with your code? Thanks a lot.

Regarding get_pixel_value in STN method

Hi @kevinzakka and @robotrory,
Can you please tell me,
def get_pixel_value(img, x, y):
"""
Utility function to get pixel value for coordinate
vectors x and y from a 4D tensor image.
Input
-----
- img: tensor of shape (B, H, W, C)
- x: flattened tensor of shape (BHW,)
- y: flattened tensor of shape (BHW,)
Returns
-------
- output: tensor of shape (B, H, W, C)
"""
shape = tf.shape(x)
batch_size = shape[0]
height = shape[1]
width = shape[2]

batch_idx = tf.range(0, batch_size)
batch_idx = tf.reshape(batch_idx, (batch_size, 1, 1))
b = tf.tile(batch_idx, (1, height, width))

indices = tf.stack([b, y, x], 3)

return tf.gather_nd(img, indices)

why you write indices = tf.stack([b, y, x], 3) in reverse order, shouldn't it be indices = tf.stack([b,x, y], 3). I am extending your method for 3D and I was wondering the reason of placing in reverse order.
indices = tf.stack([b, z, y, x], 3), is this correct then?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.