Whisper is an advanced encoder-decoder model for speech transcription and translation, processing audio through encoding and decoding stages. Given its large size and slow inference speed, various optimization strategies like Faster-Whisper and Speculative Decoding have been proposed to enhance performance. Our Medusa model builds on Whisper by predicting multiple tokens per iteration, which significantly improves speed with small degradation in WER. We train and evaluate our model on the LibriSpeech dataset, demonstrating strong performance with both speed and accuracy improvements.
Time comparison between Whisper Medusa and Whisper vanilla on librispeech dataset. Whisper Medusa generates the sequence X2 faster.
Start with creating a virtual environment and activating it:
conda create -n whisper-medusa python=3.11 -y
conda activate whisper-medusa
pip install torch==2.2.2 torchvision==0.17.2 torchaudio==2.2.2 --index-url https://download.pytorch.org/whl/cu118
Then install the package:
git clone https://github.com/aiola-lab/whisper-medusa.git
cd whisper-medusa
pip install -e .
Inference can be done using the following code:
import torch
import torchaudio
from whisper_medusa import WhisperMedusaModel
from transformers import WhisperProcessor
model_name = "aiola/whisper-medusa-v1"
model = WhisperMedusaModel.from_pretrained(model_name)
processor = WhisperProcessor.from_pretrained(model_name)
path_to_audio = "path/to/audio.wav"
SAMPLING_RATE = 16000
language = "en"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
input_speech, sr = torchaudio.load(path_to_audio)
if sr != SAMPLING_RATE:
input_speech = torchaudio.transforms.Resample(sr, SAMPLING_RATE)(input_speech)
input_features = processor(input_speech.squeeze(), return_tensors="pt", sampling_rate=SAMPLING_RATE).input_features
input_features = input_features.to(device)
model = model.to(device)
model_output = model.generate(
input_features,
language=language,
)
predict_ids = model_output[0]
pred = processor.decode(predict_ids, skip_special_tokens=True)
print(pred)
To evaluate the model we assume a csv file with the following columns:
audio
: path to the audio file.sentence
: the corresponding transcript.language
: the language of the audio file.
Then run the following command:
python whisper_medusa/eval_whisper_medusa.py \
--model-name /path/to/model \
--data-path /path/to/data \
--out-file-path /path/to/output \
--language en
arguments description:
model-name
: path to local model / huggingface hub.data-path
: path to the data.out-file-path
: path to the output file.language
: default language fallback.
whisper-medusa
is based on Medusa fast decoding.
@article{cai2024medusa,
title={Medusa: Simple llm inference acceleration framework with multiple decoding heads},
author={Cai, Tianle and Li, Yuhong and Geng, Zhengyang and Peng, Hongwu and Lee, Jason D and Chen, Deming and Dao, Tri},
journal={arXiv preprint arXiv:2401.10774},
year={2024}
}