Giter Club home page Giter Club logo

trade-dst's Introduction

TRADE Multi-Domain and Unseen-Domain Dialogue State Tracking

License: MIT

This is the PyTorch implementation of the paper: Transferable Multi-Domain State Generator for Task-Oriented Dialogue Systems. Chien-Sheng Wu, Andrea Madotto, Ehsan Hosseini-Asl, Caiming Xiong, Richard Socher and Pascale Fung. ACL 2019. [PDF]

This code has been written using PyTorch >= 1.0. If you use any source codes or datasets included in this toolkit in your work, please cite the following paper. The bibtex is listed below:

@InProceedings{WuTradeDST2019,
  	author = "Wu, Chien-Sheng and Madotto, Andrea and Hosseini-Asl, Ehsan and Xiong, Caiming and Socher, Richard and Fung, Pascale",
  	title = 	"Transferable Multi-Domain State Generator for Task-Oriented Dialogue Systems",
  	booktitle = 	"Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)",
  	year = 	"2019",
  	publisher = "Association for Computational Linguistics"
}

Abstract

Over-dependence on domain ontology and lack of knowledge sharing across domains are two practical and yet less studied problems of dialogue state tracking. Existing approaches generally fall short in tracking unknown slot values during inference and often have difficulties in adapting to new domains. In this paper, we propose a Transferable Dialogue State Generator (TRADE) that generates dialogue states from utterances using a copy mechanism, facilitating knowledge transfer when predicting (domain, slot, value) triplets not encountered during training. Our model is composed of an utterance encoder, a slot gate, and a state generator, which are shared across domains. Empirical results demonstrate that TRADE achieves state-of-the-art joint goal accuracy of 48.62% for the five domains of MultiWOZ, a human-human dialogue dataset. In addition, we show its transferring ability by simulating zero-shot and few-shot dialogue state tracking for unseen domains. TRADE achieves 60.58% joint goal accuracy in one of the zero-shot domains, and is able to adapt to few-shot cases without forgetting already trained domains.

Model Architecture

The architecture of the proposed TRADE model, which includes (a) an utterance encoder, (b) a state generator, and (c) a slot gate, all of which are shared among domains. The state generator will decode J times independently for all the possible (domain, slot) pairs. At the first decoding step, state generator will take the j-th (domain, slot) embeddings as input to generate its corresponding slot values and slot gate. The slot gate predicts whether the j-th (domain, slot) pair is triggered by the dialogue.

Data

Download the MultiWOZ dataset and the processed dst version.

❱❱❱ python3 create_data.py

An example of multi-domain dialogue state tracking in a conversation. The solid arrows on the left are the single-turn mapping, and the dot arrows on the right are multi-turn mapping. The state tracker needs to track slot values mentioned by the user for all the slots in all the domains.

Dependency

Check the packages needed or simply run the command

❱❱❱ pip install -r requirements.txt

Multi-Domain DST

Training

❱❱❱ python3 myTrain.py -dec=TRADE -bsz=32 -dr=0.2 -lr=0.001 -le=1

Testing

❱❱❱ python3 myTest.py -path=${save_path}
  • -bsz: batch size
  • -dr: drop out ratio
  • -lr: learning rate
  • -le: loading pretrained embeddings
  • -path: model saved path

Unseen Domain DST

Zero-Shot DST

Training

❱❱❱ python3 myTrain.py -dec=TRADE -bsz=32 -dr=0.2 -lr=0.001 -le=1 -exceptd=${domain}

Testing

❱❱❱ python3 myTest.py -path=${save_path} -exceptd=${domain}
  • -exceptd: except domain selection, choose one from {hotel, train, attraction, restaurant, taxi}.

Few-Shot DST with CL

Training Naive

❱❱❱ python3 fine_tune.py -bsz=8 -dr=0.2 -lr=0.001 -path=${save_path_except_domain} -exceptd=${except_domain}

EWC

❱❱❱ python3 EWC_train.py -bsz=8 -dr=0.2 -lr=0.001 -path=${save_path_except_domain} -exceptd=${except_domain} -fisher_sample=10000 -l_ewc=${lambda}

GEM

❱❱❱ python3 GEM_train.py -bsz=8 -dr=0.2 -lr=0.001 -path={save_path_except_domain} -exceptd=${except_domain}
  • -l_ewc: lambda value in EWC training

Bug Report

Feel free to create an issue or send email to [email protected]

trade-dst's People

Contributors

jasonwu0731 avatar

Watchers

James Cloos avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.