Giter Club home page Giter Club logo

Comments (12)

pmorerio avatar pmorerio commented on September 25, 2024 1

self.d_loss_real =tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=self.logits_real, labels=tf.one_hot(self.labels, self.no_classes + 1)))
This is the loss for 'real' i.e. depth inputs for the discriminator, in fact self.logits_real=self.D(self.depth_features, reuse=False). This means that the aim of the discriminator is to classify depth features with the correct class they belong to. The reuse= False is redundant here. What is important is the reuse=True the line below, which means that the same discriminator with the same weights must be used (not a new one, for which Tensorflow would complain).

labels=tf.one_hot(self.labels, self.no_classes + 1)))
This simply means that you have to transform the label index, which is an integer, into its one-hot encoding. The +1 is not a trick. I can have inputs from the 19 classes of NYUD or a hallucinated input. Thus the discriminator has a problem which considers 19+1 classes.

from admd.

pmorerio avatar pmorerio commented on September 25, 2024 1

Hi,
it looks everything is correct but the very last activation function.
It must be sigmoid for train_hallucination and None (no activation) for train_hallucination_p2.

from admd.

pmorerio avatar pmorerio commented on September 25, 2024

Basically I need a tensor same shape of self.labels. This is a workaround which exploits the element-wise operations that tensorflow performs: self.labels has shape [batch_size] and if you add the scalar self.no_classes, the operation will be done for every element of self.labels. Now you have to remove self.labels in order to get a tensor of shape [batch_size] which contains the value self.no_classes in each entry. In principle if you knew the batch-size you could simply create a vector of that length, but the MultiModal class only instantiates the model and is agnostic of the batch size.

from admd.

Scienceseb avatar Scienceseb commented on September 25, 2024

Ok thanks!

from admd.

Scienceseb avatar Scienceseb commented on September 25, 2024

When you do self.d_loss_real =tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=self.logits_real, labels=tf.one_hot(self.labels, self.no_classes + 1))), I just want to be sure of what is done here, logits=self.logits_real so logits here contains the output of self.D(self.depth_features, reuse=False) (why reuse= False here?), after that labels=tf.one_hot(self.labels, self.no_classes + 1))) is more tricky and I think I lost it here...you got here 19+1 one hot vector with the one at the number in labels I think. The trick with the +1 I'm not sure tio understand it perfectly

from admd.

Scienceseb avatar Scienceseb commented on September 25, 2024

Also I dont understand how you can still have a 1024 fc for layer 3 (iI have a 2048) while you give it the concatenation of the the output of layer 2 and the output of layer 1 both 1024 or they are 512? In PyTorch to do a skip connection we do the following: net = F.relu(torch.cat([res,net], dim=1)) #skip connection 1 . Thank you

from admd.

pmorerio avatar pmorerio commented on September 25, 2024

while you give it the concatenation of the the output of layer 2 and the output of layer 1

I am giving the sum net = tf.nn.relu(res + net).

In PyTorch to do a skip connection we do the following: net = F.relu(torch.cat([res,net], dim=1))

This is not the way you should implement skip connection. Residual must be summed, not concatenated.
image

from admd.

Scienceseb avatar Scienceseb commented on September 25, 2024

while you give it the concatenation of the the output of layer 2 and the output of layer 1

I am giving the sum net = tf.nn.relu(res + net).

In PyTorch to do a skip connection we do the following: net = F.relu(torch.cat([res,net], dim=1))

This is not the way you should implement skip connection. Residual must be summed, not concatenated.
image

You are totally right my mistake.

from admd.

Scienceseb avatar Scienceseb commented on September 25, 2024

D
Just to check here is my PyTorch discriminator for NYU does it respect your implementation? Im really not sure of the last ReLu at the end...but your implementation do that I think.

from admd.

Scienceseb avatar Scienceseb commented on September 25, 2024

Perfect that's what I though

from admd.

Scienceseb avatar Scienceseb commented on September 25, 2024

self.d_loss_fake = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=self.logits_fake, labels=tf.one_hot(fake_labels, self.no_classes + 1))) vs self.g_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=self.logits_fake, labels=tf.one_hot(self.labels, self.no_classes + 1))) First you want for the discriminator logits_fake to be like fake_labels (so 19 everywhere), but after that you want logits_fake to be like labels (so no 19). Your generator won againts the discriminator after how many epochs, do you have failure sometime ?

from admd.

Scienceseb avatar Scienceseb commented on September 25, 2024

Right now I dont get your results at all... dont know why my loss criterion is NLLLoss or CrossEntropy ...I got negative loss value for NLLLoss but ok value for CrossEntropy, why it fail with NLLLoss like that?
image

Here my code identical to your tensorflow code:
image

It seem that my loss_criterion(logits_fake, labels) explode...keep getting more negative and more negative at each batch...same thing with loss_d+=d_loss_real + d_loss_fake. D_model call the discriminator D show above.

from admd.

Related Issues (14)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.