Giter Club home page Giter Club logo

Comments (10)

tatsuhiko-inoue avatar tatsuhiko-inoue commented on August 28, 2024 7

I also experienced a similar error.
I avoided the error using the following modification.

diff --git a/model.py b/model.py
index b918ab0..68cb3fe 100644
--- a/model.py
+++ b/model.py
@@ -373,7 +373,7 @@ def revnet2d_step(name, z, logdet, hps, reverse):
                 h = f("f1", z1, hps.width, n_z)
                 shift = h[:, :, :, 0::2]
                 # scale = tf.exp(h[:, :, :, 1::2])
-                scale = tf.nn.sigmoid(h[:, :, :, 1::2] + 2.)
+                scale = tf.nn.sigmoid(h[:, :, :, 1::2] + 2.) + 1e-10
                 z2 += shift
                 z2 *= scale
                 logdet += tf.reduce_sum(tf.log(scale), axis=[1, 2, 3])
@@ -393,7 +393,7 @@ def revnet2d_step(name, z, logdet, hps, reverse):
                 h = f("f1", z1, hps.width, n_z)
                 shift = h[:, :, :, 0::2]
                 # scale = tf.exp(h[:, :, :, 1::2])
-                scale = tf.nn.sigmoid(h[:, :, :, 1::2] + 2.)
+                scale = tf.nn.sigmoid(h[:, :, :, 1::2] + 2.) + 1e-10
                 z2 /= scale
                 z2 -= shift
                 logdet -= tf.reduce_sum(tf.log(scale), axis=[1, 2, 3])
diff --git a/tfops.py b/tfops.py
index d978419..2e7c556 100644
--- a/tfops.py
+++ b/tfops.py
@@ -449,9 +449,9 @@ def gaussian_diag(mean, logsd):
     o.sample = mean + tf.exp(logsd) * o.eps
     o.sample2 = lambda eps: mean + tf.exp(logsd) * eps
     o.logps = lambda x: -0.5 * \
-        (np.log(2 * np.pi) + 2. * logsd + (x - mean) ** 2 / tf.exp(2. * logsd))
+        (np.log(2 * np.pi) + 2. * logsd + (x - mean) ** 2 / (tf.exp(2. * logsd) + 1e-10))
     o.logp = lambda x: flatten_sum(o.logps(x))
-    o.get_eps = lambda x: (x - mean) / tf.exp(logsd)
+    o.get_eps = lambda x: (x - mean) / (tf.exp(logsd) + 1e-10)
     return o

from glow.

paulchou0309 avatar paulchou0309 commented on August 28, 2024 1

I met the issue same how to solve it?@tatsuhiko-inoue @nuges01 @arunpatro

from glow.

nuges01 avatar nuges01 commented on August 28, 2024

I'm having this issue as well. I've been able to train on custom dataset without conditioning on class labels. However, if I set --ycond and weight_y 0.01, the gradients start by reducing, but eventually explode to inf, causing the error.

from glow.

arunpatro avatar arunpatro commented on August 28, 2024

I first run the experiment on an AWS p2.xlarge EC2 instance. It gave me these errors.
I re-run the experiment later on a NVIDIA Titan X. It ran smoothly. Converging on 64x64 images but ot converging for 256x256. Exact same hyperparams.

I doubt it can be solved. Its prone to bad random initialisations, that lead it to inf and NaN errors on the search space.

from glow.

omidsakhi avatar omidsakhi commented on August 28, 2024

How about adding "+ tf.eye(shape[3]) * 10e-4 " to this line:

https://github.com/openai/glow/blob/master/model.py#L451

? Does that make any difference?

from glow.

nuges01 avatar nuges01 commented on August 28, 2024

@tatsuhiko-inoue, Thanks for the suggestion. It didn't work for me though. Those modifications are under the condition elif hps.flow_coupling == 1 (affine coupling). I'm following the set of parameters for Conditional qualitative results for which flow_coupling is 0 (additive). For example:
python train.py --problem cifar10 --image_size 32 --n_level 3 --depth 32 --flow_permutation 2 --flow_coupling 0 --seed 0 --learntop --lr 0.001 --n_bits_x 5 --ycond --weight_y=0.01
I just ran it after changing --flow_coupling to 1, but it still results in the problem.

from glow.

nuges01 avatar nuges01 commented on August 28, 2024

@omidsakhi, that didn't solve it for me either, I'm afraid. Thanks.

from glow.

tatsuhiko-inoue avatar tatsuhiko-inoue commented on August 28, 2024

When I execute glow, the gradient of "logsd" in gaussian_diag() may be NaN.
When the "logsd" is 45.0 or more, the gradient becomes NaN.

I was able to avoid NaN gradient by calculate the gradient of "x/exp(y)" collectively as follows.
But instead the loss has become unstable.

@tf.custom_gradient
def div_by_exp(x, y):
    exp_y = tf.exp(y) + 1e-10
    ret = x / exp_y
    def _grad(dy):
        return dy/exp_y, dy*-ret
    return ret, _grad

def gaussian_diag(mean, logsd):
        :
    o.logps = lambda x: -0.5 * (np.log(2 * np.pi) + 2. * logsd + div_by_exp((x - mean) ** 2, 2*logsd))
        :

from glow.

naturomics avatar naturomics commented on August 28, 2024

Hello guys, I found a solution for this 'not invertible' problem. During the training, the weighs of invertible 1x1 conv keeps increase to balance the log-determinant terms generated by invertible 1x1 conv and affine coupling layer/actnorm. This can be solved by adding an regularization term only for the weights of invertible 1x1 conv. In practice I use l2 regulariztion. But it's also worth mentioning that after adding regularization term, the number of epochs will slightly increase to converge to the same NLL.

I discussed it in our recent publication "Generative Model with Dynamic Linear Flow", which improves the performance of flow-based methods significantly and converges faster than Glow. Our code is here.

from glow.

AnastasisKratsios avatar AnastasisKratsios commented on August 28, 2024
  • tf.eye(shape[3]) * 10e-4

I think you mean "+ tf.eye(3) * 10e-4"

shape[3] is not defined.

from glow.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.