test: update voice tests to mock audio dependencies properly

This commit is contained in:
Paul Gauthier (aider) 2024-11-30 19:15:16 -08:00
parent 1a8d112055
commit 49c78f2797

View file

@ -1,39 +1,37 @@
import os import os
import queue import queue
from unittest.mock import patch from unittest.mock import patch, MagicMock
import numpy as np import numpy as np
import pytest import pytest
from aider.voice import SoundDeviceError, Voice from aider.voice import SoundDeviceError, Voice
# Mock the entire sounddevice module
@pytest.fixture @pytest.fixture
def mock_sounddevice(): def mock_sounddevice():
with patch("sounddevice.query_devices") as mock_query: with patch('aider.voice.sounddevice') as mock_sd:
mock_query.return_value = [ mock_sd.query_devices.return_value = [
{"name": "test_device", "max_input_channels": 2}, {"name": "test_device", "max_input_channels": 2},
{"name": "another_device", "max_input_channels": 1}, {"name": "another_device", "max_input_channels": 1},
] ]
yield mock_query yield mock_sd
@pytest.fixture @pytest.fixture
def mock_soundfile(): def mock_soundfile():
with patch("soundfile.SoundFile") as mock_sf: with patch('aider.voice.sf') as mock_sf:
yield mock_sf yield mock_sf
def test_voice_init_default_device(mock_sounddevice): def test_voice_init_default_device(mock_sounddevice):
voice = Voice() voice = Voice()
assert voice.device_id is None assert voice.device_id is None
assert voice.audio_format == "wav" assert voice.audio_format == "wav"
assert voice.sd == mock_sounddevice
def test_voice_init_specific_device(mock_sounddevice): def test_voice_init_specific_device(mock_sounddevice):
voice = Voice(device_name="test_device") voice = Voice(device_name="test_device")
assert voice.device_id == 0 assert voice.device_id == 0
assert voice.sd == mock_sounddevice
def test_voice_init_invalid_device(mock_sounddevice): def test_voice_init_invalid_device(mock_sounddevice):
with pytest.raises(ValueError) as exc: with pytest.raises(ValueError) as exc:
@ -41,56 +39,52 @@ def test_voice_init_invalid_device(mock_sounddevice):
assert "Device" in str(exc.value) assert "Device" in str(exc.value)
assert "not found" in str(exc.value) assert "not found" in str(exc.value)
def test_voice_init_invalid_format(): def test_voice_init_invalid_format():
with pytest.raises(ValueError) as exc: with patch('aider.voice.sf', MagicMock()): # Need to mock sf to avoid SoundDeviceError
Voice(audio_format="invalid") with pytest.raises(ValueError) as exc:
assert "Unsupported audio format" in str(exc.value) Voice(audio_format="invalid")
assert "Unsupported audio format" in str(exc.value)
def test_callback_processing(): def test_callback_processing():
voice = Voice() with patch('aider.voice.sf', MagicMock()): # Need to mock sf to avoid SoundDeviceError
voice.q = queue.Queue() voice = Voice()
voice.q = queue.Queue()
# Test with silence (low amplitude) # Test with silence (low amplitude)
test_data = np.zeros((1000, 1)) test_data = np.zeros((1000, 1))
voice.callback(test_data, None, None, None) voice.callback(test_data, None, None, None)
assert voice.pct == 0.5 # When range is too small (<=0.001), pct is set to 0.5 assert voice.pct == 0.5 # When range is too small (<=0.001), pct is set to 0.5
# Test with loud signal (high amplitude) # Test with loud signal (high amplitude)
test_data = np.ones((1000, 1)) test_data = np.ones((1000, 1))
voice.callback(test_data, None, None, None) voice.callback(test_data, None, None, None)
assert voice.pct > 0.9 assert voice.pct > 0.9
# Verify data is queued
assert not voice.q.empty()
# Verify data is queued
assert not voice.q.empty()
def test_get_prompt(): def test_get_prompt():
voice = Voice() with patch('aider.voice.sf', MagicMock()): # Need to mock sf to avoid SoundDeviceError
voice.start_time = voice.start_time = os.times().elapsed voice = Voice()
voice.pct = 0.5 # 50% volume level voice.start_time = os.times().elapsed
voice.pct = 0.5 # 50% volume level
prompt = voice.get_prompt() prompt = voice.get_prompt()
assert "Recording" in prompt assert "Recording" in prompt
assert "sec" in prompt assert "sec" in prompt
assert "" in prompt # Should contain some filled blocks assert "" in prompt # Should contain some filled blocks
assert "" in prompt # Should contain some empty blocks assert "" in prompt # Should contain some empty blocks
def test_record_and_transcribe_keyboard_interrupt():
with patch('aider.voice.sf', MagicMock()):
voice = Voice()
with patch.object(voice, 'raw_record_and_transcribe', side_effect=KeyboardInterrupt()):
result = voice.record_and_transcribe()
assert result is None
@patch("sounddevice.InputStream") def test_record_and_transcribe_device_error():
def test_record_and_transcribe_keyboard_interrupt(mock_stream): with patch('aider.voice.sf', MagicMock()):
voice = Voice() voice = Voice()
mock_stream.side_effect = KeyboardInterrupt() with patch.object(voice, 'raw_record_and_transcribe', side_effect=SoundDeviceError("Test error")):
result = voice.record_and_transcribe()
result = voice.record_and_transcribe() assert result is None
assert result is None
@patch("sounddevice.InputStream")
def test_record_and_transcribe_device_error(mock_stream):
voice = Voice()
mock_stream.side_effect = SoundDeviceError("Test error")
result = voice.record_and_transcribe()
assert result is None