Giter Club home page Giter Club logo

summarymixing's Introduction

SummaryMixing for SpeechBrain v1.0

Halve your VRAM requirements and train 30% faster any speech model achieving equivalents or better Word Error Rates and SLU accuracies with SummaryMixing Conformers and Branchformers.

!! A word about using SummaryMixing with SpeechBrain V1.0 !!

The main branch of this repository will keep tracking the latest version of SpeechBrain available. Unfortunately the results reported in our publication and bellow in the Table were obtained with SpeechBrain v0.5 and may not be exactly reproduced with the current code. If you want the exact same results, please use our dedicated branch that contains the code compatible with SpeechBrain v0.5!

In brief

This repository implements SummaryMixing, a simpler, faster and much cheaper replacement to self-attention in Conformers and Branchformers for automatic speech recognition, keyword spotting and intent classification (see: the publication for further details). The code is fully compatible with the SpeechBrain toolkit with version 0.5 -- copy and paste is all you need to start using SummaryMixing in your setup. If you wish to run with the latest version of SpeechBrain (v1.0+), please go to the main branch of this repository. SummaryMixing is the first alternative to MHSA able to beat it on speech tasks while reducing its complexity significantly (from quadratic to linear).

A glance at SummaryMixing

SummaryMixing is a linear-time alternative to self-attention (SA) for speech processing models such as Transformers, Conformers or Branchformers. Instead of computing pair-wise scores between tokens (leading to quadratic-time complexity for SA), it summarises a whole utterance with mean over vectors for all time steps. SummaryMixing is based on the recent findings demonstrating that self-attention could be useless for speech recognition as the attention weights of trained ASR systems are almost uniformly distributed accross the tokens composing a sequence. SummaryMixing also is a generalisation of the recent HyperMixer and HyperConformer to better and simpler mixing functions. In a SummaryMixing cell, that takes the same inputs and produces the same outputs than self-attention, contributions from each time step are first transformed and then averaged globally before being fed back to each time step. This is visible in Figure 1 in the article. Therefore, the time-complexity is reduced to linear.

A few results

A SummaryMixing-equipped Conformer outperforms a self-attention equivalent model on Librispeech test-clean (2.1% vs 2.3%) and test-other (5.1% vs 5.4%). This is done with a 30% training reduction as well as less than half of the memory budget (from 46GB to 21GB). Such gains are also visible with CommonVoice, AISHELL-1 and Tedlium2. This gain is also visible at decoding time as the real-time factor remains stable (does not increase) with the sentence length for a SummaryMixing Branchformer while the same model with self-attention would see its RTF following a quadratic increase. The SpeechBrain configuration files in this repository can reproduce these numbers.

The following Table gives an idea of the results observed with Librispeech. More results on CommonVoice, AISHELL, Tedlium, SLURP, and Google Speech Command are available in the article.

Encoder Variant Dev-clean Test-clean Test-other GPU VRAM
WER % WER % WER % hours GB
ContextNet N.A. 3.3 2.3 5.9 160 25
Transformer Self-attention 3.3 2.3 5.5 129 40
Conformer Self-attention 2.8 2.3 5.4 137 46
Branchformer Self-attention 2.9 2.2 5.1 132 45
CNN Only 3.1 2.4 5.7 83 22
HyperMixer 3.1 2.3 5.6 126 30
FastFormer 3.0 2.2 5.4 96 23
Proposed
Conformer SummaryMixing 2.8 2.1 5.1 98 21
Branchformers SummaryMixing-lite 3.0 2.2 5.2 98 23
SummaryMixing 2.9 2.2 5.1 105 26
+Summary Decoder 3.1 2.3 5.3 104 26

RTF performance

Citation

Please cite SummaryMixing as follows:

@inproceedings{parcollet24_interspeech,
  title     = {SummaryMixing: A Linear-Complexity Alternative to Self-Attention for Speech Recognition and Understanding},
  author    = {Titouan Parcollet and Rogier {van Dalen} and Shucong Zhang and Sourav Bhattacharya},
  year      = {2024},
  booktitle = {Interspeech 2024},
  pages     = {3460--3464},
  doi       = {10.21437/Interspeech.2024-40},
  issn      = {2958-1796},
}

Licence

This code is distributed under the CC-BY-NC 4.0 Licence. See the Licence for further details

summarymixing's People

Contributors

tparcollet avatar adel-moumen avatar eltociear avatar somang-park avatar

Stargazers

Al Chen avatar  avatar Yunus Güngör avatar Vincent avatar  avatar Nguyễn Văn Anh Tuấn avatar Shawon Ashraf avatar RES avatar  avatar Jiahao Li avatar Thomas Rolland avatar Lyonel Behringer avatar  avatar Sarthak Yadav avatar Cassio T Batista avatar gabriel duncan avatar Abel Sen avatar Ibrahim Amin avatar Jiawen Huang avatar Donghyun Kim avatar  avatar  avatar EEzim avatar  avatar  avatar  avatar  avatar  avatar jzhu avatar Jackie Wang avatar  avatar Yuanhang Zhang avatar bouna Nabgha avatar Jeff Carpenter avatar Jorge Iranzo avatar Alef Iury avatar xmdxcsj avatar  avatar Moreno La Quatra avatar Myungchul Shin avatar  avatar František Kynych avatar  avatar Alkis Koudounas avatar  avatar Daimon avatar Ng Kam Woh avatar Andrew Liu avatar mingjie chen avatar Vikram avatar  avatar Ais avatar Xilai Li avatar SeshurajuP avatar Nickolay V. Shmyrev avatar  avatar Andrew Rouditchenko avatar Shu-wen (Leo) Yang avatar Samuele Cornell avatar Fernando López Gavilánez avatar Desh Raj avatar  avatar Marius Miron avatar Nauman Dawalatabad avatar  avatar Raman avatar Songlin Yang avatar Sofian Mejjoute avatar Claudio Casellato avatar Mohamed Zayed Ahmed avatar Jack Deadman avatar Larissa Guder avatar B.V.K avatar Akshat Dewan avatar ChNousias avatar Akis Nousias avatar  avatar Nishant Sinha avatar Vijay Jaisankar avatar R.Sowmiya avatar Ashish Papanai avatar Pavan Pandya avatar Yang Wang  avatar Konstantinos Kyriakidis avatar KC avatar Pun King Fung avatar Brianda avatar Danilo Jr Dela Cruz avatar Yunusemre avatar  avatar  avatar Hristo Vrigazov avatar Vlad Kostoglodov avatar Florian Lux avatar Mohamed Amr avatar Omer Tariq avatar Aidan Pine avatar Varun Ganjigunte Prakash avatar Subtain Malik avatar Aneesh Shetty avatar

Watchers

 avatar Nickolay V. Shmyrev avatar Rogier van Dalen avatar Peter Moonki Hong avatar Sourav Bhattacharya avatar  avatar  avatar Kostas Georgiou avatar  avatar Kishore avatar

summarymixing's Issues

The grad norm is nan

Hi author, I'm getting the following when training branchformer using summary_mixing

[autodl-container-4d6411b93c-8a044365] 2024-04-10 17:11:12,899 (ctc:67) WARNING: 13/34 samples got nan grad. These were ignored for CTC loss.
[autodl-container-4d6411b93c-8a044365] 2024-04-10 17:11:13,133 (build_trainer:660) WARNING: The grad norm is nan. Skipping updating the model.
[autodl-container-4d6411b93c-8a044365] 2024-04-10 17:11:13,263 (ctc:67) WARNING: 7/32 samples got nan grad. These were ignored for CTC loss.
[autodl-container-4d6411b93c-8a044365] 2024-04-10 17:11:13,477 (build_trainer:660) WARNING: The grad norm is nan. Skipping updating the model.
[autodl-container-4d6411b93c-8a044365] 2024-04-10 17:11:13,625 (ctc:67) WARNING: 21/45 samples got nan grad. These were ignored for CTC loss.
[autodl-container-4d6411b93c-8a044365] 2024-04-10 17:11:13,858 (build_trainer:660) WARNING: The grad norm is nan. Skipping updating the model.
[autodl-container-4d6411b93c-8a044365] 2024-04-10 17:11:14,022 (ctc:67) WARNING: 21/62 samples got nan grad. These were ignored for CTC loss.
[autodl-container-4d6411b93c-8a044365] 2024-04-10 17:11:14,248 (build_trainer:660) WARNING: The grad norm is nan. Skipping updating the model.
[autodl-container-4d6411b93c-8a044365] 2024-04-10 17:11:14,499 (ctc:67) WARNING: 37/105 samples got nan grad. These were ignored for CTC loss.
[autodl-container-4d6411b93c-8a044365] 2024-04-10 17:11:14,735 (build_trainer:660) WARNING: The grad norm is nan. Skipping updating the model.
[autodl-container-4d6411b93c-8a044365] 2024-04-10 17:11:14,875 (ctc:67) WARNING: 12/39 samples got nan grad. These were ignored for CTC loss.
[autodl-container-4d6411b93c-8a044365] 2024-04-10 17:11:15,104 (build_trainer:660) WARNING: The grad norm is nan. Skipping updating the model.
[autodl-container-4d6411b93c-8a044365] 2024-04-10 17:11:15,261 (ctc:67) WARNING: 23/56 samples got nan grad. These were ignored for CTC loss.
[autodl-container-4d6411b93c-8a044365] 2024-04-10 17:11:15,479 (build_trainer:660) WARNING: The grad norm is nan. Skipping updating the model.
[autodl-container-4d6411b93c-8a044365] 2024-04-10 17:11:15,623 (ctc:67) WARNING: 20/47 samples got nan grad. These were ignored for CTC loss.
[autodl-container-4d6411b93c-8a044365] 2024-04-10 17:11:15,854 (build_trainer:660) WARNING: The grad norm is nan. Skipping updating the model.
[autodl-container-4d6411b93c-8a044365] 2024-04-10 17:11:16,004 (ctc:67) WARNING: 15/53 samples got nan grad. These were ignored for CTC loss.
[autodl-container-4d6411b93c-8a044365] 2024-04-10 17:11:16,224 (build_trainer:660) WARNING: The grad norm is nan. Skipping updating the model.

Why is this

Valid step generates a RuntimeError

Dear Team,

I want to compare the ASR results we have reached based on wav2vec2 & whisper architectures, with your SummaryMixing one.

We are performing a custom ASR training, our dataset is composed of 95 000 records for Train, 16 000 records for Val, 17 000 records for Test.

Train was successfully performed with the following parameters (A100 40G GPU):

  • precision=bf16
  • batch_size=12
  • number_of_epochs=60
  • token_type=bpe

However, at epoch 1 valid step, we got the following error:

speechbrain.utils.epoch_loop - Going into epoch 1
  0%|          | 0/8660 [00:00<?, ?it/s]/usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:5109: UserWarning: Support for mismatched key_padding_mask and attn_mask is deprecated. Use same type for both instead.
  warnings.warn(
 44%|████▍     | 3803/8660 [15:03<18:27,  4.39it/s, train_loss=93.3]speechbrain.utils.checkpoints - Saving an empty state_dict for <torch.cuda.amp.grad_scaler.GradScaler object at 0x7fe0401657c0> in /data/outputs/save/CKPT+2024-05-31+14-19-54+00/scaler.ckpt.
 88%|████████▊ | 7586/8660 [30:12<04:18,  4.15it/s, train_loss=80.9]speechbrain.utils.checkpoints - Saving an empty state_dict for <torch.cuda.amp.grad_scaler.GradScaler object at 0x7fe0401657c0> in /data/outputs/save/CKPT+2024-05-31+14-35-03+00/scaler.ckpt.
100%|██████████| 8660/8660 [34:15<00:00,  4.21it/s, train_loss=78.7]
  0%|          | 0/1291 [00:00<?, ?it/s]
speechbrain.core - Exception:
Traceback (most recent call last):
  File "train.py", line 442, in <module>
    asr_brain.fit(
  File "/usr/local/lib/python3.8/dist-packages/speechbrain/core.py", line 1556, in fit
    self._fit_valid(valid_set=valid_set, epoch=epoch, enable=enable)
  File "/usr/local/lib/python3.8/dist-packages/speechbrain/core.py", line 1462, in _fit_valid
    loss = self.evaluate_batch(batch, stage=Stage.VALID)
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/speechbrain/core.py", line 1345, in evaluate_batch
    out = self.compute_forward(batch, stage=stage)
  File "train.py", line 68, in compute_forward
    enc_out, pred = self.modules.Transformer(
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/parallel/data_parallel.py", line 183, in forward
    return self.module(*inputs[0], **module_kwargs[0])
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/speechbrain/lobes/models/transformer/TransformerASR.py", line 381, in forward
    encoder_out, _ = self.encoder(
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/speechbrain/lobes/models/transformer/Branchformer.py", line 480, in forward
    output, attention = enc_layer(
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/speechbrain/lobes/models/transformer/Branchformer.py", line 277, in forward
    x2 = self._forward_cnn_branch(x2)
  File "/usr/local/lib/python3.8/dist-packages/speechbrain/lobes/models/transformer/Branchformer.py", line 293, in _forward_cnn_branch
    x = self.convolution_branch(x)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/speechbrain/lobes/models/transformer/Branchformer.py", line 94, in forward
    x = self.csgu(x)  # (B, T, D//2)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/speechbrain/lobes/models/convolution.py", line 99, in forward
    x2 = self.conv(x2)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/speechbrain/nnet/CNN.py", line 428, in forward
    x = self._manage_padding(
  File "/usr/local/lib/python3.8/dist-packages/speechbrain/nnet/CNN.py", line 480, in _manage_padding
    x = F.pad(x, padding, mode=self.padding_mode)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/functional.py", line 4495, in pad
    return torch._C._nn.pad(input, pad, mode, value)
RuntimeError: Argument #4: Padding size should be less than the corresponding input dimension, but got: padding (15, 15) at dimension 2 of input [12, 1536, 13]

What's wrong?

Thanks for your support.

Issue Encoder-only SummaryMixing

Hi,

Thanks for this repo. I have been playing a bit with your creation SummaryMixing and tried to plug it in a CTC-only recipe (https://github.com/speechbrain/speechbrain/blob/develop/recipes/LibriSpeech/ASR/CTC/train.py) which has been implemented by Shucong. However, by using this recipe one must make sure that num_decoder_layers is equal to 0. Doing so will create an issue with SummaryMixing because of this line: https://github.com/SamsungLabs/SummaryMixing/blob/main/speechbrain/lobes/models/transformer/TransformerASR.py#L388 . Indeed, the forward tgt = self.custom_tgt_module(tgt) is trying to use the custom_tgt_module which is defined here: https://github.com/SamsungLabs/SummaryMixing/blob/main/speechbrain/lobes/models/transformer/TransformerASR.py#L334-L335 but as you can see this method is only defined if num_decoder_layers > 0. One fix is to simply return from this function right after the Encoder as we did in speechbrain here: https://github.com/speechbrain/speechbrain/blob/develop/speechbrain/lobes/models/transformer/TransformerASR.py#L350-L352

Also, I found that it was a bit hard to use SummaryMixing because of this line: https://github.com/SamsungLabs/SummaryMixing/blob/main/speechbrain/lobes/models/transformer/Branchformer.py#L222 . Indeed, you are defined the VanillaDNN using local_proj_out_dim + summary_out_dim, but in my experiment with the CTC-only recipe it is "impossible" to define externally the input dimension because it seems to have some extra downsampling somewhere. I had to do some harcoding in order to make it work. I may be only related to the CTC-only recipe, but I wanted to let you know as maybe you already experienced this issue.

Thanks!

Adel

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.