2025-08-08 20:29:45 +02:00
// SPDX-License-Identifier: MIT
// Copyright (c) 2025 <COPYRIGHT HOLDER>. All rights reserved.
2025-08-08 19:53:00 +02:00
//! Transcription backend selection and implementations (CPU/GPU) used by PolyScribe.
2025-08-08 16:52:18 +02:00
use crate ::OutputEntry ;
2025-08-11 06:59:24 +02:00
use crate ::progress ::ProgressMessage ;
2025-08-08 16:19:02 +02:00
use crate ::{ decode_audio_to_pcm_f32_ffmpeg , find_model_file } ;
2025-08-08 16:52:18 +02:00
use anyhow ::{ Context , Result , anyhow } ;
2025-08-08 16:19:02 +02:00
use std ::env ;
2025-08-08 16:52:18 +02:00
use std ::path ::Path ;
2025-08-11 06:59:24 +02:00
use std ::sync ::mpsc ::Sender ;
2025-08-08 16:19:02 +02:00
// Re-export a public enum for CLI parsing usage
#[ derive(Debug, Clone, Copy, PartialEq, Eq) ]
2025-08-08 16:52:18 +02:00
/// Kind of transcription backend to use.
2025-08-08 16:19:02 +02:00
pub enum BackendKind {
2025-08-08 16:52:18 +02:00
/// Automatically detect the best available backend (CUDA > HIP > Vulkan > CPU).
2025-08-08 16:19:02 +02:00
Auto ,
2025-08-08 16:52:18 +02:00
/// Pure CPU backend using whisper-rs.
2025-08-08 16:19:02 +02:00
Cpu ,
2025-08-08 16:52:18 +02:00
/// NVIDIA CUDA backend (requires CUDA runtime available at load time and proper feature build).
2025-08-08 16:19:02 +02:00
Cuda ,
2025-08-08 16:52:18 +02:00
/// AMD ROCm/HIP backend (requires hip/rocBLAS libraries available and proper feature build).
2025-08-08 16:19:02 +02:00
Hip ,
2025-08-08 16:52:18 +02:00
/// Vulkan backend (experimental; requires Vulkan loader/SDK and feature build).
2025-08-08 16:19:02 +02:00
Vulkan ,
}
2025-08-08 16:52:18 +02:00
/// Abstraction for a transcription backend implementation.
2025-08-08 16:19:02 +02:00
pub trait TranscribeBackend {
2025-08-08 16:52:18 +02:00
/// Return the backend kind for this implementation.
2025-08-08 16:19:02 +02:00
fn kind ( & self ) -> BackendKind ;
2025-08-08 16:52:18 +02:00
/// Transcribe the given audio file path and return transcript entries.
///
/// Parameters:
/// - audio_path: path to input media (audio or video) to be decoded/transcribed.
/// - speaker: label to attach to all produced segments.
/// - lang_opt: optional language hint (e.g., "en"); None means auto/multilingual model default.
/// - gpu_layers: optional GPU layer count if applicable (ignored by some backends).
fn transcribe (
& self ,
audio_path : & Path ,
speaker : & str ,
lang_opt : Option < & str > ,
2025-08-11 06:59:24 +02:00
progress_tx : Option < Sender < ProgressMessage > > ,
2025-08-08 16:52:18 +02:00
gpu_layers : Option < u32 > ,
) -> Result < Vec < OutputEntry > > ;
2025-08-08 16:19:02 +02:00
}
2025-08-08 20:16:44 +02:00
fn check_lib ( _names : & [ & str ] ) -> bool {
2025-08-08 16:19:02 +02:00
#[ cfg(test) ]
{
// During unit tests, avoid touching system libs to prevent loader crashes in CI.
2025-08-08 16:52:18 +02:00
false
2025-08-08 16:19:02 +02:00
}
#[ cfg(not(test)) ]
{
2025-08-08 17:04:42 +02:00
// Disabled runtime dlopen probing to avoid loader instability; rely on environment overrides.
2025-08-08 16:19:02 +02:00
false
}
}
fn cuda_available ( ) -> bool {
2025-08-08 16:52:18 +02:00
if let Ok ( x ) = env ::var ( " POLYSCRIBE_TEST_FORCE_CUDA " ) {
return x = = " 1 " ;
}
check_lib ( & [
" libcudart.so " ,
" libcudart.so.12 " ,
" libcudart.so.11 " ,
" libcublas.so " ,
" libcublas.so.12 " ,
] )
2025-08-08 16:19:02 +02:00
}
fn hip_available ( ) -> bool {
2025-08-08 16:52:18 +02:00
if let Ok ( x ) = env ::var ( " POLYSCRIBE_TEST_FORCE_HIP " ) {
return x = = " 1 " ;
}
2025-08-08 16:19:02 +02:00
check_lib ( & [ " libhipblas.so " , " librocblas.so " ] )
}
fn vulkan_available ( ) -> bool {
2025-08-08 16:52:18 +02:00
if let Ok ( x ) = env ::var ( " POLYSCRIBE_TEST_FORCE_VULKAN " ) {
return x = = " 1 " ;
}
2025-08-08 16:19:02 +02:00
check_lib ( & [ " libvulkan.so.1 " , " libvulkan.so " ] )
}
2025-08-08 16:52:18 +02:00
/// CPU-based transcription backend using whisper-rs.
2025-08-08 16:19:02 +02:00
pub struct CpuBackend ;
2025-08-08 16:52:18 +02:00
/// CUDA-accelerated transcription backend for NVIDIA GPUs.
2025-08-08 16:19:02 +02:00
pub struct CudaBackend ;
2025-08-08 16:52:18 +02:00
/// ROCm/HIP-accelerated transcription backend for AMD GPUs.
2025-08-08 16:19:02 +02:00
pub struct HipBackend ;
2025-08-08 16:52:18 +02:00
/// Vulkan-based transcription backend (experimental/incomplete).
2025-08-08 16:19:02 +02:00
pub struct VulkanBackend ;
impl CpuBackend {
2025-08-08 16:52:18 +02:00
/// Create a new CPU backend instance.
pub fn new ( ) -> Self {
CpuBackend
}
}
2025-08-08 19:42:10 +02:00
impl Default for CpuBackend {
2025-08-08 20:01:56 +02:00
fn default ( ) -> Self {
Self ::new ( )
}
2025-08-08 19:42:10 +02:00
}
2025-08-08 16:52:18 +02:00
impl CudaBackend {
/// Create a new CUDA backend instance.
pub fn new ( ) -> Self {
CudaBackend
}
}
2025-08-08 19:42:10 +02:00
impl Default for CudaBackend {
2025-08-08 20:01:56 +02:00
fn default ( ) -> Self {
Self ::new ( )
}
2025-08-08 19:42:10 +02:00
}
2025-08-08 16:52:18 +02:00
impl HipBackend {
/// Create a new HIP backend instance.
pub fn new ( ) -> Self {
HipBackend
}
}
2025-08-08 19:42:10 +02:00
impl Default for HipBackend {
2025-08-08 20:01:56 +02:00
fn default ( ) -> Self {
Self ::new ( )
}
2025-08-08 19:42:10 +02:00
}
2025-08-08 16:52:18 +02:00
impl VulkanBackend {
/// Create a new Vulkan backend instance.
pub fn new ( ) -> Self {
VulkanBackend
}
2025-08-08 16:19:02 +02:00
}
2025-08-08 19:42:10 +02:00
impl Default for VulkanBackend {
2025-08-08 20:01:56 +02:00
fn default ( ) -> Self {
Self ::new ( )
}
2025-08-08 19:42:10 +02:00
}
2025-08-08 16:19:02 +02:00
impl TranscribeBackend for CpuBackend {
2025-08-08 16:52:18 +02:00
fn kind ( & self ) -> BackendKind {
BackendKind ::Cpu
}
fn transcribe (
& self ,
audio_path : & Path ,
speaker : & str ,
lang_opt : Option < & str > ,
2025-08-11 06:59:24 +02:00
progress_tx : Option < Sender < ProgressMessage > > ,
2025-08-08 16:52:18 +02:00
_gpu_layers : Option < u32 > ,
) -> Result < Vec < OutputEntry > > {
2025-08-11 06:59:24 +02:00
transcribe_with_whisper_rs ( audio_path , speaker , lang_opt , progress_tx )
2025-08-08 16:19:02 +02:00
}
}
impl TranscribeBackend for CudaBackend {
2025-08-08 16:52:18 +02:00
fn kind ( & self ) -> BackendKind {
BackendKind ::Cuda
}
fn transcribe (
& self ,
audio_path : & Path ,
speaker : & str ,
lang_opt : Option < & str > ,
2025-08-11 06:59:24 +02:00
progress_tx : Option < Sender < ProgressMessage > > ,
2025-08-08 16:52:18 +02:00
_gpu_layers : Option < u32 > ,
) -> Result < Vec < OutputEntry > > {
2025-08-08 16:19:02 +02:00
// whisper-rs uses enabled CUDA feature at build time; call same code path
2025-08-11 06:59:24 +02:00
transcribe_with_whisper_rs ( audio_path , speaker , lang_opt , progress_tx )
2025-08-08 16:19:02 +02:00
}
}
impl TranscribeBackend for HipBackend {
2025-08-08 16:52:18 +02:00
fn kind ( & self ) -> BackendKind {
BackendKind ::Hip
}
fn transcribe (
& self ,
audio_path : & Path ,
speaker : & str ,
lang_opt : Option < & str > ,
2025-08-11 06:59:24 +02:00
progress_tx : Option < Sender < ProgressMessage > > ,
2025-08-08 16:52:18 +02:00
_gpu_layers : Option < u32 > ,
) -> Result < Vec < OutputEntry > > {
2025-08-11 06:59:24 +02:00
transcribe_with_whisper_rs ( audio_path , speaker , lang_opt , progress_tx )
2025-08-08 16:19:02 +02:00
}
}
impl TranscribeBackend for VulkanBackend {
2025-08-08 16:52:18 +02:00
fn kind ( & self ) -> BackendKind {
BackendKind ::Vulkan
}
fn transcribe (
& self ,
_audio_path : & Path ,
_speaker : & str ,
_lang_opt : Option < & str > ,
2025-08-11 06:59:24 +02:00
_progress_tx : Option < Sender < ProgressMessage > > ,
2025-08-08 16:52:18 +02:00
_gpu_layers : Option < u32 > ,
) -> Result < Vec < OutputEntry > > {
Err ( anyhow! (
" Vulkan backend not yet wired to whisper.cpp FFI. Build with --features gpu-vulkan and ensure Vulkan SDK is installed. How to fix: install Vulkan loader (libvulkan), set VULKAN_SDK, and run cargo build --features gpu-vulkan. "
) )
2025-08-08 16:19:02 +02:00
}
}
2025-08-08 16:52:18 +02:00
/// Result of choosing a transcription backend.
2025-08-08 16:19:02 +02:00
pub struct SelectionResult {
2025-08-08 16:52:18 +02:00
/// The constructed backend instance to perform transcription with.
2025-08-08 16:19:02 +02:00
pub backend : Box < dyn TranscribeBackend + Send + Sync > ,
2025-08-08 16:52:18 +02:00
/// Which backend kind was ultimately selected.
2025-08-08 16:19:02 +02:00
pub chosen : BackendKind ,
2025-08-08 16:52:18 +02:00
/// Which backend kinds were detected as available on this system.
2025-08-08 16:19:02 +02:00
pub detected : Vec < BackendKind > ,
}
2025-08-08 16:52:18 +02:00
/// Select an appropriate backend based on user request and system detection.
///
/// If `requested` is `BackendKind::Auto`, the function prefers CUDA, then HIP,
/// then Vulkan, falling back to CPU when no GPU backend is detected. When a
/// specific GPU backend is requested but unavailable, an error is returned with
/// guidance on how to enable it.
///
/// Set `verbose` to true to print detection/selection info to stderr.
2025-08-12 02:57:42 +02:00
pub fn select_backend ( requested : BackendKind , config : & crate ::Config ) -> Result < SelectionResult > {
2025-08-08 16:19:02 +02:00
let mut detected = Vec ::new ( ) ;
2025-08-08 16:52:18 +02:00
if cuda_available ( ) {
detected . push ( BackendKind ::Cuda ) ;
}
if hip_available ( ) {
detected . push ( BackendKind ::Hip ) ;
}
if vulkan_available ( ) {
detected . push ( BackendKind ::Vulkan ) ;
}
2025-08-08 16:19:02 +02:00
let mk = | k : BackendKind | -> Box < dyn TranscribeBackend + Send + Sync > {
match k {
BackendKind ::Cpu = > Box ::new ( CpuBackend ::new ( ) ) ,
BackendKind ::Cuda = > Box ::new ( CudaBackend ::new ( ) ) ,
BackendKind ::Hip = > Box ::new ( HipBackend ::new ( ) ) ,
BackendKind ::Vulkan = > Box ::new ( VulkanBackend ::new ( ) ) ,
BackendKind ::Auto = > Box ::new ( CpuBackend ::new ( ) ) , // will be replaced
}
} ;
let chosen = match requested {
BackendKind ::Auto = > {
2025-08-08 16:52:18 +02:00
if detected . contains ( & BackendKind ::Cuda ) {
BackendKind ::Cuda
} else if detected . contains ( & BackendKind ::Hip ) {
BackendKind ::Hip
} else if detected . contains ( & BackendKind ::Vulkan ) {
BackendKind ::Vulkan
} else {
BackendKind ::Cpu
}
2025-08-08 16:19:02 +02:00
}
BackendKind ::Cuda = > {
2025-08-08 16:52:18 +02:00
if detected . contains ( & BackendKind ::Cuda ) {
BackendKind ::Cuda
} else {
return Err ( anyhow! (
" Requested CUDA backend but CUDA libraries/devices not detected. How to fix: install NVIDIA driver + CUDA toolkit, ensure libcudart/libcublas are in loader path, and build with --features gpu-cuda. "
) ) ;
}
2025-08-08 16:19:02 +02:00
}
BackendKind ::Hip = > {
2025-08-08 16:52:18 +02:00
if detected . contains ( & BackendKind ::Hip ) {
BackendKind ::Hip
} else {
return Err ( anyhow! (
" Requested ROCm/HIP backend but libraries/devices not detected. How to fix: install ROCm hipBLAS/rocBLAS, ensure libs are in loader path, and build with --features gpu-hip. "
) ) ;
}
2025-08-08 16:19:02 +02:00
}
BackendKind ::Vulkan = > {
2025-08-08 16:52:18 +02:00
if detected . contains ( & BackendKind ::Vulkan ) {
BackendKind ::Vulkan
} else {
return Err ( anyhow! (
" Requested Vulkan backend but libvulkan not detected. How to fix: install Vulkan loader/SDK and build with --features gpu-vulkan. "
) ) ;
}
2025-08-08 16:19:02 +02:00
}
BackendKind ::Cpu = > BackendKind ::Cpu ,
} ;
2025-08-12 02:57:42 +02:00
if config . verbose > = 1 & & ! config . quiet {
2025-08-08 19:33:47 +02:00
crate ::dlog! ( 1 , " Detected backends: {:?} " , detected ) ;
crate ::dlog! ( 1 , " Selected backend: {:?} " , chosen ) ;
2025-08-08 16:19:02 +02:00
}
2025-08-08 16:52:18 +02:00
Ok ( SelectionResult {
backend : mk ( chosen ) ,
chosen ,
detected ,
} )
2025-08-08 16:19:02 +02:00
}
// Internal helper: transcription using whisper-rs with CPU/GPU (depending on build features)
#[ allow(clippy::too_many_arguments) ]
2025-08-12 02:43:20 +02:00
#[ cfg(feature = " whisper " ) ]
2025-08-08 16:52:18 +02:00
pub ( crate ) fn transcribe_with_whisper_rs (
audio_path : & Path ,
speaker : & str ,
lang_opt : Option < & str > ,
2025-08-11 06:59:24 +02:00
progress_tx : Option < Sender < ProgressMessage > > ,
2025-08-08 16:52:18 +02:00
) -> Result < Vec < OutputEntry > > {
2025-08-11 06:59:24 +02:00
// initial progress
if let Some ( tx ) = & progress_tx {
let _ = tx . send ( ProgressMessage {
fraction : 0.0 ,
stage : Some ( " load_model " . to_string ( ) ) ,
note : Some ( format! ( " {} " , audio_path . display ( ) ) ) ,
} ) ;
}
2025-08-08 16:19:02 +02:00
let pcm = decode_audio_to_pcm_f32_ffmpeg ( audio_path ) ? ;
let model = find_model_file ( ) ? ;
2025-08-11 06:59:24 +02:00
if let Some ( tx ) = & progress_tx {
let _ = tx . send ( ProgressMessage {
fraction : 0.05 ,
stage : Some ( " load_model " . to_string ( ) ) ,
note : Some ( " model selected " . to_string ( ) ) ,
} ) ;
}
2025-08-12 06:00:09 +02:00
let is_en_only = model
. file_name ( )
. and_then ( | s | s . to_str ( ) )
. map ( | s | s . contains ( " .en. " ) | | s . ends_with ( " .en.bin " ) )
. unwrap_or ( false ) ;
if let Some ( lang ) = lang_opt {
if is_en_only & & lang ! = " en " {
return Err ( anyhow! (
" Selected model is English-only ({}), but a non-English language hint '{}' was provided. Please use a multilingual model or set WHISPER_MODEL. " ,
model . display ( ) ,
lang
) ) ;
}
}
2025-08-08 16:52:18 +02:00
let model_str = model
. to_str ( )
. ok_or_else ( | | anyhow! ( " Model path not valid UTF-8: {} " , model . display ( ) ) ) ? ;
2025-08-08 16:19:02 +02:00
2025-08-08 19:33:47 +02:00
// Try to reduce native library logging via environment variables when not super-verbose.
if crate ::verbose_level ( ) < 2 {
// These env vars are recognized by ggml/whisper in many builds; harmless if unknown.
unsafe {
std ::env ::set_var ( " GGML_LOG_LEVEL " , " 0 " ) ;
std ::env ::set_var ( " WHISPER_PRINT_PROGRESS " , " 0 " ) ;
}
}
// Suppress stderr from whisper/ggml during model load and inference when quiet and not verbose.
2025-08-08 19:42:10 +02:00
let ( _ctx , mut state ) = crate ::with_suppressed_stderr ( | | {
2025-08-08 19:33:47 +02:00
let cparams = whisper_rs ::WhisperContextParameters ::default ( ) ;
let ctx = whisper_rs ::WhisperContext ::new_with_params ( model_str , cparams )
. with_context ( | | format! ( " Failed to load Whisper model at {} " , model . display ( ) ) ) ? ;
let state = ctx
. create_state ( )
. map_err ( | e | anyhow! ( " Failed to create Whisper state: {:?} " , e ) ) ? ;
Ok ::< _ , anyhow ::Error > ( ( ctx , state ) )
} ) ? ;
2025-08-11 06:59:24 +02:00
if let Some ( tx ) = & progress_tx {
let _ = tx . send ( ProgressMessage {
fraction : 0.15 ,
stage : Some ( " encode " . to_string ( ) ) ,
note : Some ( " state ready " . to_string ( ) ) ,
} ) ;
}
2025-08-08 16:19:02 +02:00
2025-08-08 16:52:18 +02:00
let mut params =
whisper_rs ::FullParams ::new ( whisper_rs ::SamplingStrategy ::Greedy { best_of : 1 } ) ;
let n_threads = std ::thread ::available_parallelism ( )
. map ( | n | n . get ( ) as i32 )
. unwrap_or ( 1 ) ;
2025-08-08 16:19:02 +02:00
params . set_n_threads ( n_threads ) ;
params . set_translate ( false ) ;
2025-08-08 16:52:18 +02:00
if let Some ( lang ) = lang_opt {
params . set_language ( Some ( lang ) ) ;
}
2025-08-08 16:19:02 +02:00
2025-08-11 06:59:24 +02:00
if let Some ( tx ) = & progress_tx {
let _ = tx . send ( ProgressMessage {
fraction : 0.20 ,
stage : Some ( " decode " . to_string ( ) ) ,
note : Some ( " inference " . to_string ( ) ) ,
} ) ;
}
2025-08-08 19:33:47 +02:00
crate ::with_suppressed_stderr ( | | {
state
. full ( params , & pcm )
. map_err ( | e | anyhow! ( " Whisper full() failed: {:?} " , e ) )
} ) ? ;
2025-08-11 06:59:24 +02:00
if let Some ( tx ) = & progress_tx {
let _ = tx . send ( ProgressMessage {
fraction : 1.0 ,
stage : Some ( " done " . to_string ( ) ) ,
note : Some ( " transcription finished " . to_string ( ) ) ,
} ) ;
}
2025-08-08 16:19:02 +02:00
2025-08-08 16:52:18 +02:00
let num_segments = state
. full_n_segments ( )
. map_err ( | e | anyhow! ( " Failed to get segments: {:?} " , e ) ) ? ;
2025-08-08 16:19:02 +02:00
let mut items = Vec ::new ( ) ;
for i in 0 .. num_segments {
2025-08-08 16:52:18 +02:00
let text = state
. full_get_segment_text ( i )
. map_err ( | e | anyhow! ( " Failed to get segment text: {:?} " , e ) ) ? ;
let t0 = state
. full_get_segment_t0 ( i )
. map_err ( | e | anyhow! ( " Failed to get segment t0: {:?} " , e ) ) ? ;
let t1 = state
. full_get_segment_t1 ( i )
. map_err ( | e | anyhow! ( " Failed to get segment t1: {:?} " , e ) ) ? ;
2025-08-08 16:19:02 +02:00
let start = ( t0 as f64 ) * 0.01 ;
let end = ( t1 as f64 ) * 0.01 ;
2025-08-08 16:52:18 +02:00
items . push ( OutputEntry {
id : 0 ,
speaker : speaker . to_string ( ) ,
start ,
end ,
text : text . trim ( ) . to_string ( ) ,
} ) ;
2025-08-08 16:19:02 +02:00
}
Ok ( items )
}
2025-08-12 02:43:20 +02:00
#[ allow(clippy::too_many_arguments) ]
#[ cfg(not(feature = " whisper " )) ]
pub ( crate ) fn transcribe_with_whisper_rs (
_audio_path : & Path ,
_speaker : & str ,
_lang_opt : Option < & str > ,
_progress_tx : Option < Sender < ProgressMessage > > ,
) -> Result < Vec < OutputEntry > > {
Err ( anyhow! (
" Transcription requires the 'whisper' feature. Rebuild with --features whisper (and optional gpu-cuda/gpu-hip). "
) )
}