Giter Club home page Giter Club logo

Comments (6)

StevenShi-23 avatar StevenShi-23 commented on June 14, 2024

Hi Marc,

Thanks for bringing this up! This is indeed a bug, and we are fixing it.

from deeprec.

StevenShi-23 avatar StevenShi-23 commented on June 14, 2024

Hi Marc,

Upon checking, this is not a bug. When applying BatchNorm on the default axis (last dim), BatchNorm reduces to LayerNorm, and since the size of gamma/beta depends on the shape of input tensor, the original implementation is still correct.

However, for the clarity of the code, we updated the example (ref PR #816 ).

Thanks for the comment!

from deeprec.

zippeurfou avatar zippeurfou commented on June 14, 2024

I am not sure I am following see this screenshot.
Screenshot 2023-04-18 at 11 41 50 AM
What am I missing?

from deeprec.

Duyi-Wang avatar Duyi-Wang commented on June 14, 2024

Because your code isn't in trianing.

tf.layers.batch_normalization() will call to class BatchNormalizationBase

class BatchNormalizationBase(Layer):

tf.keras.layers.LayerNormalization() will call to class LayerNormalization
class LayerNormalization(Layer):

In LayerNormalization, mean and var are computed by nn.moments

mean, variance = nn.moments(inputs, self.axis, keep_dims=True)

then use nn.batch_normalization to get the result.
outputs = nn.batch_normalization(
inputs,
mean,
variance,
offset=offset,
scale=scale,
variance_epsilon=self.epsilon)

It is the same with BN without other features.

def _moments(self, inputs, reduction_axes, keep_dims):
mean, variance = nn.moments(inputs, reduction_axes, keep_dims=keep_dims)
# TODO(b/129279393): Support zero batch input in non DistributionStrategy
# code as well.
if self._support_zero_size_input():
inputs_size = array_ops.size(inputs)
mean = array_ops.where(inputs_size > 0, mean, K.zeros_like(mean))
variance = array_ops.where(inputs_size > 0, variance,
K.zeros_like(variance))
return mean, variance

mean, variance = self._moments(
math_ops.cast(inputs, self._param_dtype),
reduction_axes,
keep_dims=keep_dims)

outputs = nn.batch_normalization(inputs,
_broadcast(mean),
_broadcast(variance),
offset,
scale,
self.epsilon)

But the difference is that when you are not in training, the mean and var of BN will be replaced.

mean = tf_utils.smart_cond(training,
lambda: mean,
lambda: ops.convert_to_tensor(moving_mean))
variance = tf_utils.smart_cond(
training,
lambda: variance,
lambda: ops.convert_to_tensor(moving_variance))

from deeprec.

Duyi-Wang avatar Duyi-Wang commented on June 14, 2024

you can add input param moving_mean_initializer='ones' which is defaulted to 'zeros' and find output is changed.

from deeprec.

zippeurfou avatar zippeurfou commented on June 14, 2024

Thanks @Duyi-Wang it makes sense. I was confused by it as well but the doc clearly state it. Thanks for pointing out the code.
Adding a screenshot for posterity.
Screenshot 2023-04-19 at 10 45 40 AM
Feel free to close this one.

from deeprec.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.