Giter Club home page Giter Club logo

sgm's Introduction

Sequence Generation Model for Multi-label Classification

This is the code for our paper SGM: Sequence Generation Model for Multi-label Classification [pdf]


Note

In general, this code is more suitable for the following application scenarios:

  • The dataset is relatively large:
    • The performance of the seq2seq model depends on the size of the dataset.
  • There exist some orders or dependencies between labels:
    • A reasonable prior order of labels tends to be helpful.

Requirements

  • Ubuntu 16.0.4
  • Python version >= 3.5
  • PyTorch version >= 1.0.0

Dataset

Our used RCV1-V2 dataset can be downloaded from google drive with this link. The structure of the folders on drive is:

Google Drive Root		   # The compressed zip file
 |-- data                          # The unprocessed raw data files
 |    |-- train.src        
 |    |-- train.tgt
 |    |-- valid.src
 |    |-- valid.tgt
 |    |-- test.src
 |    |-- test.tgt
 |    |-- topic_sorted.json        # The json file of label set for evaluation
 |-- checkpoints                   # The pre-trained model checkpoints
 |    |-- sgm.pt
 |    |-- sgmge.pt

We found that the valid-set in the previous version is so small that the model tends to overfit the valid-set, resulting in unstable performance. Therefore, we have expanded the valid-set. In addition, we also filtered out samples that contain more than 500 words in the original RCV1-V2 dataset.


Reproducibility

We provide the pretrained checkpoints of the SGM model and the SGM+GE model on the RCV1-V2 dataset to help you to reproduce our reported experimental results. The detailed reproduction steps are as follows:

  • Please download the RCV1-V2 dataset and checkpoints first by clicking on the link, then put them in the same directory as these codes. The correct structure of the folders should be:
Root
 |-- data                          
 |    |-- ...        
 |-- checkpoints                   
 |    |-- ...
 |-- models                   
 |    |-- ...
 |-- utils                   
 |    |-- ...
 |-- preprocess.py
 |-- train.py
 |-- ...
  • Preprocess the downloaded data:
python3 preprocess.py -load_data ./data/ -save_data ./data/save_data/ -src_vocab_size 50000

All the preprocessed data will be stored in the folder ./data/save_data/

  • Perform prediction and evaluation:
python3 predict.py -gpus gpu_id -data ./data/save_data/ -batch_size 64 -restore ./checkpoints/sgm.pt -log results/

The predicted labels and evaluation scores will be stored in the folder results


Training from scratch

Preprocessing

You can preprocess the dataset with the following command:

python3 preprocess.py \
	-load_data load_data_path \       # input file dir for the data
	-save_data save_data_path \       # output file dir for the processed data
	-src_vocab_size 50000             # size of the source vocabulary

Note that all data path must end with /. Other parameter descriptions can be found in preprocess.py


Training

You can perform model training with the following command:

python3 train.py -gpus gpu_id -config model_config -log save_path

All log files and checkpoints during training will be saved in save_path. The detailed parameter descriptions can be found in train.py


Testing

You can perform testing with the following command:

python3 predict.py -gpus gpu_id -data save_data_path -batch_size batch_size -log log_path

The predicted labels and evaluation scores will be stored in the folder log_path. The detailed parameter descriptions can be found in predict.py


Citation

If you use the above code for your research, please cite the paper:

@inproceedings{YangCOLING2018,
  author    = {Pengcheng Yang and
               Xu Sun and
               Wei Li and
               Shuming Ma and
               Wei Wu and
               Houfeng Wang},
  title     = {{SGM:} Sequence Generation Model for Multi-label Classification},
  booktitle = {Proceedings of the 27th International Conference on Computational
               Linguistics, {COLING} 2018, Santa Fe, New Mexico, USA, August 20-26,
               2018},
  pages     = {3915--3926},
  year      = {2018}
}

sgm's People

Contributors

ypengc7512 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

sgm's Issues

i cant open the Google,can not get the dataset

我打不开谷歌,尝试了很多办法,更改hosts文件,使用谷歌镜像网站都试过了,仍没有办法打开,请问有别的链接可以下载数据集吗,万分感谢,或者发给我一份也好,我的qq:945900130。期待您的帮助

请问在decode中前向传播的时候,为什么rnn的输入是标签向量和state,而没有经过attention后得到的c(t)呢,没看懂

您代码里面的解码的前向传播没怎么看懂
`< def forward(self, inputs, init_state, contexts):
if not self.config.global_emb:
embs = self.embedding(inputs)
outputs, state, attns = [], init_state, []
for emb in embs.split(1):
output, state = self.rnn(emb.squeeze(0), state)
output, attn_weights = self.attention(output, contexts)
output = self.dropout(output)
outputs += [output]
attns += [attn_weights]
outputs = torch.stack(outputs)
attns = torch.stack(attns)
return outputs, state
else:
outputs, state, attns = [], init_state, []
embs = self.embedding(inputs).split(1)
max_time_step = len(embs)
emb = embs[0]
output, state = self.rnn(emb.squeeze(0), state)
output, attn_weights = self.attention(output, contexts)
output = self.dropout(output)
soft_score = F.softmax(self.linear(output))
outputs += [output]
attns += [attn_weights]

        batch_size = soft_score.size(0)
        a, b = self.embedding.weight.size()

        for i in range(max_time_step-1):
            emb1 = torch.bmm(soft_score.unsqueeze(1), self.embedding.weight.expand((batch_size, a, b)))
            emb2 = embs[i+1]
            gamma = F.sigmoid(self.gated1(emb1.squeeze())+self.gated2(emb2.squeeze()))
            emb = gamma * emb1.squeeze() + (1 - gamma) * emb2.squeeze()
            output, state = self.rnn(emb, state)
            output, attn_weights = self.attention(output, contexts)
            output = self.dropout(output)
            soft_score = F.softmax(self.linear(output))
            outputs += [output]
            attns += [attn_weights]
        outputs = torch.stack(outputs)
        attns = torch.stack(attns)
        return outputs, state>`

运行preprocess.py时出错

[dhy@mu01 SGM-master]$ python preprocess.py
Building source vocabulary...
max: 64, min: 1, avg: 5.17
Traceback (most recent call last):
File "preprocess.py", line 261, in
main()
File "preprocess.py", line 234, in main
opt.src_vocab_size)
File "preprocess.py", line 118, in initVocabulary
genWordVocab = makeVocabulary(dataFile, vocabSize, char=char)
File "preprocess.py", line 99, in makeVocabulary
vocab = vocab.prune(size)
File "/home/dhy/multi_label/SGM-master/data/dict.py", line 123, in prune
newDict.add(self.idxToLabel[i])
KeyError: tensor(46)
这是什么原因呢

config.yaml是SGM的配置文件还是SGM+GE的配置文件?

您好,十分感谢您可以将这份工作的数据及代码开源出来。 👍
但是,在使用您的代码过程中,发现了一些小问题 #20 (python train.py时抛出错误)。

另外,我注意到在SGM这篇论文中提到了GE这个组件,且该组件发挥了很重要的作用。但该仓库只提供了一个config.yaml,我不太清楚这个配置是SGM模型的,还是SGM+GE的。我有漏掉什么东西么?您可以同时提供这两个模型的配置文件么?

再次谢谢您百忙之中回答我的问题。:smile:

What is the format of label_test?

Is it meaning that data have different label length.
If my dataset's label is:

id1 -2 -2 -2 -2 1 -2 -2 -2 -2 1 -2 -2 -2 -2 -2 -2 1 -2 1 -2
id2 -2 -2 -2 -2 -2 -2 -2 0 -2 1 0 0 0 0 1 -2 -2 -2 1 -2

I only need to let label train is:
-2 -2 -2 -2 1 -2 -2 -2 -2 1 -2 -2 -2 -2 -2 -2 1 -2 1 -2
Thx

beam设置的问题

多标签训练的时候 beam需要设置吗,我设置成非1,代码报错
candidate += [tgt_vocab.convertToLabels(s.tolist(), utils.EOS) for s in samples]
显示 s 是list对象,无法进行to list操作

Where is the tgt seq in loss computation?

您好!
请问在损失函数计算中预测的标签在代码哪个位置,我想修改下 loss.py 中计算损失的函数,可以解决标签顺序的问题。
多谢了!

Definition of loss function

Hello, can you explain that what is the meaning of the 'sim_score' parameter in the 'cross_entropy_loss' function of the 'models/loss.py' file. I am confused about this. thanks.

梯度没有下降

为什么运行您的代码后从第一次eval()了一下结果:hamming_loss: 0.03198517 | micro_f1: 0.2195。后面一直是这个结果 不变化了呢?

Attention layer

The code for attention layer is a bit different that the one that you reported in the paper?
Could you please let me know if it has any reference ?

RuntimeError: While copying the parameter named decoder.embedding.weight, whose dimensions in the model are torch.Size([58, 256]) and whose dimensions in the checkpoint are torch.Size([107, 256]).

mldl@ub1604:~/ub16_prj/SGM$ python3 predict.py -gpus 0 -log log_name
loading checkpoint...

loading data...

loading time cost: 3.669
building model...

Traceback (most recent call last):
File "/usr/local/lib/python3.5/dist-packages/torch/nn/modules/module.py", line 482, in load_state_dict
own_state[name].copy_(param)
RuntimeError: invalid argument 2: sizes do not match at /pytorch/torch/lib/THC/generic/THCTensorCopy.c:101

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "predict.py", line 91, in
model.load_state_dict(checkpoints['model'])
File "/usr/local/lib/python3.5/dist-packages/torch/nn/modules/module.py", line 487, in load_state_dict
.format(name, own_state[name].size(), param.size()))
RuntimeError: While copying the parameter named decoder.embedding.weight, whose dimensions in the model are torch.Size([58, 256]) and whose dimensions in the checkpoint are torch.Size([107, 256]).
mldl@ub1604:~/ub16_prj/SGM$

why sorting matters?

In the paper, you mentioned that "In addition, the proposed models are trained using the maximum likelihood estimation method and the cross-entropy loss function, which requires humans to predefine the order of the output labels. "
but this doesn't explain for me... It means that order matters because we use seq2seq model(instead of the mentioned 'MLE and cross-entropy'), which is the intrinsic property of model. The prediction order should follow the true order. Thus human-defined order has impact.
Is the reason that high-frequency labels are more likely the root label in a hierarchical category taxonomy and thus more likely to have correlations with other labels?

About train/test split of RCV1

Hi, thanks for your nice paper and code!
I have noticed that the standard split for RCV1 train/test in the original paper is 23,149/781,265. But from the data downloaded from your link, I found the train file is much bigger than test.
I wonder is this correct? Thanks in advance.

想问下关于评价的问题

请问在考虑正确率的时候也考虑也预测的顺序了吗?
假设应该在y1预测出sports这个标签,但是y1预测出tennis,y2才预测出sports,
这样还算正确吗?

论文中的两个小疑问

您好,我对您论文中有两个小疑问,因为不知道怎么打公式,我把问题放在了图片中,希望您能解答,谢谢
ko2 a46lc0tmt eu39 gttn

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.