candle_transformers/models/whisper/
mod.rs

1//! Whisper Model Implementation
2//!
3//! Whisper is an automatic speech recognition (ASR) system trained on large amounts
4//! of multilingual and multitask supervised data collected from the web. It can be used to
5//! convert audio files (in the `.wav` format) to text. Supported features include
6//! language detection as well as multilingual speech recognition.
7//!
8//! - ⚡ [Interactive Wasm Example](https://huggingface.co/spaces/lmz/candle-whisper)
9//! - 💻 [GH Link](https://github.com/openai/whisper)
10//! - 💻 Transformers Python [reference implementation](https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/modeling_whisper.py)
11//!
12//!
13pub mod audio;
14pub mod model;
15pub mod quantized_model;
16
17use serde::Deserialize;
18
19// The names in comments correspond to the original implementation:
20// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L17
21#[derive(Debug, Clone, PartialEq, Deserialize)]
22pub struct Config {
23    pub num_mel_bins: usize,            // n_mels
24    pub max_source_positions: usize,    // n_audio_ctx
25    pub d_model: usize,                 // n_audio_state
26    pub encoder_attention_heads: usize, // n_audio_head
27    pub encoder_layers: usize,          // n_audio_layer
28    pub vocab_size: usize,              // n_vocab
29    pub max_target_positions: usize,    //  n_text_ctx
30    // pub n_text_state: usize,
31    pub decoder_attention_heads: usize, // n_text_head
32    pub decoder_layers: usize,          // n_text_layer
33    #[serde(default)]
34    pub suppress_tokens: Vec<u32>,
35}
36
37pub const DTYPE: candle::DType = candle::DType::F32;
38
39// Audio parameters.
40pub const SAMPLE_RATE: usize = 16000;
41pub const N_FFT: usize = 400;
42pub const HOP_LENGTH: usize = 160;
43pub const CHUNK_LENGTH: usize = 30;
44pub const N_SAMPLES: usize = CHUNK_LENGTH * SAMPLE_RATE; // 480000 samples in a 30-second chunk
45pub const N_FRAMES: usize = N_SAMPLES / HOP_LENGTH; // 3000 frames in a mel spectrogram input
46
47pub const NO_SPEECH_THRESHOLD: f64 = 0.6;
48pub const LOGPROB_THRESHOLD: f64 = -1.0;
49pub const TEMPERATURES: [f64; 6] = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0];
50pub const COMPRESSION_RATIO_THRESHOLD: f64 = 2.4;
51
52// Tokenizer dependent bits.
53pub const SOT_TOKEN: &str = "<|startoftranscript|>";
54pub const TRANSCRIBE_TOKEN: &str = "<|transcribe|>";
55pub const TRANSLATE_TOKEN: &str = "<|translate|>";
56pub const NO_TIMESTAMPS_TOKEN: &str = "<|notimestamps|>";
57pub const EOT_TOKEN: &str = "<|endoftext|>";
58pub const NO_SPEECH_TOKENS: [&str; 2] = ["<|nocaptions|>", "<|nospeech|>"];