Giter Club home page Giter Club logo

gmn_chatbot's Introduction

GMN Chatbot

Introduction

This is a repository for the chatbot based on Graph Matching Network (GMN), which is our final project for the course: New Frontier Artificial Intelligence II at U-Tokyo.

Dataset

We used the following English dialogues between patient and doctor to train our chatbot.

https://drive.google.com/drive/folders/1g29ssimdZ6JzTST6Y8g6h-ogUNReBtJD

Dataloader

  • Please prepare the data by download the text files in the google drive link above. There should be the following files in the data directory
"data/healthcaremagic_splitted_idname_1.csv",
"data/healthcaremagic_splitted_idname_2.csv",
"data/healthcaremagic_splitted_idname_3.csv",
"data/healthcaremagic_splitted_idname_4.csv",
"data/icliniq_splitted_idname.csv"
  • The preprocessing from text file to csv and generate the tokenized dataset files, please run preprocessing/reformat_text_data.py
python preprocessing/reformat_text_data.py
  • To get the (pytorch)dataloader, import from the following module. Note that the tokenization was done by huggingface bert-base-uncased tokenizer and the maximum length default is 256
from preprocessing.get_en_dataloader import get_training_dev_test_dataset

train_dataset, dev_dataset, test_dataset = get_training_dev_test_dataset(debugging=False, max_length=256)

print(train_loader[0])

which the dataset can be used in Huggingface trainer. In case of manual training, please wrap it with pytorch dataloader

from torch.utils.data import DataLoader

# For train loader
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)

Default Data split

Note that the data is split by the following

Train:
    - data/healthcaremagic_splitted_idname_1.csv
    - data/healthcaremagic_splitted_idname_2.csv
    - data/healthcaremagic_splitted_idname_3.csv

Dev:
    - data/healthcaremagic_splitted_idname_4.csv

Test:
    - data/icliniq_splitted_idname.csv
  • The dataloader will return the following dictionary for each different index
Dict{
    'input_ids': tensor(2, max_length),
    'token_type_ids': same,
    'attention_mask': same,
    'doctor_input_ids': tensor(#negative_sample + 1, max_length),
    'doctor_token_type_ids': same,
    'doctor_attention_mask': same,

}

There are 2 things to note here,

  • In all samples there are description, patient response and doctor response, in total 3-turns dialogue. So here, we treat the description and patient as first 2-turns dialogue and ask the model to output the probability of the third turn
  • the negative samples are sampled by randomly chosen from different response in other conversation. The correct response is always in the first index(0) and followed by #negative_sample number of wrong response.

BERT Baseline

Alt text

gmn_chatbot's People

Contributors

barnrang avatar coldog2333 avatar tiigerchayaphol avatar

Watchers

 avatar  avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.