Giter Club home page Giter Club logo

som-dst's People

Contributors

dsksd avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

som-dst's Issues

the JGA is only 0.27....

Firstly, thanks for the author's good works!
I use your original code and train in this hyperparmeters:
attention_probs_dropout_prob=0.1, batch_size=16, bert_ckpt_path='assets/bert-base-uncased-pytorch_model.bin', bert_config_path='assets/bert_config_base_uncased.json', data_root='data/mwz2.0', dec_lr=0.0001, dec_warmup=0.1, decoder_teacher_forcing=0.5, dev_data='dev_dials.json', dev_data_path='data/mwz2.0/dev_dials.json', dropout=0.1, enc_lr=4e-05, enc_warmup=0.1, eval_epoch=1, exclude_domain=False, hidden_dropout_prob=0.1, max_seq_length=256, msg=None, n_epochs=30, n_history=1, not_shuffle_state=False, num_workers=4, ontology_data='data/mwz2.0/ontology.json', op_code='4', random_seed=42, save_dir='outputs', shuffle_p=0.5, shuffle_state=True, slot_token='[SLOT]', test_data='test_dials.json', test_data_path='data/mwz2.0/test_dials.json', train_data='train_dials.json', train_data_path='data/mwz2.0/train_dials.json', vocab_path='assets/vocab.txt', word_dropout=0.1
actually, I only changed batch_size, and other same as yours.
But, the results were a little off the mark.
op_code: 4, is_gt_op: False, is_gt_p_state: False, is_gt_gen: False
Epoch 20 joint accuracy : 0.2733441910966341
Epoch 20 slot turn accuracy : 0.9547819399203203
Epoch 20 slot turn F1: 0.8490504095046196
Epoch 20 op accuracy : 0.9708785740137087
Epoch 20 op F1 : {'delete': 0.0886426592797784, 'update': 0.7394006048941435, 'dontcare': 0.03618649965205289, 'carryover': 0.984787856819641}
Epoch 20 op hit count : {'delete': 16, 'update': 6723, 'dontcare': 26, 'carryover': 207838}
Epoch 20 op all count : {'delete': 339, 'update': 11265, 'dontcare': 1401, 'carryover': 208035}
Final Joint Accuracy : 0.08108108108108109
Final slot turn F1 : 0.8280395869977584
Latency Per Prediction : 21.975781 ms

I dont know what's wrong with it?

Error while downloading pytorch version

ERROR: Could not find a version that satisfies the requirement torch==1.3.0a0+24ae9b5 (from versions: 0.1.2, 0.1.2.post1, 0.1.2.post2, 0.4.1, 0.4.1.post2, 1.0.0, 1.0.1, 1.0.1.post2, 1.1.0, 1.2.0, 1.3.0, 1.3.1, 1.4.0, 1.5.0, 1.5.1, 1.6.0, 1.7.0, 1.7.1, 1.8.0, 1.8.1, 1.9.0)
ERROR: No matching distribution found for torch==1.3.0a0+24ae9b5

ValueError

'ValueError: max() arg is an empty sequence' in File "/home/fzuirdata/yaozhen/som-dst-master/utils/data_utils.py", line 345, collate_fn, when running script '' train.py

the joint accuracy on pretrained model is 0.43,can't reach 0.53

Hi! Thanks for sharing the pretrained model. I have downloaded your pretrained SOM-DST model , but the result of the evaluation is:

op_code: 4, is_gt_op: False, is_gt_p_state: False, is_gt_gen: False
Epoch 0 joint accuracy : 0.43078175895765475
Epoch 0 slot turn accuracy : 0.9664268910603907
Epoch 0 slot turn F1: 0.8821281530381658
Epoch 0 op accuracy : 0.9676212450234745
Epoch 0 op F1 : {'delete': 0.01583635763774332, 'update': 0.7806280915679281, 'dontcare': 0.13675213675213674, 'carryover': 0.9831123351416803}
Epoch 0 op hit count : {'delete': 24, 'update': 6786, 'dontcare': 32, 'carryover': 207041}
Epoch 0 op all count : {'delete': 2995, 'update': 9351, 'dontcare': 280, 'carryover': 208414}
Final Joint Accuracy : 0.2862862862862863
Final slot turn F1 : 0.8881324359127896
Latency Per Prediction : 26.158629 ms

As you know,the joint accuracy result reported in paper was 0.53, the model you shared is the best model?

Code Error

In line 115 in model.py file, the GRU hidden state in the decoder is not reset to the pooled output of the encoder for each slot but is continuously used as the hidden for subsequent slots, which is different from what is mentioned in the paper.
BUT, IT ALSO WORKS, WHY?

more details about hyper-parameter settings?

Hi,
Thanks for sharing the code! Could you also share the hyper-parameter settings to reproduce 53.09% joint accuracy on MultiWOZ 2.1?

I cannot get this result by directly running the code (without modification) parallelly on two P100 GPUs. I get 52.16% at the 29th epoch (best on dev set).

I also found that a paper ("Efficient Context and Schema Fusion Networks for Multi-Domain
Dialogue State Tracking") reported that your model achieves only 52.57%, which result is now shown at paperswithcode.

how do you handle multiple domains for domain classification loss?

I am wondering how you can handle multiple domains for domain classification loss. In the paper, you describe using a one-hot vector to represent the label. So if the target is multi-domain, e.g. restaurant + hotel, is the one hot vector containing multiple ones, e.g. [1, 1, 0, 0, 0] ?

the value of max_update may be zero

pytorch version: 0.4.1
when evaluate the model, eval batch size is 1,
if max_update is None: max_update = op_ids.eq(self.update_id).sum(-1).max().item()
the value of max_update might be zero. So run the following code:
v = torch.zeros(1, max_update, self.hidden_size, device=input_ids.device) , An error message may be given: RuntimeError: sizes must be non-negative

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.