Giter Club home page Giter Club logo

courseproject's Introduction

Documentation

In the following, we will describe how the text classification model is build and its major components.

Project Overview

In this project, our team is doing the text classification competition. The specific task is sarcasm detection, given a specific Twitter response, we are trying to tell if the response is sarcasm. Each data point in the given dataset consist of the response tweet and also the context of the response. The goal of this project is to develop a robust model that can perform well in telling whether the given Twitter response is sarcasm or not.

DEMO Link

Classification Model

We have experimented with many different models. In the end, we reached the competition benchmark score by fine-tuning a pre-trained BERT, distilled BERT in specific.

Data Preprocessing

For data cleaning, we remove punctuations and any other special characters in the Twitter response. Also, we expand abbreviations such as can't and won't into can not and will not respectively. We also remove the heading of each response. After cleaning the data, we use a pre-trained DistilBertTokenizer to tokenize the cleaned response data. Then, the tokenized responses are used as the input to our model.

Model Architecture

The general idea of our model is to fine-tune the pre-trained distilled BERT for text classification. We achieved this by adding two extra fully connected linear layers to the BERT output and fixing the parameters for the BERT model. The general pipeline of the model is that given tokenized responses as input, we first put the inputs to the distilled BERT base model to get high-dimensional representations of the responses. Then the response representations are input to the two linear layers to get the final prediction of whether the response is sarcasm. Between the linear layers, we used ReLU as activation. We also applied dropout to both the output of the base model and the output of the linear layers.

Model Training

Given raw Twitter response, we first preprocessing it following the data preprocessing steps to get tokenized responses. Then, the tokenized responses are used as the input to the BERT base model, which will give high-dimensional representations of the responses. Then, those representations are put into several linear layers to generate the final prediction. The loss function we use is a NLL loss. After getting the training prediction, together with the ground truth labels that indicate whether the response is sarcasm or not, we put them to the NLLLoss and performs back propagation on the computed loss.

Evaluation

During training, we split the train dataset into two subset, one for actual training, one for validation. The percentage of the validation set is 20% of the data point in the original training dataset. In each epoch, we evaluate the F1 score of the model on the validation set save the model having the best F1 score. For the actual prediction task on the test set, we use the saved model during training for the actual prediction.

Previous Attempts

We have come a long way to the model we have right now. We first thought of models based on CNN and RNN. But after actually implemented them, those models did not give results good enough to beat the competition benchmark. Apart from distilled BERT, we have also experimented with the full BERT, which gives decent result, but it tends to overfit and takes a lot more time to run. We have also tried to vary the number of linear layers and the dimension of those layers used to fine-tune the model, we have tried to add 3 or 4 linear layers and many other different combinations of dimension, but we finalized to 2 linear layers, which are of size(768, 256) and (256, 2). In terms of different activation function, we have tried Tanh, PReLU, and LeakyReLU. Though they all give very similar results, we choose ReLU in the end. We have also experimented the dropout ratio in the range [0, 0.5]. We observed that with 0.5 dropout, the model reaches best performance on the test set. We also tried to tune the learning rate in the range[0.00001, 0.01].

For data preprocessing, we found that removing stopwords and stemming the words has negatively affected the performance of our model. Expanding abbreviations seems to have improve the performance of the model by reducing overfitting. Removing punctuations and special characters generally gives cleaner data for the tokenizer. Therefore, it helps both with the model training and model testing.

Dependencies

  • Python
  • Json
  • PyTorch
  • Skit-Learn
  • Transformers

To install dependencies, you can use the included environment.ysml in the code directory to create a virtual environment with Anaconda. Installation reference can be found here. A detailed tutorial is also included in the DEMO.

Contributions

  • Junting Wang: Team Leader. Implemented the model and written up the code documentation.
  • Tianwei Zhang: Helped with experiments, model testing and project DEMO.

courseproject's People

Contributors

junting98 avatar bhaavya avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.