Giter Club home page Giter Club logo

speech-transformer's Introduction

Speech Transformer: End-to-End ASR with Transformer

A PyTorch implementation of Speech Transformer [1], an end-to-end automatic speech recognition with Transformer network, which directly converts acoustic features to character sequence using a single nueral network.

Install

  • Python3 (recommend Anaconda)
  • PyTorch 0.4.1+
  • Kaldi (just for feature extraction)
  • pip install -r requirements.txt
  • cd tools; make KALDI=/path/to/kaldi
  • If you want to run egs/aishell/run.sh, download aishell dataset for free.

Usage

Quick start

$ cd egs/aishell
# Modify aishell data path to your path in the begining of run.sh 
$ bash run.sh

That's all!

You can change parameter by $ bash run.sh --parameter_name parameter_value, egs, $ bash run.sh --stage 3. See parameter name in egs/aishell/run.sh before . utils/parse_options.sh.

Workflow

Workflow of egs/aishell/run.sh:

  • Stage 0: Data Preparation
  • Stage 1: Feature Generation
  • Stage 2: Dictionary and Json Data Preparation
  • Stage 3: Network Training
  • Stage 4: Decoding

More detail

egs/aishell/run.sh provide example usage.

# Set PATH and PYTHONPATH
$ cd egs/aishell/; . ./path.sh
# Train
$ train.py -h
# Decode
$ recognize.py -h

How to visualize loss?

If you want to visualize your loss, you can use visdom to do that:

  1. Open a new terminal in your remote server (recommend tmux) and run $ visdom.
  2. Open a new terminal and run $ bash run.sh --visdom 1 --visdom_id "<any-string>" or $ train.py ... --visdom 1 --vidsdom_id "<any-string>".
  3. Open your browser and type <your-remote-server-ip>:8097, egs, 127.0.0.1:8097.
  4. In visdom website, chose <any-string> in Environment to see your loss. loss

How to resume training?

$ bash run.sh --continue_from <model-path>

How to solve out of memory?

When happened in training, try to reduce batch_size. $ bash run.sh --batch_size <lower-value>.

Results

Model CER Config
LSTMP 9.85 4x(1024-512). See kaldi-ktnet1
Listen, Attend and Spell 13.2 See Listen-Attend-Spell's egs/aishell/run.sh
SpeechTransformer 12.8 See egs/aishell/run.sh

Reference

  • [1] Yuanyuan Zhao, Jie Li, Xiaorui Wang, and Yan Li. "The SpeechTransformer for Large-scale Mandarin Chinese Speech Recognition." ICASSP 2019.

speech-transformer's People

Contributors

begeekmyfriend avatar kaituoxu avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

speech-transformer's Issues

train.py

Hi, Xu. Thanks for your contribution of this excellent work.
However, in stage -3 of run.sh, I didn't find where the train.py's code is,
Can you tell me ?
Thank you !

Speech-Transformer keeps using Kaldi's Python 2.x

While Speech-Transformer is obvious coded in Python 3

Some strange error occurred, eg.
def init(self, *args, LFR_m=1, LFR_n=1, **kwargs):
SyntaxError: invalid syntax

How to solve it?

Thanks in advance.

decode question

hi,I met some questions when the network was running step4:decoder

run.pl: job failed , log is in exp/............/decoder_test_beam5_nbest1_ml100/decode.log

the log is :
UnicoderEncoderError: 'ascii' codec can't encode character '\u4e00' in position 29: ordinal not in range(128)

No such file or directory: exp/................/decode_test_beam5_nbest1_ml100/data.json

I used the python 3.6.8 and print the setdefaultencoding is utf-8

How to solve the question?

Thank you

About reference

Hi, Could you please share with me the reference paper? looking forword to your reply, thx!

Training time

Hi kaituo,

Thank you very much for your nice code. I was wondering how long will it take to train?

Decoding error

Traceback (most recent call last):
    File "../../src/bin/recognize.py", line 69, in <module>
        recognize(args)
    File "../../src/bin/recognize.py", line 58, in recognize
        nbest_hyps = model.recognize(input, input_length, char_list, args)
    File "/home/tt/Speech-Transformer/src/transformer/transformer.py", line 46, in recognize
        args)
    File "/home/tt/Speech-Transformer/src/transformer/decode.py", line 222, in recognize_beam
        for x in hyp['yes'][0, 1:]]))
UnicodeEncodeError: 'ascii' codec can't encode character '\u751a' in position 6: ordinal not in range(128)

Any solution?

question about "build_LFR_features"

i found any feature the kaldi performed will pass function "build_LFR_features", is it has any special function? what's means by "stacking frames and skipping frames"?

question about the non_pad_mask

Hi kaituo, thanks for sharing such a useful speech transformer.
I have run your code and got the reported results successfully but I have a question about the non_pad_mask, I notice that when calculating the loss, we have masked the padded part. So since we have restricted the loss calculation, why we need to restrict step by step when forwarding.

  • forward
    def forward(self, dec_input, enc_output, non_pad_mask=None, slf_attn_mask=None, dec_enc_attn_mask=None):

        dec_output, dec_slf_attn = self.slf_attn(
            dec_input, dec_input, dec_input, mask=slf_attn_mask)
        dec_output *= non_pad_mask

        dec_output, dec_enc_attn = self.enc_attn(
            dec_output, enc_output, enc_output, mask=dec_enc_attn_mask)
        dec_output *= non_pad_mask

        dec_output = self.pos_ffn(dec_output)
        dec_output *= non_pad_mask
  • loss calculating
        loss = loss.masked_select(non_pad_mask).sum() / n_word

Training on small data set

Hi Kaituo,
I'm doing a very naive test. I selected one speaker from aishell(spkr = 2,#snt = 364, #wrd = 4,989) and used this speaker both for training, dev, and test. I expected to get some decent results becuase I'm overfitting the model. However here are the results I got (#epoch = 50):

SPKR | # Snt # Wrd | Corr Sub Del Ins Err S.Err |
| Sum/Avg | 364 4989 | 0.2 95.9 3.9 3.9 103.7 100.0

I understand an NN needs much much more data to work. but I wonder why the results are poor for this toy example when we use the same data for train, dev and test. shouldn't at least we get better results.

academic exchange

Hello, Mr. Xu:
I am a postgraduate student of speech recognition in Jiangnan University. I have long heard of your name. Can you add Wechat to facilitate future academic exchanges(my WeChat is 'zw76859420')?
By Wei Zhang

problem in decoding

(1/7176) decoding BAC009S0764W0121
remeined hypothes: 5
Traceback (most recent call last):
File "/data/gjj/Speech-Transformer/egs/aishell/../../src/bin/recognize.py", line 70, in
recognize(args)
File "/data/gjj/Speech-Transformer/egs/aishell/../../src/bin/recognize.py", line 59, in recognize
nbest_hyps = model.recognize(input, input_length, char_list, args)
File "/data/gjj/Speech-Transformer/src/transformer/transformer.py", line 46, in recognize
args)
File "/data/gjj/Speech-Transformer/src/transformer/decoder.py", line 222, in recognize_beam
for x in hyp['yseq'][0, 1:]]))
UnicodeEncodeError: 'ascii' codec can't encode character '\u751a' in position 6: ordinal not in range

hello, I have this problem after training and I don't what this error means and how to solve it.

遇到编码的问题

在解码时(stage 4),遇到如下问题:UnicodeEncodeError: 'ascii' codec can't encode character '\u751a' in position 6: ordinal not in range(128)。定位到代码,是decoder.py中print('hypo: ' + ''.join([char_list[int(x)] for x in hyp['yseq'][0, 1:]]))这一行代码出了问题。
网上的建议是设置utf-8编码,但是python3.6默认就是这个编码,很苦恼这个问题。

missing file

line 64 of file run.sh . utils/parse_options.sh || exit 1;
but there is no such folder with such file, how to fix?

Request for Pre-trained model

Hi, Your work is excellent.
Since transformer is very difficult to train, would you like to provide a pre-trained model ? Thank you very much !

CUDA out of memory

Hi,
sorry to bother you,
I want to ask that even I had changed the batch size to 2,
It stills says out of cuda memory while training?
Is it possible? My input dimension is only 23.
Please help me if you have time.
Thanks very much

continue_from设置

您好,我想在跑了2个epoch后保存的模型final.pth.tar上继续训练,continue_from应该如何设置呢?

bash.run.sh时

我在bash.run.sh的时候会报错,原因是run.sh第65行出现了一个没给的地址,请问一下应该怎么解决呢?

Request for Pre-trained model

Excuse me ! Is there pre-trained model to share?i couldnt get the good result.
Also,how long sentence the model can be recognition, whether there is length restriction about the model?
Sorry for bother! hope the reply!

ImportError: No module named torch

您好,感谢您的开源,我在执行训练的时候前面0,1,2的步骤都训练。就是为什么在步骤3训练的时候好像采用的python版本会是python2?我把系统默认的python的软链接改到python 3.7又会报:Unknown option: --
usage: python3 [option] ... [-c cmd | -m mod | file | -] [arg] ...
请问您的训练环境是怎么样的呢

无法运行

你好 我尝试运行你的代码,发现没法运行,您可以给我提供些帮助吗,代码报错如下

run.sh: 行 64: utils/parse_options.sh: 没有那个文件或目录

speech transformer

你好,请问这个是实现文献1的code吗?但是我没有在里面看到conv的操作。

data.py中的一个问题

您好,在data.py文件中,return里面有个ilen值,按理来说应该得到的是padding之前的xs的长度?但是这里得到的是长度已经一样的?

我用batch_size=4测试了一下:

get batch of lengths of input sequences

ilens = np.array([x.shape[0] for x in xs])
ilens = torch.from_numpy(ilens) # [235, 235, 235, 235] [250, 250, 250, 250]...

好像在前面
batch = load_inputs_and_targets(batch[0], LFR_m=LFR_m, LFR_n=LFR_n)
xs, ys = batch
这里xs就已经padding过了?

这里的ilen值的问题会导致后面的enc_dec_attn_mask全都是false...

run time long in stage 1

Hi, I found that it spend long time in stage 1, is there any wrong with my operation? How long will this process takes?

error: unrecognized arguments: --continue_from

当我第一次训练模型时,在train.log中报错,此时我并没有训练好的模型,无法导入继续训练,此时continue_from=""为默认值,空值。
error: unrecognized arguments: --continue_from

model train on thchs30 is bad

I find that the loss value on thchs dataset is smaller than aishell dataset? However the predict result is bad for thchs test set(almost can't recognize). What's the reason? Can you explain? Thanks

question about loss curve

Hi, kaituo ,i'm trying to train this network on librispeech ,the loss curve of epoch 1 shows that the model tends to saturate after first few steps(there has approximately 4k iters per epoch,and my loss has dropped from 4 to 3 after 100 iters and then stays the same.)
I have not made any changes to the model. The only change i do is to use my own dataloader (for loading librispeech corpus) . so i wander if u have the same trend of loss-decline on traning aishell corpus?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.