Giter Club home page Giter Club logo

vdcnn's Introduction

VDCNN

Tensorflow Implementation of Very Deep Convolutional Neural Network for Text Classification, proposed by Conneau et al.

Archiecture for VDCNN is now correctly re-implemented with Tensorflow 2 and tf.keras support. A simple training interface is implemented following Tensorflow 2 Expert Tutorial. Feel free to contribute additional utilities like TensorBoard support.

Side Note, if you are a newcomer for NLP text classification:

  • Please checkout new SOTA NLP methods like transformers or Bert.

  • Check out PyTorch for MUCH BETTER dynamic graphing and dataset object support.

    • Current VDCNN implementation is also extremely easy to be ported onto PyTorch.

Prerequisites

  • Python3
  • Tensorflow >= 2.0
  • tensorflow-datasets
  • numpy

Datasets

The original paper tests several NLP datasets, including DBPedia, AG's News, Sogou News and etc.

tensorflow-datasets is used to support AG's News dataset.

Downloads of those NLP text classification datasets can be found here (Many thanks to ArdalanM):

Dataset Classes Train samples Test samples source
AG’s News 4 120 000 7 600 link
Sogou News 5 450 000 60 000 link
DBPedia 14 560 000 70 000 link
Yelp Review Polarity 2 560 000 38 000 link
Yelp Review Full 5 650 000 50 000 link
Yahoo! Answers 10 1 400 000 60 000 link
Amazon Review Full 5 3 000 000 650 000 link
Amazon Review Polarity 2 3 600 000 400 000 link

Parameters Setting

The original paper suggests the following details for training:

  • SGD optimizer with lr 1e-2, decay 0.9.
  • 10 - 15 epochs for convergence.
  • He Initialization.

Some additional parameter settings for this repo:

  • Gradient clipping with norm_value of 7.0, to stablize the training.

Skip connections and pooling are correctly implemented now:

  • k-maxpooling.
  • maxpooling with kernel size of 3 and strides 2.
  • conv pooling with K_i convolutional layer.

For dotted skip connections:

  • Identity with zero padding.
  • Conv1D with kernel size of 1.

Please refer to Conneau et al for their methodology and experiment section in more detail.

Experiments

Results are reported as follows: (i) / (ii)

  • (i): Test set accuracy reported by the paper (acc = 100% - error_rate)
  • (ii): Test set accuracy reproduced by this Keras implementation

TODO: Feel free to report your own experimental results in the following format:

Results for "Identity" Shortcut, "k-max" Pooling:

Depth ag_news DBPedia Sogou News
9 layers 90.17 / xx.xxxx 98.44 / xx.xxxx 96.42 / xx.xxxx
17 layers 90.61 / xx.xxxx 98.39 / xx.xxxx 96.49 / xx.xxxx
29 layers 91.33 / xx.xxxx 98.59 / xx.xxxx 96.82 / xx.xxxx
49 layers xx.xx / xx.xxxx xx.xx / xx.xxxx xx.xx / xx.xxxx

Reference

Original preprocessing codes and VDCNN Implementation By geduo15

Train Script and data iterator from Convolutional Neural Network for Text Classification

NLP Datasets Gathered by ArdalanM and Others

vdcnn's People

Contributors

cjiang2 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

vdcnn's Issues

Training fails

Hi,

I run your code on ag_news, but every time encounter a training failure. (test acc is 25%).

Look like the way you recovered has some issues. any ideas about this?

Unknown characters

In the original paper they allocate an encoding character for all characters outside the range they actually encode. It isn't obvious to me that you have done this in your code. Any reason? Or am I just not seeing where that is being done?

Regarding Data Set

Hello,

          For Dbpedia Data set and other data set like Amazon review they have 3 columns. 
          First one is class or target. Between second column and third column which one we have to select. 
          I have trained  different model on third column data. I have ignored 2nd column(What is the use of that). 
          I got 99.02% test accuracy. I has beat author results. Does I have taken right columns?

Thanks

bug in load_csv_file()

the code in line 33 is:
if i > sequence_max_length - 1:
I think it should be:
if i >= sequence_max_length - 1

There are a `for` loop in `Convolutional_Block` and only the last conv out is used,why?

The code is:

for i in range(2):
  with tf.variable_scope("conv1d_%s" % str(i)):
      filter_shape = [3, inputs.get_shape()[2], num_filters]
      W = tf.get_variable(name='W', shape=filter_shape, 
          initializer=he_normal,
          regularizer=regularizer)
      out = tf.nn.conv1d(inputs, W, stride=1, padding="SAME")
      out = tf.layers.batch_normalization(inputs=out, momentum=0.997, epsilon=1e-5, 
                                      center=True, scale=True, training=is_training)
      out = tf.nn.relu(out)
      print("Conv1D:", out.get_shape())

Issue about running your code

I encounteded the following question when running your code(tf == 1.1.0):

Traceback (most recent call last):
File "C:/Users/syrup/Documents/VDCNN-master/train.py", line 47, in
use_k_max_pooling=FLAGS.use_k_max_pooling)
File "C:\Users\syrup\Documents\VDCNN-master\vdcnn_9.py", line 54, in init
initializer=conv_initializer)
File "C:\Program Files\Anaconda3\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 1049, in get_variable
use_resource=use_resource, custom_getter=custom_getter)
File "C:\Program Files\Anaconda3\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 948, in get_variable
use_resource=use_resource, custom_getter=custom_getter)
File "C:\Program Files\Anaconda3\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 356, in get_variable
validate_shape=validate_shape, use_resource=use_resource)
File "C:\Program Files\Anaconda3\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 341, in _true_getter
use_resource=use_resource)
File "C:\Program Files\Anaconda3\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 714, in _get_single_variable
validate_shape=validate_shape)
File "C:\Program Files\Anaconda3\lib\site-packages\tensorflow\python\ops\variables.py", line 197, in init
expected_shape=expected_shape)
File "C:\Program Files\Anaconda3\lib\site-packages\tensorflow\python\ops\variables.py", line 275, in _init_from_args
initial_value(), name="initial_value", dtype=dtype)
File "C:\Program Files\Anaconda3\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 690, in
shape.as_list(), dtype=dtype, partition_info=partition_info)
TypeError: call() got an unexpected keyword argument 'partition_info'

load_csv_file() only loads description field

text = row['fields'][-1].lower()

This code in function load_csv_file() doesn't load title part of the text. However, I think in the paper, both title and description part of the text are put into training.

Questions about the test error

  1. I‘ve run your code on AG News dataset, and I get high accuracy in train step,but a relatively lower and unstable accuracy in test step. If I set the is_training=Ture in test step, I will get a good result, is there problems in the batch norm?

  2. What is the use of fixed_padding after pooling layers, I did't see such an operation in the original paper.

Minor Bug

ag news dataset test doesnt reach 1024 threshold but if you run the code for yahoo answers which reaches beyond 102 you'll get an error because you substitute character before checking for reaching 1024.The condition of reaching max sequence length should be on top of all!

decay step problem

HI,
In the train.py, there is a line of code
lr_decay_fn = lambda lr, global_step : tf.train.exponential_decay(lr, global_step, 100, 0.95, staircase=True)
May I ask why the decay step set to be 100? I saw other code sometimes set to around 10000.
How to decide this paramter?

Cannot save model

Hi, your code is really an interesting implementation.
However, I faced problems when saving the model when running model.to_json() in train.py
The error message looks like:

File "train.py", line 71, in train
model_json = model.to_json()
File "/home/mzli/.local/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py", line 2264, in to_json
model_config = self._updated_config()
File "/home/mzli/.local/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py", line 2221, in _updated_config
config = self.get_config()
File "/home/mzli/.local/lib/python3.8/site-packages/tensorflow/python/keras/engine/functional.py", line 598, in get_config
return copy.deepcopy(get_network_config(self))
File "/home/mzli/.local/lib/python3.8/site-packages/tensorflow/python/keras/engine/functional.py", line 1278, in get_network_config
layer_config = serialize_layer_fn(layer)
File "/home/mzli/.local/lib/python3.8/site-packages/tensorflow/python/keras/utils/generic_utils.py", line 250, in serialize_keras_object
raise e
File "/home/mzli/.local/lib/python3.8/site-packages/tensorflow/python/keras/utils/generic_utils.py", line 245, in serialize_keras_object
config = instance.get_config()
File "/home/mzli/.local/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer.py", line 676, in get_config
raise NotImplementedError('Layer %s has arguments in __init__ and '
NotImplementedError: Layer KMaxPooling has arguments in __init__ and therefore must override get_config.

May I get some advice on this issue? Many thanks!!

accuracy question when running vdcnn29

hello,
while running vdcnn9 use AG dataset without k maxpooling,I got result same as the paper(89.83%)
but when goes to vdcnn29 with k maxpooling,I only got accuracy 90.4% while the paper report 91.27%
I want to know the acc when you went through this network,thanks!

confused about the embedding

great work dude, sorry for my rude, I was just confused about the embedding layer in code

for my knowledge, it usually be Embedding(input_dim=$vocab_size, output_dim=$embedded_size, input_length=$input_size)

Issue Memory Error

I am getting a Memory Error when trying to run for yahoo answers dataset!I ll try to merge in dataset api in order to cache batches and not the whole dataset into memory.If i get through this maybe I ll fork from yours and contribute to repo!(The bug isn't sprecified exactly because what I get is Memory Error and not something more indicative)

Please state which license (Apache?)

Thanks for doing this work - it saves me from following the same paper and reproducing it myself. At least it will if you are willing to give this an Apache license, as I need to implement a commercial version of this and cannot without a license that allows me to.

Are you willing to issue under the Apache license and add that to the source code and MD?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.