Supports multiple Speech to Text providers

- Supports multiple STT providers: Groq, Google, Microsoft, and OpenAI
- Records audio with volume-based progress indication
- Handles error cases for missing API keys and audio devices
- Includes example usage in main block
- Depends on sounddevice, numpy, and provider-specific libraries
- Manages API keys through environment variables
- Converts audio format if file size exceeds limit
This commit is contained in:
gurssagar 2025-02-04 02:01:43 +05:30
parent f7deb02560
commit cd62b87182

View file

@ -1,14 +1,15 @@
import math
import os
import queue
import tempfile
import time
import numpy as np
import warnings
from typing import Optional, Tuple, Union
from prompt_toolkit.shortcuts import prompt
from aider.llm import litellm
from .dump import dump # noqa: F401
warnings.filterwarnings(
@ -16,69 +17,132 @@ warnings.filterwarnings(
)
warnings.filterwarnings("ignore", category=SyntaxWarning)
from pydub import AudioSegment # noqa
from pydub.exceptions import CouldntDecodeError, CouldntEncodeError # noqa
from pydub import AudioSegment
from pydub.exceptions import CouldntDecodeError, CouldntEncodeError
try:
import soundfile as sf
except (OSError, ModuleNotFoundError):
sf = None
try:
import google.cloud.speech as speech
except (OSError, ModuleNotFoundError):
speech = None
try:
import azure.cognitiveservices.speech as speech_sdk
except (OSError, ModuleNotFoundError):
speech_sdk = None
try:
import openai
except (OSError, ModuleNotFoundError):
openai = None
from groq import Groq # Updated import
class SoundDeviceError(Exception):
pass
class Voice:
max_rms = 0
min_rms = 1e5
pct = 0
"""
A class to handle audio recording and transcription using sounddevice and various transcription services.
"""
threshold = 0.15
def __init__(self, audio_format: str = "wav", device_name: Optional[str] = None, provider: str = "groq", api_key: Optional[str] = None):
"""
Initialize the Voice class.
def __init__(self, audio_format="wav", device_name=None):
Args:
audio_format: The format of the output audio file. Supported formats: 'wav', 'mp3', 'webm'.
device_name: The name of the audio input device to use. If None, the default device is used.
provider: The transcription service provider. Supported providers: 'groq', 'google', 'microsoft', 'openai'.
api_key: Optional API key for the transcription service. If not provided, it will be retrieved from environment variables.
Raises:
SoundDeviceError: If sounddevice or soundfile is not available.
ValueError: If the audio_format is not supported or the device_name is not found.
"""
if sf is None:
raise SoundDeviceError
raise SoundDeviceError("soundfile is not available. Please install it.")
try:
print("Initializing sound device...")
import sounddevice as sd
self.sd = sd
devices = sd.query_devices()
if device_name:
# Find the device with matching name
device_id = None
for i, device in enumerate(devices):
if device_name in device["name"]:
device_id = i
break
if device_id is None:
available_inputs = [d["name"] for d in devices if d["max_input_channels"] > 0]
raise ValueError(
f"Device '{device_name}' not found. Available input devices:"
f" {available_inputs}"
)
print(f"Using input device: {device_name} (ID: {device_id})")
self.device_id = device_id
else:
self.device_id = None
except (OSError, ModuleNotFoundError):
raise SoundDeviceError
if audio_format not in ["wav", "mp3", "webm"]:
raise ValueError(f"Unsupported audio format: {audio_format}")
self.audio_format = audio_format
raise SoundDeviceError("sounddevice is not available. Please install it.")
def callback(self, indata, frames, time, status):
"""This is called (from a separate thread) for each audio block."""
if audio_format not in ["wav", "mp3", "webm"]:
raise ValueError(f"Unsupported audio format: {audio_format}. Supported formats: wav, mp3, webm.")
self.audio_format = audio_format
self.device_id = None
self.provider = provider
self.api_key = api_key
if device_name:
devices = self.sd.query_devices()
device_id = None
for i, device in enumerate(devices):
if device_name.lower() in device["name"].lower():
device_id = i
break
if device_id is None:
available_inputs = [d["name"] for d in devices if d["max_input_channels"] > 0]
raise ValueError(
f"Device '{device_name}' not found. Available input devices: {available_inputs}"
)
self.device_id = device_id
print(f"Using input device: {device_name} (ID: {device_id})")
self._validate_provider()
def _validate_provider(self):
supported_providers = ['groq', 'google', 'microsoft', 'openai']
if self.provider not in supported_providers:
raise ValueError(f"Unsupported provider: {self.provider}. Supported providers: {supported_providers}")
if self.provider == 'google' and speech is None:
raise ValueError("google.cloud.speech is not available. Please install google-cloud-speech.")
if self.provider == 'microsoft' and speech_sdk is None:
raise ValueError("azure.cognitiveservices.speech is not available. Please install azure-cognitive-services-speech.")
if self.provider == 'openai' and openai is None:
raise ValueError("openai is not available. Please install openai.")
def set_api_key(self):
if self.provider == 'groq':
self.api_key = os.getenv("GROQ_API_KEY")
if not self.api_key:
raise ValueError("Please set the GROQ_API_KEY environment variable.")
elif self.provider == 'google':
self.api_key = os.getenv("GOOGLE_API_KEY")
if not self.api_key:
raise ValueError("Please set the GOOGLE_API_KEY environment variable.")
elif self.provider == 'microsoft':
self.api_key = os.getenv("AZURE_SPEECH_KEY")
if not self.api_key:
raise ValueError("Please set the AZURE_SPEECH_KEY environment variable.")
elif self.provider == 'openai':
self.api_key = os.getenv("OPENAI_API_KEY")
if not self.api_key:
raise ValueError("Please set the OPENAI_API_KEY environment variable.")
def callback(self, indata: np.ndarray, frames: int, time: float, status: int) -> None:
"""
Callback function for audio data processing.
Args:
indata: The audio data.
frames: The number of frames.
time: The time stamp.
status: The stream status.
"""
import numpy as np
rms = np.sqrt(np.mean(indata**2))
rms = np.sqrt(np.mean(indata ** 2))
self.max_rms = max(self.max_rms, rms)
self.min_rms = min(self.min_rms, rms)
@ -90,7 +154,13 @@ class Voice:
self.q.put(indata.copy())
def get_prompt(self):
def get_prompt(self) -> str:
"""
Generate a progress prompt string.
Returns:
A formatted string showing recording status and progress bar.
"""
num = 10
if math.isnan(self.pct) or self.pct < self.threshold:
cnt = 0
@ -101,19 +171,44 @@ class Voice:
bar = bar[:num]
dur = time.time() - self.start_time
return f"Recording, press ENTER when done... {dur:.1f}sec {bar}"
return f"Recording, press ENTER when done... {dur:.1f}s {bar}"
def record_and_transcribe(self, history=None, language=None):
def record_and_transcribe(self, history: Optional[list] = None, language: Optional[str] = None) -> Optional[str]:
"""
Record audio and transcribe it.
Args:
history: Optional list of previous commands/transcripts for context.
language: Optional language code for transcription.
Returns:
The transcribed text or None if an error occurs.
"""
try:
return self.raw_record_and_transcribe(history, language)
except KeyboardInterrupt:
return
print("\nRecording stopped by user.")
return None
except SoundDeviceError as e:
print(f"Error: {e}")
print("Please ensure you have a working audio input device connected and try again.")
return
return None
def raw_record_and_transcribe(self, history, language):
def raw_record_and_transcribe(self, history: Optional[list], language: Optional[str]) -> Optional[str]:
"""
Raw method to record and transcribe audio without exception handling.
Args:
history: Optional list of previous commands/transcripts for context.
language: Optional language code for transcription.
Returns:
The transcribed text or None if an error occurs.
"""
self.max_rms = 0
self.min_rms = 1e5
self.pct = 0
self.threshold = 0.15
self.q = queue.Queue()
temp_wav = tempfile.mktemp(suffix=".wav")
@ -121,7 +216,7 @@ class Voice:
try:
sample_rate = int(self.sd.query_devices(self.device_id, "input")["default_samplerate"])
except (TypeError, ValueError):
sample_rate = 16000 # fallback to 16kHz if unable to query device
sample_rate = 16000 # Fallback to 16kHz
except self.sd.PortAudioError:
raise SoundDeviceError(
"No audio input device detected. Please check your audio settings and try again."
@ -132,56 +227,177 @@ class Voice:
try:
with self.sd.InputStream(
samplerate=sample_rate, channels=1, callback=self.callback, device=self.device_id
):
) as stream:
prompt(self.get_prompt, refresh_interval=0.1)
except self.sd.PortAudioError as err:
raise SoundDeviceError(f"Error accessing audio input device: {err}")
with sf.SoundFile(temp_wav, mode="x", samplerate=sample_rate, channels=1) as file:
while not self.q.empty():
file.write(self.q.get())
# Write recorded data to file
try:
with sf.SoundFile(temp_wav, mode="x", samplerate=sample_rate, channels=1) as file:
while not self.q.empty():
file.write(self.q.get())
except Exception as e:
print(f"Error writing audio data: {e}")
return None
# Convert audio format if necessary
use_audio_format = self.audio_format
filename = temp_wav
# Check file size and offer to convert to mp3 if too large
file_size = os.path.getsize(temp_wav)
if file_size > 24.9 * 1024 * 1024 and self.audio_format == "wav":
print("\nWarning: {temp_wav} is too large, switching to mp3 format.")
print("\nWarning: The WAV file is too large, switching to MP3 format.")
use_audio_format = "mp3"
filename = temp_wav
if use_audio_format != "wav":
try:
new_filename = tempfile.mktemp(suffix=f".{use_audio_format}")
audio = AudioSegment.from_wav(temp_wav)
audio.export(new_filename, format=use_audio_format)
os.remove(temp_wav)
filename = new_filename
except (CouldntDecodeError, CouldntEncodeError) as e:
print(f"Error converting audio: {e}")
except (OSError, FileNotFoundError) as e:
print(f"File system error during conversion: {e}")
except Exception as e:
print(f"Unexpected error during audio conversion: {e}")
print(f"Error converting audio: {e}")
use_audio_format = "wav"
filename = temp_wav
with open(filename, "rb") as fh:
try:
transcript = litellm.transcription(
model="whisper-1", file=fh, prompt=history, language=language
# Set API key based on provider
self.set_api_key()
# Transcribe audio using the selected provider
if self.provider == 'groq':
return self._transcribe_groq(filename, language)
elif self.provider == 'google':
return self._transcribe_google(filename, language)
elif self.provider == 'microsoft':
return self._transcribe_microsoft(filename, language)
elif self.provider == 'openai':
return self._transcribe_openai(filename, language)
else:
raise ValueError(f"Unsupported provider: {self.provider}")
def _transcribe_groq(self, filename: str, language: Optional[str] = None) -> Optional[str]:
"""
Transcribe audio using Groq's Whisper model.
Args:
filename: Path to the audio file.
language: Optional language code for transcription.
Returns:
The transcribed text or None if an error occurs.
"""
try:
client = Groq(api_key=self.api_key)
with open(filename, "rb") as audio_file:
transcript = client.audio.transcriptions.create(
file=(os.path.basename(filename), audio_file.read()),
model="distil-whisper-large-v3-en"
)
except Exception as err:
print(f"Unable to transcribe {filename}: {err}")
return
return transcript.text
except Exception as e:
print(f"Unable to transcribe audio using Groq: {e}")
return None
if filename != temp_wav:
os.remove(filename)
def _transcribe_google(self, filename: str, language: Optional[str] = None) -> Optional[str]:
"""
Transcribe audio using Google Cloud Speech-to-Text.
text = transcript.text
return text
Args:
filename: Path to the audio file.
language: Optional language code for transcription.
Returns:
The transcribed text or None if an error occurs.
"""
if language is None:
language = "en-US"
try:
client = speech.SpeechClient()
with open(filename, "rb") as audio_file:
audio = speech.RecognitionAudio(content=audio_file.read())
config = speech.RecognitionConfig(
encoding=speech.RecognitionConfig.AudioEncoding.LINEAR16,
language_code=language,
)
response = client.recognize(config, audio)
return " ".join([result.alternatives[0].transcript for result in response.results])
except Exception as e:
print(f"Unable to transcribe audio using Google: {e}")
return None
def _transcribe_microsoft(self, filename: str, language: Optional[str] = None) -> Optional[str]:
"""
Transcribe audio using Microsoft Azure Speech Services.
Args:
filename: Path to the audio file.
language: Optional language code for transcription.
Returns:
The transcribed text or None if an error occurs.
"""
if language is None:
language = "en-US"
try:
speech_config = speech_sdk.SpeechConfig(subscription=self.api_key, region="global")
audio_config = speech_sdk.AudioConfig(filename=filename)
speech_recognizer = speech_sdk.SpeechRecognizer(speech_config, audio_config=audio_config)
result = speech_recognizer.recognize_once()
return result.text
except Exception as e:
print(f"Unable to transcribe audio using Microsoft: {e}")
return None
def _transcribe_openai(self, filename: str, language: Optional[str] = None) -> Optional[str]:
"""
Transcribe audio using OpenAI Whisper model.
Args:
filename: Path to the audio file.
language: Optional language code for transcription.
Returns:
The transcribed text or None if an error occurs.
"""
try:
client = openai.OpenAI(api_key=self.api_key)
with open(filename, "rb") as audio_file:
response = client.audio.transcriptions.create(
"whisper-1",
file=audio_file.read()
)
return response.text
except Exception as e:
print(f"Unable to transcribe audio using OpenAI: {e}")
return None
if __name__ == "__main__":
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
raise ValueError("Please set the OPENAI_API_KEY environment variable.")
print(Voice().record_and_transcribe())
# Example usage with different providers
voice_groq = Voice(provider='groq')
transcript_groq = voice_groq.record_and_transcribe()
if transcript_groq:
print("\nTranscript (Groq):")
print(transcript_groq)
voice_google = Voice(provider='google')
transcript_google = voice_google.record_and_transcribe()
if transcript_google:
print("\nTranscript (Google):")
print(transcript_google)
voice_microsoft = Voice(provider='microsoft')
transcript_microsoft = voice_microsoft.record_and_transcribe()
if transcript_microsoft:
print("\nTranscript (Microsoft):")
print(transcript_microsoft)
voice_openai = Voice(provider='openai')
transcript_openai = voice_openai.record_and_transcribe()
if transcript_openai:
print("\nTranscript (OpenAI):")
print(transcript_openai)