Giter Club home page Giter Club logo

batchrenormalization's People

Contributors

brucedai003 avatar titu1994 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

batchrenormalization's Issues

Performance looks similar

Thanks for your implementation of batch renormalization. I saw that your performance of batch renorm and batch norm are similar. Do you check the performance with a simple network as in the paper? The paper shows a simple network with high gain between batch renorm and batch norm.

Bug in batch_renorm

I ran this command inside a conda environment with keras 2.2 installed:
python cifar10_brn.py

Error message:
/root/devansh/FFL/BatchRenormalization/batch_renorm.py:97: UserWarning: This implementation of BatchRenormalization is inconsistent with the original paper and therefore results may not be similar ! For discussion on the inconsistency of this implementation, refer here : keras-team/keras-contrib#17
warnings.warn('This implementation of BatchRenormalization is inconsistent with the '
Traceback (most recent call last):
File "cifar10_brn.py", line 27, in
model = create_wide_residual_network(input_dim=init_shape, nb_classes=10, N=2, k=4)
File "/root/devansh/FFL/BatchRenormalization/wrn_renorm.py", line 118, in create_wide_residual_network
x = initial_conv(ip)
File "/root/devansh/FFL/BatchRenormalization/wrn_renorm.py", line 14, in initial_conv
x = BatchRenormalization(axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer='uniform')(x)
File "/root/anaconda3/envs/ffl_dev/lib/python3.6/site-packages/keras/engine/base_layer.py", line 457, in call
output = self.call(inputs, **kwargs)
File "/root/devansh/FFL/BatchRenormalization/batch_renorm.py", line 192, in call
r = K.stop_gradient(K.clip(r, 1 / self.r_max, self.r_max))
File "/root/anaconda3/envs/ffl_dev/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py", line 1597, in clip
if max_value is not None and max_value < min_value:
File "/root/anaconda3/envs/ffl_dev/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 653, in bool
raise TypeError("Using a tf.Tensor as a Python bool is not allowed. "
TypeError: Using a tf.Tensor as a Python bool is not allowed. Use if t is not None: instead of if t: to test if a tensor is defined, and use TensorFlow ops such as tf.cond to execute subgraphs conditioned on the value of a tensor.

Bug regarding to cifar10_brn.py

I ran your cifar10_brn.py file and I got this error:

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
170500096/170498071 [==============================] - 25s 0us/step

InvalidArgumentError Traceback (most recent call last)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py in _create_c_op(graph, node_def, inputs, control_inputs)
1627 try:
-> 1628 c_op = c_api.TF_FinishOperation(op_desc)
1629 except errors.InvalidArgumentError as e:

InvalidArgumentError: Shape must be rank 1 but is rank 0 for 'batch_renormalization_4/Reshape_10' (op: 'Reshape') with input shapes: [1,1,1,16], [].

During handling of the above exception, another exception occurred:

ValueError Traceback (most recent call last)
in ()
51
52
---> 53 model = create_wide_residual_network(input_dim=init_shape, nb_classes=10, N=2, k=4)
54
55

/content/drive/vguNuke/Jupyter/BatchRenormalization/wrn_renorm.py in create_wide_residual_network(input_dim, nb_classes, N, k, dropout, verbose)
116 ip = Input(shape=input_dim)
117
--> 118 x = initial_conv(ip)
119 nb_conv = 4
120

/content/drive/vguNuke/Jupyter/BatchRenormalization/wrn_renorm.py in initial_conv(input)
12 channel_axis = 1 if K.image_data_format() == "channels_first" else -1
13
---> 14 x = BatchRenormalization(axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_init='uniform')(x)
15 x = Activation('relu')(x)
16 return x

/usr/local/lib/python3.6/dist-packages/keras/engine/base_layer.py in call(self, inputs, **kwargs)
455 # Actually call the layer,
456 # collecting output(s), mask(s), and shape(s).
--> 457 output = self.call(inputs, **kwargs)
458 output_mask = self.compute_mask(inputs, previous_mask)
459

/content/drive/vguNuke/Jupyter/BatchRenormalization/batch_renorm.py in call(self, x, mask)
195 x, broadcast_running_mean, broadcast_running_std,
196 broadcast_beta, broadcast_gamma,
--> 197 epsilon=self.epsilon)
198
199 # pick the normalized form of x corresponding to the training phase

/usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py in batch_normalization(x, mean, var, beta, gamma, axis, epsilon)
1906 # so it may have extra axes with 1, it is not needed and should be removed
1907 if ndim(mean) > 1:
-> 1908 mean = tf.reshape(mean, (-1))
1909 if ndim(var) > 1:
1910 var = tf.reshape(var, (-1))

/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/gen_array_ops.py in reshape(tensor, shape, name)
6480 if _ctx is None or not _ctx._eager_context.is_eager:
6481 _, _, _op = _op_def_lib._apply_op_helper(
-> 6482 "Reshape", tensor=tensor, shape=shape, name=name)
6483 _result = _op.outputs[:]
6484 _inputs_flat = _op.inputs

/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/op_def_library.py in _apply_op_helper(self, op_type_name, name, **keywords)
785 op = g.create_op(op_type_name, inputs, output_types, name=scope,
786 input_types=input_types, attrs=attr_protos,
--> 787 op_def=op_def)
788 return output_structure, op_def.is_stateful, op
789

/usr/local/lib/python3.6/dist-packages/tensorflow/python/util/deprecation.py in new_func(*args, **kwargs)
486 'in a future version' if date is None else ('after %s' % date),
487 instructions)
--> 488 return func(*args, **kwargs)
489 return tf_decorator.make_decorator(func, new_func, 'deprecated',
490 _add_deprecated_arg_notice_to_docstring(

/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py in create_op(failed resolving arguments)
3272 input_types=input_types,
3273 original_op=self._default_original_op,
-> 3274 op_def=op_def)
3275 self._create_op_helper(ret, compute_device=compute_device)
3276 return ret

/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py in init(self, node_def, g, inputs, output_types, control_inputs, input_types, original_op, op_def)
1790 op_def, inputs, node_def.attr)
1791 self._c_op = _create_c_op(self._graph, node_def, grouped_inputs,
-> 1792 control_input_ops)
1793
1794 # Initialize self._outputs.

/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py in _create_c_op(graph, node_def, inputs, control_inputs)
1629 except errors.InvalidArgumentError as e:
1630 # Convert to ValueError for backwards compatibility.
-> 1631 raise ValueError(str(e))
1632
1633 return c_op

ValueError: Shape must be rank 1 but is rank 0 for 'batch_renormalization_4/Reshape_10' (op: 'Reshape') with input shapes: [1,1,1,16], [].

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.