hsiaoyetgun / esim Goto Github PK
View Code? Open in Web Editor NEWTensorFlow implementation of the ESIM model (Enhanced LTSM for natural language inference)
TensorFlow implementation of the ESIM model (Enhanced LTSM for natural language inference)
Hi there,
Thanks for sharing the code. For attention part in model.py, your code is:
attentionSoft_b = tf.nn.softmax(tf.transpose(attentionWeights))
attentionSoft_b = tf.transpose(attentionSoft_b)
while I feel like it should be:
attentionSoft_b = tf.nn.softmax(attentionWeights, axis=1)
or you should indicate the "perm" in transpose function.
Please correct me if I'm wrong, thanks!
https://github.com/HsiaoYetGun/ESIM/blob/master/Model.py#L169
attentionSoft_b = tf.nn.softmax(tf.transpose(attentionWeights))
这里对attentionWeights进行transpose后,生成的张量的形状为 ( seq_length, seq_length, batch_size )
然后在对上一步的结果进行softmax,tf.nn.softmax默认在最后一个维度作softmax,
那岂不是在batch上作softmax ?求相互指教。
ub16c9@ub16c9-gpu:/media/ub16c9/fcd84300-9270-4bbd-896a-5e04e79203b7/ub16_prj/ESIM-tf$ python3.6 Train.py
Using TensorFlow backend.
CMD : python3 Train.py --num_epochs 300 --batch_size 32 --dropout_keep_prob 0.5 --clip_value 10 --learning_rate 0.0004 --l2 0.0 --seq_length 100 --optimizer adam --early_stop_step 5000000 --threshold 0 --embedding_size 300 --embedding_normalize 1 --hidden_size 300 --attention_size 300 --eval_batch 1000 --vocab_path data/vocab.txt --embedding_path data/embeddings.pkl --trainset_path data/train.txt --devset_path data/dev.txt --testset_path data/test.txt --save_path ./model/checkpoint --best_path ./model/bestval --log_path ./config/log/log --config_path ./config/config.yaml
Training with following options :
------------- HYPER PARAMETERS -------------
attention_size: 300
batch_size: 32
best_path: ./model/bestval
clip_value: 10.0
config_path: ./config/config.yaml
devset_path: data/dev.txt
dropout_keep_prob: 0.5
early_stop_step: 5000000
embedding_normalize: 1
embedding_path: data/embeddings.pkl
embedding_size: 300
eval_batch: 1000
hidden_size: 300
l2: 0.0
learning_rate: 0.0004
log_path: config/log/log.2019_06_21_08_26_24
n_classes: 3
n_vocab: 47955
num_epochs: 300
optimizer: adam
save_path: ./model/checkpoint
seq_length: 100
testset_path: data/test.txt
threshold: 0
trainset_path: data/train.txt
vocab_dict_size: 47955
vocab_path: data/vocab.txt
embeded_left : (?, 100, 300)
embeded_right : (?, 100, 300)
a_bar : (?, 100, 600)
b_bar : (?, 100, 600)
att_wei : (?, 100, 100)
att_soft_a : (?, 100, 100)
att_soft_b : (?, 100, 100)
a_hat : (?, 100, 600)
b_hat : (?, 100, 600)
a_diff : (?, 100, 600)
a_mul : (?, 100, 600)
m_a : (?, 100, 2400)
m_b : (?, 100, 2400)
v_a : (?, 100, 600)
v_b : (?, 100, 600)
v_a_avg : (?, 600)
v_a_max : (?, 600)
v : (?, 2400)
Loading training and validation data ...
Traceback (most recent call last):
File "Train.py", line 164, in
train()
File "Train.py", line 49, in train
premise_train, premise_mask_train, hypothesis_train, hypothesis_mask_train, y_train = sentence2Index(arg.trainset_path, vocab_dict)
File "/media/ub16c9/fcd84300-9270-4bbd-896a-5e04e79203b7/ub16_prj/ESIM-tf/Utils.py", line 189, in sentence2Index
labelList = enc._fit_transform(labelList)
AttributeError: 'OneHotEncoder' object has no attribute '_fit_transform'
ub16c9@ub16c9-gpu:/media/ub16c9/fcd84300-9270-4bbd-896a-5e04e79203b7/ub16_prj/ESIM-tf$
Traceback (most recent call last):
File "Train.py", line 165, in
train()
File "Train.py", line 94, in train
_, batch_loss, batch_acc = sess.run([model.train, model.loss, model.acc], feed_dict=feed_dict)
File "/home/prf/anaconda3/envs/prfenv/lib/python3.7/site-packages/tensorflow/python/client/session.py", line 929, in run
run_metadata_ptr)
File "/home/prf/anaconda3/envs/prfenv/lib/python3.7/site-packages/tensorflow/python/client/session.py", line 1128, in _run
str(subfeed_t.get_shape())))
ValueError: Cannot feed value of shape (32, 100) for Tensor 'premise_actual_length:0', which has shape '(?,)'
had you perform tree-Lstm with tensorflow?
thank you for your share . when i run this code. here is an error as follows:
ValueError: Cannot feed value of shape (32, 100) for Tensor 'premise_actual_length:0', which has shape '(?,)'
i know that it's a simple problem,and just the shape is not match. but how to solve it?
lstm需要的输入的应该是[seq_len,batch_size,emb_size],代码里面需要transpose一下
i just want to konw the accuracy of your ESIM model? Have you got 88%?
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.