nnuyi / non-local_nets-tensorflow Goto Github PK
View Code? Open in Web Editor NEWAn implement of Non-local neural networks for tensorflow version
An implement of Non-local neural networks for tensorflow version
I compare the code between this version and Facebook's version. I find this version is not non-local net which published in the CVPR or arXiv.
This is the source code of Facebook:
def spacetime_nonlocal(
model, blob_in, dim_in, dim_out, batch_size, prefix, dim_inner,
is_test, max_pool_stride=2):
# ---------------------
cur = blob_in
# we do projection to convert each spacetime location to a feature
# theta original size
# e.g., (8, 1024, 4, 14, 14) => (8, 1024, 4, 14, 14)
theta = model.ConvNd(
cur, prefix + '_theta',
dim_in,
dim_inner,
[1, 1, 1],
strides=[1, 1, 1],
pads=[0, 0, 0] * 2,
weight_init=('GaussianFill', {'std': cfg.NONLOCAL.CONV_INIT_STD}),
bias_init=('ConstantFill', {'value': 0.}), no_bias=cfg.NONLOCAL.NO_BIAS)
# phi and g: half spatial size
# e.g., (8, 1024, 4, 14, 14) => (8, 1024, 4, 7, 7)
if cfg.NONLOCAL.USE_MAXPOOL is True:
max_pool = model.MaxPool(
cur, prefix + '_pool',
kernels=[1, max_pool_stride, max_pool_stride],
strides=[1, max_pool_stride, max_pool_stride],
pads=[0, 0, 0] * 2,
)
else:
max_pool = cur
phi = model.ConvNd(
max_pool, prefix + '_phi',
dim_in,
dim_inner,
[1, 1, 1],
strides=[1, 1, 1],
pads=[0, 0, 0] * 2,
weight_init=('GaussianFill', {'std': cfg.NONLOCAL.CONV_INIT_STD}),
bias_init=('ConstantFill', {'value': 0.}), no_bias=cfg.NONLOCAL.NO_BIAS)
g = model.ConvNd(
max_pool, prefix + '_g',
dim_in,
dim_inner,
[1, 1, 1],
strides=[1, 1, 1],
pads=[0, 0, 0] * 2,
weight_init=('GaussianFill', {'std': cfg.NONLOCAL.CONV_INIT_STD}),
bias_init=('ConstantFill', {'value': 0.}), no_bias=cfg.NONLOCAL.NO_BIAS)
# we have to use explicit batch size (to support arbitrary spacetime size)
# e.g., (8, 1024, 4, 14, 14) => (8, 1024, 784)
theta, theta_shape_5d = model.Reshape(
theta, [theta + '_re' if not cfg.MODEL.ALLOW_INPLACE_RESHAPE else theta,
theta + '_shape5d'],
shape=(batch_size, dim_inner, -1))
phi, phi_shape_5d = model.Reshape(
phi, [phi + '_re' if not cfg.MODEL.ALLOW_INPLACE_RESHAPE else phi,
phi + '_shape5d'],
shape=(batch_size, dim_inner, -1))
g, g_shape_5d = model.Reshape(
g, [g + '_re' if not cfg.MODEL.ALLOW_INPLACE_RESHAPE else g,
g + '_shape5d'],
shape=(batch_size, dim_inner, -1))
# e.g., (8, 1024, 784) * (8, 1024, 784) => (8, 784, 784)
theta_phi = model.net.BatchMatMul([theta, phi], prefix + '_affinity', trans_a=1)
if cfg.NONLOCAL.USE_SOFTMAX is True:
if cfg.NONLOCAL.USE_SCALE is True:
theta_phi_sc = model.Scale(theta_phi, theta_phi, scale=dim_inner**-.5)
else:
theta_phi_sc = theta_phi
# softmax
# sum(p[i, j, :]) == 1, for any i, j
p = model.Softmax(theta_phi_sc, theta_phi + '_prob', engine='CUDNN', axis=2)
else:
ones = model.net.ConstantFill([theta_phi], [theta_phi + '_ones'], value=1.)
ones = model.net.ReduceBackSum([ones], [theta_phi + '_const'])
zeros = model.net.ConstantFill([theta_phi], [theta_phi + '_zeros'], value=0.)
denom = model.net.Add(
[zeros, ones], [theta_phi + '_denom'], broadcast=1, axis=0)
model.StopGradient(denom, denom)
p = model.net.Div([theta_phi, denom], [theta_phi + '_sc'])
# note: g's axis[2] corresponds to p's axis[2]
# e.g., g(8, 1024, 784_2) * p(8, 784_1, 784_2) => (8, 1024, 784_1)
t = model.net.BatchMatMul([g, p], prefix + '_y', trans_b=1)
# reshape back:
# e.g., (8, 1024, 784) => (8, 1024, 4, 14, 14)
t_re, t_shape = model.Reshape(
[t, theta_shape_5d],
[t + '_re' if not cfg.MODEL.ALLOW_INPLACE_RESHAPE else t,
t + '_shape3d'])
blob_out = t_re
blob_out = model.ConvNd(
blob_out, prefix + '_out',
dim_inner,
dim_out,
[1, 1, 1],
strides=[1, 1, 1],
pads=[0, 0, 0] * 2,
weight_init=('GaussianFill', {'std': cfg.NONLOCAL.CONV_INIT_STD})
if not cfg.NONLOCAL.USE_ZERO_INIT_CONV else
('ConstantFill', {'value': 0.}),
bias_init=('ConstantFill', {'value': 0.}), no_bias=cfg.NONLOCAL.NO_BIAS)
if cfg.NONLOCAL.USE_BN is True:
blob_out = model.SpatialBN(
blob_out, prefix + "_bn", dim_out,
epsilon=cfg.NONLOCAL.BN_EPSILON, momentum=cfg.NONLOCAL.BN_MOMENTUM,
is_test=is_test
)
model.param_init_net.ConstantFill(
[prefix + "_bn_s"], prefix + "_bn_s", value=cfg.NONLOCAL.BN_INIT_GAMMA)
if cfg.NONLOCAL.USE_AFFINE is True:
blob_out = model.AffineNd(blob_out, prefix + "_bn", dim_out)
return blob_out
In fact, it uses MatMul instead of conv op.
Dear nnUyi:
Thank you for your code, which provides us a clear and concise of the Non Local Model.
However, I have trained the network with and without Non Local Model for comparison. For the network with Non Local Model, I use your code; for the network without None Local Model, I just disabled the "nonlocal_block1 = NonLocalBlock(cnv1_pool, 32, scope='nonlocal_block1')" and "nonlocal_block2 = NonLocalBlock(cnv2_pool, 64, scope='nonlocal_block2')". The results show that both training and testing accuracy are higher WITHOUT Non Local Model.
Have you experienced this strange circumstance?
Thanks again!
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.