Giter Club home page Giter Club logo

bcn-cnn-text-classification's Introduction

Text Classification

Data Preprocessing and EDA

  • It was observed that the target distribution was not uniform.
  • It was observed that num_sentence and reflection_period did not correlate with the targets and are unnecessary for training.
  • The text in hm_train.csv and hm_test.csv was tokenized and joined with spaces in between. This was done to make sure that all tokens (including punctuations) are space-separated. The cleaned data was saved in hm_train_cleaned.csv and hm_test_cleaned.csv respectively.
  • The training data was split into train and validation sets and saved in train.csv and test.csv.

The code for preprocessing and EDA is pressent in preprocessing.ipynb.

Approaches Explored

  • TextCNN
  • Bi-attentive Classification Network

TextCNN

The TextCNN consisted of a 3-layer 1D ConvNet followed by two Dense layers. 300 dimensional GloVe word embeddings were used for word vectorization. Dropout was used to avoid overfitting. Despite adding dropout, val accuracy was saturating, while training accuracy was increasing. Weighted cross-entropy loss was used to tackle the class imbalance. The code for the experiment is present in the notebook TextCNN/TextCNN.ipynb.

Score on submission: 0.8645

Bi-attentive Classification Network

The approach and the model architecture was introduced in the Contextualized Word Vectors paper. A brief overview goes as follows:

  1. The input text sequence S = [w1, w2, w3, ..... wn] is vectorized using an embedding of choice. Let the vectors be denoted by

    E = [e1, e2, e3, ..... en]

    where n = No. of words in sequence

  2. The vectorized sequence is duplicated to give two copies of the same - E1 and E2.

  3. A feedforward network with ReLU activation is used on each word vector. This is the pre-encoding step used for dimensionality reduction of the embeddings.

    E1' = FC_Layer(E1)

    E2' = FC_Layer(E2)

  4. The output is then passed through a Bi-directional LSTM to obtain context aware representations of the words in the input sequence.

    X = Bi-LSTM(E1')

    Y = Bi-LSTM(E2')

  5. Compute Anxn = XYT. The matrix will contain the dot products of each xi with every yi. Intutively, it will contain information on how each word is related to every other in the input sequence.

  6. Compute:

    Ax = softmax(A)

    Ay = softmax(AT)

    where softmax() represents column wise softmax. Here, Ax and Ay will be nxn matrices where each column sums to 1. This gives us normalized weights which can be used for attention.

  7. In this step we apply the attention weights obtained, to the corresponding matrices X and Y.

    Cx = AxT . X

    Cy = AyT . Y

  8. The conditioned matrices Cx and Cy along with X and Y are passed through another Bi-LSTM to integrate the obtained information.

    X|Y = Bi-LSTM(X, Cy)

    Y|X = Bi-LSTM(Y, Cx)

    These outputs are then min-pooled along the column.

  9. The pooled representations Xpool and Ypool are then concatenated and passed through a final Dense layer with softmax at the end to obtain the probablity distribution over the classes.

Bi-atteneive Classification Network

Implementation Details

TextCNN

Libraries used:

  • Keras: Used for training
  • Scikit-Learn: Used for preprocessing and metrics

For more details, refer to TextCNN/TextCNN.ipynb. The notebook contains code for loading data, building and training the model and predicting on the test set.

Score: 0.8645

Bi-attentive Classification Network

AllenNLP, a high level NLP research library was used for this purpose. The library had to be extended to add support for reading from and predicting on differently formatted data. The library extension can be found in the BCN/mylibrary folder. A customizable implementation of the Bi-attentive Classification Network is available in the library. Many experiments were performed with different choices of word embeddings, model size, learning rates and learning rate schedulers:

  • 300 Dimensional GloVe embeddings + Medium Sized model + lr=0.001 + 20 epochs

    (See BCN/bcn.jsonnet)

    Score: 0.8766

  • 300 Dimensional GloVe embeddings + Medium Sized model + lr=0.0005 + LR scheduler (lr/2 on plateau for 3 epochs) + 15 epochs

    (See BCN/bcn_lrscheduler.jsonnet)

    The model seemed to overfit in the later epochs. Best val accuracy was in the 12th epoch. Perhaps increasing dropout can help.

    Score: 0.8827

  • 300 Dimensional GloVe + 512 Dimensional ELMo embeddings + Large Sized Model + lr=0.0005

    (See BCN/bcn_glove_elmo.jsonnet)

    This configuration caused memory error even on a reduced batch size of 32 (from 100).

  • 300 Dimensional GloVe + 128 Dimensional ELMo embeddings + Small Sized Model + lr=0.001

    (See BCN/bcn_elmo_small.jsonnet)

    This configuration also caused memory error even on a small batch size.

  • 50 Dimensional GloVe + 128 Dimensional ELMo embeddings + Small Sized Model + lr=0.001

    (See BCN/bcn_glove_small_elmo_small.jsonnet)

    Was taking a long time to train (1.5 hrs per epoch). The validation metrics for the 1st and 2nd epoch were not impressive. Training was cancelled and no submission was made.

  • 768 Dimensional BERT Embeddings + Medium Sized Model + lr=0.0005 + 10 epochs

    (See BCN/bcn_bert.jsonnet)

    The accuracy increased rapidly in the early epochs but was saturating after the 8th epoch.

    Perhaps learning rate scheduling can help.

    Score: 0.8738

Running the code

The BCN/*.jsonnet files are configuration files that AllenNLP uses to build the model architecture and train models. The paths to training and validation data are present in these config files.

The training and prediction scripts are provided in the BCN/folder.

For training:

python3 train.py /path/to/config.jsonnet /path/to/model/folder

For predictions:

python3 predict.py /path/to/model.tar.gz /path/to/test.csv /path/to/submission.csv

Directory Structure

.
├── BCN
│   ├── mylibrary
│   │   ├── data
│   │   │   ├── dataset_readers
│   │   │   │   ├── __init__.py
│   │   │   │   └── smiledb.py
│   │   │   └── __init__.py
│   │   ├── predictors
│   │   │   ├── __init__.py
│   │   │   └── smiledb.py
│   │   └── __init__.py
│   ├── bcn_bert.jsonnet
│   ├── bcn_elmo_small.jsonnet
│   ├── bcn_glove_elmo.jsonnet
│   ├── bcn_glove_small_elmo_small.jsonnet
│   ├── bcn.jsonnet
│   ├── bcn_lrscheduler.jsonnet
│   ├── predict.py
│   └── train.py
├── TextCNN
│   └──TextCNN.ipynb
├── hm_test_cleaned.csv
├── hm_test.csv
├── hm_train_cleaned.csv
├── hm_train.csv
├── install_requirements.sh
├── preprocessing.ipynb
├── README.md
├── requirements.txt
├── test.csv
└── train.csv

bcn-cnn-text-classification's People

Contributors

kushalchauhan98 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.