Giter Club home page Giter Club logo

batchrenormalization's Issues

Bug in batch_renorm

I ran this command inside a conda environment with keras 2.2 installed:
python cifar10_brn.py

Error message:
/root/devansh/FFL/BatchRenormalization/batch_renorm.py:97: UserWarning: This implementation of BatchRenormalization is inconsistent with the original paper and therefore results may not be similar ! For discussion on the inconsistency of this implementation, refer here : keras-team/keras-contrib#17
warnings.warn('This implementation of BatchRenormalization is inconsistent with the '
Traceback (most recent call last):
File "cifar10_brn.py", line 27, in
model = create_wide_residual_network(input_dim=init_shape, nb_classes=10, N=2, k=4)
File "/root/devansh/FFL/BatchRenormalization/wrn_renorm.py", line 118, in create_wide_residual_network
x = initial_conv(ip)
File "/root/devansh/FFL/BatchRenormalization/wrn_renorm.py", line 14, in initial_conv
x = BatchRenormalization(axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer='uniform')(x)
File "/root/anaconda3/envs/ffl_dev/lib/python3.6/site-packages/keras/engine/base_layer.py", line 457, in call
output = self.call(inputs, **kwargs)
File "/root/devansh/FFL/BatchRenormalization/batch_renorm.py", line 192, in call
r = K.stop_gradient(K.clip(r, 1 / self.r_max, self.r_max))
File "/root/anaconda3/envs/ffl_dev/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py", line 1597, in clip
if max_value is not None and max_value < min_value:
File "/root/anaconda3/envs/ffl_dev/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 653, in bool
raise TypeError("Using a tf.Tensor as a Python bool is not allowed. "
TypeError: Using a tf.Tensor as a Python bool is not allowed. Use if t is not None: instead of if t: to test if a tensor is defined, and use TensorFlow ops such as tf.cond to execute subgraphs conditioned on the value of a tensor.

Bug regarding to cifar10_brn.py

I ran your cifar10_brn.py file and I got this error:

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
170500096/170498071 [==============================] - 25s 0us/step

InvalidArgumentError Traceback (most recent call last)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py in _create_c_op(graph, node_def, inputs, control_inputs)
1627 try:
-> 1628 c_op = c_api.TF_FinishOperation(op_desc)
1629 except errors.InvalidArgumentError as e:

InvalidArgumentError: Shape must be rank 1 but is rank 0 for 'batch_renormalization_4/Reshape_10' (op: 'Reshape') with input shapes: [1,1,1,16], [].

During handling of the above exception, another exception occurred:

ValueError Traceback (most recent call last)
in ()
51
52
---> 53 model = create_wide_residual_network(input_dim=init_shape, nb_classes=10, N=2, k=4)
54
55

/content/drive/vguNuke/Jupyter/BatchRenormalization/wrn_renorm.py in create_wide_residual_network(input_dim, nb_classes, N, k, dropout, verbose)
116 ip = Input(shape=input_dim)
117
--> 118 x = initial_conv(ip)
119 nb_conv = 4
120

/content/drive/vguNuke/Jupyter/BatchRenormalization/wrn_renorm.py in initial_conv(input)
12 channel_axis = 1 if K.image_data_format() == "channels_first" else -1
13
---> 14 x = BatchRenormalization(axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_init='uniform')(x)
15 x = Activation('relu')(x)
16 return x

/usr/local/lib/python3.6/dist-packages/keras/engine/base_layer.py in call(self, inputs, **kwargs)
455 # Actually call the layer,
456 # collecting output(s), mask(s), and shape(s).
--> 457 output = self.call(inputs, **kwargs)
458 output_mask = self.compute_mask(inputs, previous_mask)
459

/content/drive/vguNuke/Jupyter/BatchRenormalization/batch_renorm.py in call(self, x, mask)
195 x, broadcast_running_mean, broadcast_running_std,
196 broadcast_beta, broadcast_gamma,
--> 197 epsilon=self.epsilon)
198
199 # pick the normalized form of x corresponding to the training phase

/usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py in batch_normalization(x, mean, var, beta, gamma, axis, epsilon)
1906 # so it may have extra axes with 1, it is not needed and should be removed
1907 if ndim(mean) > 1:
-> 1908 mean = tf.reshape(mean, (-1))
1909 if ndim(var) > 1:
1910 var = tf.reshape(var, (-1))

/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/gen_array_ops.py in reshape(tensor, shape, name)
6480 if _ctx is None or not _ctx._eager_context.is_eager:
6481 _, _, _op = _op_def_lib._apply_op_helper(
-> 6482 "Reshape", tensor=tensor, shape=shape, name=name)
6483 _result = _op.outputs[:]
6484 _inputs_flat = _op.inputs

/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/op_def_library.py in _apply_op_helper(self, op_type_name, name, **keywords)
785 op = g.create_op(op_type_name, inputs, output_types, name=scope,
786 input_types=input_types, attrs=attr_protos,
--> 787 op_def=op_def)
788 return output_structure, op_def.is_stateful, op
789

/usr/local/lib/python3.6/dist-packages/tensorflow/python/util/deprecation.py in new_func(*args, **kwargs)
486 'in a future version' if date is None else ('after %s' % date),
487 instructions)
--> 488 return func(*args, **kwargs)
489 return tf_decorator.make_decorator(func, new_func, 'deprecated',
490 _add_deprecated_arg_notice_to_docstring(

/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py in create_op(failed resolving arguments)
3272 input_types=input_types,
3273 original_op=self._default_original_op,
-> 3274 op_def=op_def)
3275 self._create_op_helper(ret, compute_device=compute_device)
3276 return ret

/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py in init(self, node_def, g, inputs, output_types, control_inputs, input_types, original_op, op_def)
1790 op_def, inputs, node_def.attr)
1791 self._c_op = _create_c_op(self._graph, node_def, grouped_inputs,
-> 1792 control_input_ops)
1793
1794 # Initialize self._outputs.

/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py in _create_c_op(graph, node_def, inputs, control_inputs)
1629 except errors.InvalidArgumentError as e:
1630 # Convert to ValueError for backwards compatibility.
-> 1631 raise ValueError(str(e))
1632
1633 return c_op

ValueError: Shape must be rank 1 but is rank 0 for 'batch_renormalization_4/Reshape_10' (op: 'Reshape') with input shapes: [1,1,1,16], [].

Performance looks similar

Thanks for your implementation of batch renormalization. I saw that your performance of batch renorm and batch norm are similar. Do you check the performance with a simple network as in the paper? The paper shows a simple network with high gain between batch renorm and batch norm.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.