Comments (24)
Hey @xintao-xiang, I think by using Equation (6) I think they mean the architecture remains the same. As you can see in the MLPDEncoder they use the same architecture and just apply one-hot encoding to the input and softmax to the output. So I am trying to get the discrete version of the code to work as well. However, so far, I have manged only to prove, based on the paper, that I need to use the MLPDEncoder.
I have no idea whether to use the original Decoder and DiscreteDecoder. Also I am not sure if I have to change the loss functions. Let me know how you get on with your code. Maybe we can help each other out?
from dag-gnn.
Hi @ItsyPetkov ,
I can tell how I tried to implement it for your reference, though I did not have any proof if it is what we want.
I assume that the encoder is trying to get a latent space, so I still use the MLPEncoder for all the variables. The shape of weight depends on the MAX dimension of input variable (usually the one-hot representation for discrete variables with most values). For variables with smaller dimensions, I just insert 0s and do not use them in calculation.
For the decoder, I modify the MLPDiscreteDecoder, and similar to the encoder, all the output variables are assumed to have the same dimension. Different softmax layers are used (somewhat hard-coded) for the outputs of different discrete variables just like the image shows (the first one has 6 discrete values while the other two have 2 discrete values).
For the loss function, I use nll_catogorical for discrete variables, with others the same in the code provided by the author.
Hope that helps and I'll really appreciate if you have any idea to share.
from dag-gnn.
Hi @xintao-xiang if you are using the MLPEncoder then how do you one-hot encode or do you do something else?
from dag-gnn.
@ItsyPetkov I just one-hot encode all the discrete ones and forward it the encoder. Say we have X1 (2 values), X2 (3 values), then I will one-hot encode them and insert a 0's column to X1, so now we have Nx2x3 data matrix. Then just take input dimension as 3 and hidden dimension whatever we want.
And again I don't know if it is correct or not, but looks reasonable as the latent space is just some representation that does not need softmax...
from dag-gnn.
@xintao-xiang Yeah, alright makes sense. I do the same thing but with the MLPDEncoder with benchmark data which has a finite cardinality of 4. Meaning that for every piece of data in dataset there are only 4 possible categories. So the output of my version of the encoder is of shape XY4. However, what is the output shape of your decoder?
from dag-gnn.
@ItsyPetkov The output shape of decoder is just the same as the input of encoder. So following the example, the output shape is Nx2x3 but with softmax of two dimensions in X1, and softmax of three dimensions in X2. Then I just ignore the redundant dimensions and calculate the loss with meaningful ones.
from dag-gnn.
@xintao-xiang Alright yeah makes sense. I did the same thing. The only difference I see so far is the fact that your KL-Divergence term is calculated using the same function the authors have provided. However, they have also provided two such functions for categorical data as well. Maybe try using them? They are in the utils.py. I haven't tried to use them yet so I do not know what will happen but it is worth a shot?
from dag-gnn.
@ItsyPetkov Yeah it is worth a try but I did not see any mathematical insight of using the two, do you have any idea?
from dag-gnn.
@xintao-xiang well I have tried both of them and they do not improve the result at all. But I think my version of the model is wrong because I use softmax in the encoder so I cannot say if it is good to use them or not.
from dag-gnn.
Hey @xintao-xiang have you tried to check the torch.matmul() line in the forward function of the Discrete Decoder. There is broadcasting that happens there. That might be causing the result to be wrong.
from dag-gnn.
Hi @ItsyPetkov , matmul should only broadcast matrix A, which should be correct.
from dag-gnn.
@xintao-xiang hmm well if that is the case I literally have no idea where a potential mistake might be. What do you do in your forward function for the Decoder. I assume you go through the identity function, then the matmul matrix multiplication and then the result goes through the subsequent layers which you have. Is that assumption correct or do you have more stuff added in there?
from dag-gnn.
@ItsyPetkov Yes, that's correct. Did any problem raised in using this setting on your side?
from dag-gnn.
@xintao-xiang No, that is the problem. I cannot prove that what I am doing is right at this point. :(
from dag-gnn.
@ItsyPetkov Well, I'm not sure if that's correct either. But I guess you could try creating a synthetic dataset with some really simple relationships and see if that works as expected. And please tell me if you do that because I'm also curious about the result hhhhhha :)
from dag-gnn.
@xintao-xiang I managed to get my hands on one of the benchmark datasets so I am testing with that but the true positive rate is 43% and the false discovery rate is about 66%. We are on the right track but it is not complete right at the moment.
from dag-gnn.
@ItsyPetkov Have you tried tuning the parameters? Such model sometimes can be sensitive to hyperparameters.
from dag-gnn.
@xintao-xiang Not really, that is a good idea though. I'll try and see what I find. Thank you!
from dag-gnn.
@xintao-xiang what are you using for the one hot encoding of the data prior to feeding into the encoder? Are you using nn.Embedding?
from dag-gnn.
@xintao-xiang what are you using for the one hot encoding of the data prior to feeding into the encoder? Are you using nn.Embedding?
from dag-gnn.
@ItsyPetkov Sorry for the late reply, I use one hot encoding. But I guess in theory nn.Embedding should also work.
from dag-gnn.
@xintao-xiang I think there is a fundamental problem with the model as it is AE not VAE. You need to add reparameterization step and you need to fix the KLD as it is wrong.
from dag-gnn.
@ItsyPetkov Yes, it looks like AE not VAE. But does that give better results? In fact I have noticed this and modified the code, but it would produce some strange results and it cannot even manage to reconstruct the input samples.
from dag-gnn.
@xintao-xiang In theory it should. I haven't managed to do it yet though. I have only managed to get the same result. I did it through tweaking hyperparameters.
from dag-gnn.
Related Issues (20)
- Updating the parameters of A in the main function in train.py
- why save encoder weight in encoder_file failed
- Are there any tricks in this expression? HOT 2
- cannot import name 'complete_to_chordal_graph' HOT 2
- about transpose in line 619-629 of utils.py HOT 3
- Synthetic data, how to get X without noise?
- is train.py for the VAE implementation ? latent variable z is not sampled from its posterior p(z|x) HOT 2
- confused about `Wa` in MLPEncoder
- Confused about 'nan error'
- How to get the data/ directory
- What does the input data G and X look like?
- how to use dag-gnn as a package
- Questions about input data
- How do I train a network on my own data? HOT 2
- The element value of adjacency matrix HOT 2
- a bug report
- KL loss in 340th line, train.py HOT 3
- Reparameterization HOT 1
- Why are feat_train, feat_valid, and feat_test values the same? HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from dag-gnn.