Whisper-Medusa: Using multiple Decoding Heads to Achieve 1.5X Speedup

TL;DR

OpenAI’s Whisper is an advanced encoder-decoder model for speech transcription and translation, processing audio through encoding and decoding stages. Given its large size and slow inference speed, various optimization strategies like Faster-Whisper and Speculative Decoding have been proposed to enhance performance. Our Medusa model builds on top of Whisper (Whisper-Medusa model) by predicting multiple tokens per iteration, which significantly improves speed with small degradation in Word Error Rate (WER). We train and evaluate our model on the LibriSpeech dataset, demonstrating strong performance with both speed and accuracy improvements.

Whisper-Medusa

Whisper is a cutting-edge encoder-decoder model designed specifically for speech transcription and translation tasks. But what exactly does that mean? Let’s break it down.

How Does Whisper Work?

At its core, Whisper processes audio input by breaking it down into two main parts: encoding and decoding. Here’s how it works:

Encoding the Speech Input:

When an audio clip is input into Whisper, the encoder component processes the raw waveform through a series of transformations, converting it into a sequence of high-dimensional embeddings. This involves several layers of convolutional and self-attention mechanisms that capture temporal dependencies and phonetic features, effectively translating the audio signal from the time domain to a latent space representation. These encoded features preserve essential acoustic information necessary for subsequent processing.

Decoding the Encoded Representations:

Decoding the Encoded Representations

After the encoding phase, the high-dimensional embeddings are passed to the decoder. The decoder employs an autoregressive approach, where it sequentially predicts output tokens. Each token is generated based on the current context, utilizing attention mechanisms to reference relevant parts of the encoded audio representation. The model predicts the next token in the sequence by iterating through a transformer architecture, which maintains the context and dependencies from previously generated tokens. This iterative token prediction is essential for producing coherent and contextually accurate transcriptions or translations.

Figure 1: Whisper’s architecture utilizes a transformer-based model that encodes audio features with a multi-layer encoder, processes these representations through a series of self-attention layers, and then decodes them into text tokens using a multi-layer decoder.

Speeding Up the Whisper Model

The most accurate Whisper model is very large, with approximately 1.5 billion parameters. As a result, its substantial size poses a significant challenge for inference speed.There are various approaches that focus on improving the speed of Whisper such as model optimization, efficient use of hardware, algorithmic improvements, and advanced pre-processing and post-processing techniques. One of the most effective solutions for tackling the speed issue is Faster-Whisper [1]. This approach involves reimplementing OpenAI’s Whisper model using CTranslate2, a high-performance inference engine specifically optimized for Transformer models. Another approach is to optimize the Whisper model by using guided knowledge distillation and quantization [2]

Speeding Up Whisper with Speculative Decoding: Whisper-Medusa

To tackle the Whisper’s speed limitations, which currently predicts just one token per iteration, we approached the problem from a different angle using “Speculative Decoding”.

Speculative Decoding is a technique that utilizes a smaller, faster “assistant model” to generate several potential outputs. The most promising output is then selected from these generated candidates, optimizing the overall efficiency and effectiveness of the decoding process.

Specifically, the assistant model generates multiple potential sequences. Each generated sequence is then evaluated and scored based on some hypothesis on its likelihood or quality.  The scores are used to filter and select the most promising sequences. Finally, the most probable or highest-quality sequence is chosen as the output. 

Our version of Speculative Decoding is inspired by [3], Instead of adhering to the traditional single-token approach, or using  additional  “assistant models”, we propose modifying the model architecture itself to predict multiple tokens per step. Specifically, we plan for the model to forecast the next K + 1 tokens at every iteration, where K is the number of extra tokens we aim to predict. Since the Whisper model’s decoder processes the entire speech audio at once, our  method is  a smart and efficient way to speed things up.

Our architecture is an extension of the whisper architecture. Specifically, the outputs from the final decoder layer are fed into K -“Medusa heads.” Each Medusa head consists of a linear layer with a single layer and a residual connection. 

It can be seen in the image that the vocabulary projection layer is shared across all heads, minimizing the increase in parameters. Additionally, during training, we only update the final decoder layer and the Medusa heads, simplifying the training process.

aiOla ASR medusa

Figure 2 : Whisper’s architecture with medusa heads.

Datasets:

Our model is trained using the LibriSpeech dataset, a widely used resource in the field of speech recognition. LibriSpeech consists of approximately 1,000 hours of English read speech, all derived from public domain audiobooks, and comes with corresponding transcriptions. Specifically, we train our model on the LibriSpeech-100, LibriSpeech-360, and LibriSpeech-500 subsets. Notably, the transcripts in LibriSpeech are in uppercase format, so we used transcripts from [4], which restore punctuation and capitalization for enhanced readability.

We evaluate our model using the LibriSpeech Test-Clean dataset, known for its high-quality audio where the speech is clear and well-articulated.

Results:

When evaluating our model, we focus on two key measures: accuracy and speedup. 

 Figure 3 illustrates the speedup of our Medusa model relative to target sequence length on the LibriSpeech test subsets. The results indicate that our Medusa architecture achieves a consistent speedup across all sequence lengths. While the improvement is minor for very short targets, there is a significant speedup for longer target sequences.

Medusa speedup

Figure 4: Medusa Speedup by target sequence length on Librispeech Test-Clean subset.

Next, we evaluate our model for accuracy. Table 1 displays the WER and CER results for both the Whisper and Medusa models on the LibriSpeech Test-Clean subset . The “Whisper Vanilla” refers to the original model published by OpenAI. Our Medusa model shows a slight degradation in the WER  results, but overall, the performance remains very strong.

Model Dataset WER CER
Test-Clean Whisper vanilla 0.04 0.019
Medusa (10 heads) 0.042 0.019

Table 1: WER and CER results for whisper and our medusa model on LibriSpeech Test-Clean and LibriSpeech Test-Other

In conclusion, our advancements with the Medusa architecture offers a promising solution to the speed limitations of Open AI’s Whisper model, significantly enhancing its efficiency with small degradation in WER. 

Our code is available here. Looking ahead, we aim to further refine the Medusa model to achieve an even greater improvement on both accuracy and speed.

References:

[1] https://github.com/SYSTRAN/faster-whisper

[2] Shao, Hang, et al. “Whisper-kdq: A lightweight whisper via guided knowledge distillation and quantization for efficient asr.” arXiv preprint arXiv:2305.10788 (2023).

[3] Cai, Tianle, et al. “Medusa: Simple llm inference acceleration framework with multiple decoding heads.” arXiv preprint arXiv:2401.10774 (2024)

[4] Meister, Aleksandr, et al. “LibriSpeech-PC: Benchmark for Evaluation of Punctuation and Capitalization Capabilities of end-to-end ASR Models.” 2023 IEEE Automatic Speech Recognition and Understanding Workshop (ASRU). IEEE, 202

Yael Segal
Author
Yael Segal
Yael Segal-Feldman is an AI Researcher at aiOla and a PhD candidate in Computer Science at the Technion–Israel Institute of Technology. Her research specializes in machine learning and deep learning, with a particular focus on the detection and localization of speech objects. She has made significant contributions in areas such as keyword spotting, spoken term detection, and Diadochokinetic (DDK) tasks, advancing the field of AI-driven speech recognition technologies.
Pen