From 6b6c2432505e37562183300b1ac0e1e742138b0d Mon Sep 17 00:00:00 2001 From: Your Name Date: Sat, 25 Apr 2026 07:51:08 +0000 Subject: [PATCH 01/10] Add Tests --- .github/workflows/tests.yml | 32 +++ pyproject.toml | 5 + tests/__init__.py | 0 tests/readme.md | 3 +- tests/unit/__init__.py | 0 tests/unit/test_api_server.py | 328 ++++++++++++++++++++++++++ tests/unit/test_entity.py | 427 ++++++++++++++++++++++++++++++++++ tests/unit/test_models.py | 304 ++++++++++++++++++++++++ tests/unit/test_mpv_player.py | 248 ++++++++++++++++++++ tests/unit/test_unit.py | 226 ++++++++++++++++++ tests/unit/test_webrtc.py | 195 ++++++++++++++++ tests/unit/test_zeroconf.py | 174 ++++++++++++++ 12 files changed, 1941 insertions(+), 1 deletion(-) create mode 100644 .github/workflows/tests.yml create mode 100644 tests/__init__.py create mode 100644 tests/unit/__init__.py create mode 100644 tests/unit/test_api_server.py create mode 100644 tests/unit/test_entity.py create mode 100644 tests/unit/test_models.py create mode 100644 tests/unit/test_mpv_player.py create mode 100644 tests/unit/test_unit.py create mode 100644 tests/unit/test_webrtc.py create mode 100644 tests/unit/test_zeroconf.py diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml new file mode 100644 index 00000000..00bdfd0b --- /dev/null +++ b/.github/workflows/tests.yml @@ -0,0 +1,32 @@ +name: Unit Tests + +on: + push: + branches: + - "main" + tags: + - "v[0-9]+.[0-9]+.[0-9]+" # For releases, e.g. v1.2.0 + - "v[0-9]+.[0-9]+.[0-9]+-[a-zA-Z0-9]+" # For prereleases, e.g. v1.3.0-alpha1 + pull_request: + types: [opened, synchronize, reopened] + +jobs: + pytest: + name: Pytest Unit + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Python 3.13 + uses: actions/setup-python@v5 + with: + python-version: "3.13" + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + script/setup --dev + + - name: Run unit tests + run: pytest tests/unit \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index be871a58..653be914 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,7 @@ dev = [ "pytest", "isort", "autopep8", + "pytest-asyncio", ] [project.scripts] @@ -131,3 +132,7 @@ max-args = 50 [tool.pylint.format] expected-line-ending-format = "LF" + +[tool.pytest.ini_options] +testpaths = ["tests"] +asyncio_mode = "auto" diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/readme.md b/tests/readme.md index f6bb1c3f..c9cc73a4 100644 --- a/tests/readme.md +++ b/tests/readme.md @@ -1,6 +1,7 @@ # Automated testing -tbd +From the root of LVA run:- +```pytest tests/unit``` # Manual testing diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/test_api_server.py b/tests/unit/test_api_server.py new file mode 100644 index 00000000..e1660d80 --- /dev/null +++ b/tests/unit/test_api_server.py @@ -0,0 +1,328 @@ +"""Unit tests for APIServer packet parsing and message handling.""" + +import pytest +from unittest.mock import MagicMock, patch, call +import asyncio + +from aioesphomeapi.api_pb2 import ( + HelloRequest, + HelloResponse, + PingRequest, + PingResponse, + DisconnectRequest, + DisconnectResponse, + AuthenticationRequest, + AuthenticationResponse, +) +from aioesphomeapi._frame_helper.packets import make_plain_text_packets +from aioesphomeapi.core import MESSAGE_TYPE_TO_PROTO + +PROTO_TO_MESSAGE_TYPE = {v: k for k, v in MESSAGE_TYPE_TO_PROTO.items()} + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def make_packet(msg) -> bytes: + """Serialize a protobuf message into a plain-text ESPHome packet.""" + msg_type = PROTO_TO_MESSAGE_TYPE[msg.__class__] + data = msg.SerializeToString() + packets = make_plain_text_packets([(msg_type, data)]) + # make_plain_text_packets returns a list of memoryview/bytes — join into one + if isinstance(packets, (list, tuple)): + return b"".join(bytes(p) for p in packets) + return bytes(packets) + + +class ConcreteAPIServer: + """Concrete subclass of APIServer for testing — records handled messages.""" + + def __init__(self): + from linux_voice_assistant.api_server import APIServer + + class _Concrete(APIServer): + def __init__(self_inner): + super().__init__("test-server") + self_inner.handled = [] + + def handle_message(self_inner, msg): + self_inner.handled.append(msg) + return [] + + self._cls = _Concrete + self.instance = _Concrete() + + @property + def server(self): + return self.instance + + +def make_server(): + """Return a connected APIServer instance with a mock transport.""" + wrapper = ConcreteAPIServer() + server = wrapper.server + + transport = MagicMock() + written = [] + + def capture_writelines(data): + # data may be a list of memoryview/bytes — flatten to a single bytes object + if isinstance(data, (list, tuple)): + written.append(b"".join(bytes(d) for d in data)) + else: + written.append(bytes(data)) + + transport.writelines = capture_writelines + server._transport = transport + server._writelines = capture_writelines + server._written = written + + return server + + +def get_sent_messages(server): + """Decode all messages the server sent back via writelines.""" + from aioesphomeapi.core import MESSAGE_TYPE_TO_PROTO + + messages = [] + for raw in server._written: + # raw is bytes — parse the plain-text framing manually + pos = 0 + while pos < len(raw): + # preamble + if (raw[pos] if isinstance(raw[pos], int) else raw[pos][0]) != 0x00: + break + pos += 1 + + # length varuint + length = 0 + bitpos = 0 + while True: + b = raw[pos] + pos += 1 + length |= (b & 0x7F) << bitpos + if not (b & 0x80): + break + bitpos += 7 + + # msg_type varuint + msg_type = 0 + bitpos = 0 + while True: + b = raw[pos] + pos += 1 + msg_type |= (b & 0x7F) << bitpos + if not (b & 0x80): + break + bitpos += 7 + + # payload + payload = raw[pos:pos + length] + pos += length + + msg_cls = MESSAGE_TYPE_TO_PROTO[msg_type] + messages.append(msg_cls.FromString(payload)) + + return messages + + +# --------------------------------------------------------------------------- +# connection_made / connection_lost +# --------------------------------------------------------------------------- + + +class TestConnection: + def test_connection_made_stores_transport(self): + from linux_voice_assistant.api_server import APIServer + + class _Concrete(APIServer): + def __init__(self): + super().__init__("test") + def handle_message(self, msg): + return [] + + server = _Concrete() + transport = MagicMock() + transport.writelines = MagicMock() + + loop = MagicMock() + with patch("linux_voice_assistant.api_server.asyncio.get_running_loop", return_value=loop): + server.connection_made(transport) + + assert server._transport is transport + + def test_connection_lost_clears_transport(self): + server = make_server() + server.connection_lost(None) + assert server._transport is None + assert server._writelines is None + + def test_connection_lost_clears_loop(self): + server = make_server() + server._loop = MagicMock() + server.connection_lost(None) + assert server._loop is None + + +# --------------------------------------------------------------------------- +# HelloRequest → HelloResponse +# --------------------------------------------------------------------------- + + +class TestHelloHandshake: + def test_hello_request_yields_hello_response(self): + server = make_server() + server.data_received(make_packet(HelloRequest(client_info="test"))) + msgs = get_sent_messages(server) + assert any(isinstance(m, HelloResponse) for m in msgs) + + def test_hello_response_contains_server_name(self): + server = make_server() + server.data_received(make_packet(HelloRequest(client_info="test"))) + msgs = get_sent_messages(server) + hello = next(m for m in msgs if isinstance(m, HelloResponse)) + assert hello.name == "test-server" + + def test_hello_response_has_api_version(self): + server = make_server() + server.data_received(make_packet(HelloRequest(client_info="test"))) + msgs = get_sent_messages(server) + hello = next(m for m in msgs if isinstance(m, HelloResponse)) + assert hello.api_version_major == 1 + assert hello.api_version_minor == 10 + + def test_hello_does_not_call_handle_message(self): + server = make_server() + server.data_received(make_packet(HelloRequest(client_info="test"))) + assert server.handled == [] + + +# --------------------------------------------------------------------------- +# PingRequest → PingResponse +# --------------------------------------------------------------------------- + + +class TestPing: + def test_ping_request_yields_ping_response(self): + server = make_server() + server.data_received(make_packet(PingRequest())) + msgs = get_sent_messages(server) + assert any(isinstance(m, PingResponse) for m in msgs) + + +# --------------------------------------------------------------------------- +# DisconnectRequest → DisconnectResponse +# --------------------------------------------------------------------------- + + +class TestDisconnect: + def test_disconnect_request_yields_disconnect_response(self): + server = make_server() + server.data_received(make_packet(DisconnectRequest())) + msgs = get_sent_messages(server) + assert any(isinstance(m, DisconnectResponse) for m in msgs) + + def test_disconnect_closes_transport(self): + server = make_server() + mock_transport = MagicMock() + server._transport = mock_transport + server.data_received(make_packet(DisconnectRequest())) + mock_transport.close.assert_called_once() + + def test_disconnect_clears_transport_reference(self): + server = make_server() + server.data_received(make_packet(DisconnectRequest())) + assert server._transport is None + + +# --------------------------------------------------------------------------- +# AuthenticationRequest → AuthenticationResponse +# --------------------------------------------------------------------------- + + +class TestAuthentication: + def test_auth_request_yields_auth_response(self): + server = make_server() + server.data_received(make_packet(AuthenticationRequest())) + msgs = get_sent_messages(server) + assert any(isinstance(m, AuthenticationResponse) for m in msgs) + + +# --------------------------------------------------------------------------- +# Buffer management +# --------------------------------------------------------------------------- + + +class TestBufferManagement: + def test_buffer_is_none_after_complete_packet(self): + server = make_server() + server.data_received(make_packet(PingRequest())) + assert server._buffer is None + + def test_partial_packet_stays_in_buffer(self): + server = make_server() + full = make_packet(PingRequest()) + # Send only half the packet + server.data_received(full[:len(full) // 2]) + assert server._buffer is not None + + def test_split_packet_reassembled_correctly(self): + server = make_server() + full = make_packet(PingRequest()) + half = len(full) // 2 + server.data_received(full[:half]) + server.data_received(full[half:]) + msgs = get_sent_messages(server) + assert any(isinstance(m, PingResponse) for m in msgs) + + def test_two_packets_in_one_data_received(self): + server = make_server() + data = make_packet(PingRequest()) + make_packet(PingRequest()) + server.data_received(data) + msgs = get_sent_messages(server) + assert sum(1 for m in msgs if isinstance(m, PingResponse)) == 2 + + def test_buffer_len_tracks_correctly(self): + server = make_server() + full = make_packet(PingRequest()) + server.data_received(full[:2]) + assert server._buffer_len == 2 + + def test_buffer_cleared_after_full_packet(self): + server = make_server() + server.data_received(make_packet(PingRequest())) + assert server._buffer_len == 0 + + +# --------------------------------------------------------------------------- +# _read_varuint +# --------------------------------------------------------------------------- + + +class TestReadVarint: + def _make_server_with_buffer(self, data: bytes): + server = make_server() + server._buffer = data + server._buffer_len = len(data) + server._pos = 0 + return server + + def test_reads_single_byte_varuint(self): + server = self._make_server_with_buffer(bytes([0x05])) + assert server._read_varuint() == 5 + + def test_reads_two_byte_varuint(self): + # 300 encoded as varuint = 0xAC 0x02 + server = self._make_server_with_buffer(bytes([0xAC, 0x02])) + assert server._read_varuint() == 300 + + def test_returns_minus_one_on_empty_buffer(self): + server = make_server() + server._buffer = None + assert server._read_varuint() == -1 + + def test_returns_zero_for_zero_byte(self): + server = self._make_server_with_buffer(bytes([0x00])) + assert server._read_varuint() == 0 \ No newline at end of file diff --git a/tests/unit/test_entity.py b/tests/unit/test_entity.py new file mode 100644 index 00000000..000a6e28 --- /dev/null +++ b/tests/unit/test_entity.py @@ -0,0 +1,427 @@ +"""Unit tests for ESPHome entity classes.""" + +import pytest +from unittest.mock import MagicMock, call + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def make_server(): + """Minimal mock APIServer.""" + server = MagicMock() + server.state = MagicMock() + return server + + +def make_media_player(server=None, key=1, initial_volume=1.0): + from linux_voice_assistant.entity import MediaPlayerEntity + server = server or make_server() + return MediaPlayerEntity( + server=server, + key=key, + name="Media Player", + object_id="test_media_player", + music_player=MagicMock(), + announce_player=MagicMock(), + initial_volume=initial_volume, + ) + + +def make_mute_switch(server=None, key=2, initial_muted=False): + from linux_voice_assistant.entity import MuteSwitchEntity + server = server or make_server() + get_muted = MagicMock(return_value=initial_muted) + set_muted = MagicMock() + entity = MuteSwitchEntity( + server=server, + key=key, + name="Mute", + object_id="mute", + get_muted=get_muted, + set_muted=set_muted, + ) + entity._get_muted = get_muted + entity._set_muted = set_muted + return entity + + +def make_mic_setting(server=None, key=3, options=None, value=0.0): + from linux_voice_assistant.entity import MicSettingEntity + server = server or make_server() + get_value = MagicMock(return_value=value) + set_value = MagicMock() + entity = MicSettingEntity( + server=server, + key=key, + name="Mic Gain", + object_id="mic_gain", + get_value=get_value, + set_value=set_value, + min_value=0.0, + max_value=31.0, + options=options, + ) + entity._get_value_mock = get_value + entity._set_value_mock = set_value + return entity + + +# --------------------------------------------------------------------------- +# Import proto message types once +# --------------------------------------------------------------------------- + +from aioesphomeapi.api_pb2 import ( + ListEntitiesRequest, + ListEntitiesMediaPlayerResponse, + ListEntitiesSelectResponse, + ListEntitiesNumberResponse, + ListEntitiesSwitchResponse, + MediaPlayerCommandRequest, + MediaPlayerStateResponse, + NumberCommandRequest, + NumberStateResponse, + SelectCommandRequest, + SelectStateResponse, + SubscribeHomeAssistantStatesRequest, + SwitchCommandRequest, + SwitchStateResponse, +) +from aioesphomeapi.model import MediaPlayerCommand, MediaPlayerState + + +# --------------------------------------------------------------------------- +# MediaPlayerEntity +# --------------------------------------------------------------------------- + + +class TestMediaPlayerEntityInit: + def test_initial_volume_clamped_above_one(self): + entity = make_media_player(initial_volume=1.5) + assert entity.volume == 1.0 + + def test_initial_volume_clamped_below_zero(self): + entity = make_media_player(initial_volume=-0.5) + assert entity.volume == 0.0 + + def test_initial_volume_stored(self): + entity = make_media_player(initial_volume=0.7) + assert abs(entity.volume - 0.7) < 0.001 + + def test_initial_state_is_idle(self): + entity = make_media_player() + assert entity.state == MediaPlayerState.IDLE + + def test_not_muted_by_default(self): + entity = make_media_player() + assert entity.muted is False + + +class TestMediaPlayerEntityListEntities: + def test_list_entities_request_yields_response(self): + entity = make_media_player(key=5) + msgs = list(entity.handle_message(ListEntitiesRequest())) + assert len(msgs) == 1 + assert isinstance(msgs[0], ListEntitiesMediaPlayerResponse) + + def test_list_entities_response_has_correct_key(self): + entity = make_media_player(key=5) + msgs = list(entity.handle_message(ListEntitiesRequest())) + assert msgs[0].key == 5 + + def test_list_entities_response_has_correct_object_id(self): + entity = make_media_player() + msgs = list(entity.handle_message(ListEntitiesRequest())) + assert msgs[0].object_id == "test_media_player" + + def test_list_entities_supports_pause(self): + entity = make_media_player() + msgs = list(entity.handle_message(ListEntitiesRequest())) + assert msgs[0].supports_pause is True + + +class TestMediaPlayerEntitySubscribeStates: + def test_subscribe_states_yields_state_response(self): + entity = make_media_player() + msgs = list(entity.handle_message(SubscribeHomeAssistantStatesRequest())) + assert len(msgs) == 1 + assert isinstance(msgs[0], MediaPlayerStateResponse) + + def test_subscribe_states_has_correct_key(self): + entity = make_media_player(key=7) + msgs = list(entity.handle_message(SubscribeHomeAssistantStatesRequest())) + assert msgs[0].key == 7 + + +class TestMediaPlayerEntityVolume: + def test_apply_volume_sets_both_players(self): + entity = make_media_player(initial_volume=1.0) + entity._apply_volume(0.5, persist=False) + entity.music_player.set_volume.assert_called_with(50) + entity.announce_player.set_volume.assert_called_with(50) + + def test_apply_volume_clamps_above_one(self): + entity = make_media_player() + entity._apply_volume(1.5, persist=False) + assert entity.volume == 1.0 + + def test_apply_volume_clamps_below_zero(self): + entity = make_media_player() + entity._apply_volume(-0.1, persist=False) + assert entity.volume == 0.0 + + def test_apply_volume_stores_previous_volume(self): + entity = make_media_player(initial_volume=0.8) + entity._apply_volume(0.4, persist=False, remember=True) + assert abs(entity.previous_volume - 0.4) < 0.001 + + def test_apply_volume_persist_calls_callback(self): + callback = MagicMock() + entity = make_media_player() + entity._on_volume_changed = callback + entity._apply_volume(0.5, persist=True) + callback.assert_called_once_with(0.5) + + def test_apply_volume_no_persist_skips_callback(self): + callback = MagicMock() + entity = make_media_player() + entity._on_volume_changed = callback + entity._apply_volume(0.5, persist=False) + callback.assert_not_called() + + def test_volume_command_yields_state_response(self): + entity = make_media_player(key=1) + msg = MediaPlayerCommandRequest(key=1, has_volume=True, volume=0.6) + msgs = list(entity.handle_message(msg)) + assert any(isinstance(m, MediaPlayerStateResponse) for m in msgs) + + def test_volume_command_wrong_key_ignored(self): + entity = make_media_player(key=1) + msg = MediaPlayerCommandRequest(key=99, has_volume=True, volume=0.6) + msgs = list(entity.handle_message(msg)) + assert msgs == [] or not any(isinstance(m, MediaPlayerStateResponse) for m in msgs) + + +class TestMediaPlayerEntityMuteUnmute: + def test_mute_sets_volume_to_zero(self): + entity = make_media_player(initial_volume=0.8) + msg = MediaPlayerCommandRequest(key=1, has_command=True, command=MediaPlayerCommand.MUTE) + list(entity.handle_message(msg)) + assert entity.volume == 0.0 + + def test_mute_saves_previous_volume(self): + entity = make_media_player(initial_volume=0.8) + entity.volume = 0.8 + msg = MediaPlayerCommandRequest(key=1, has_command=True, command=MediaPlayerCommand.MUTE) + list(entity.handle_message(msg)) + assert abs(entity.previous_volume - 0.8) < 0.001 + + def test_mute_sets_muted_flag(self): + entity = make_media_player(key=1) + msg = MediaPlayerCommandRequest(key=1, has_command=True, command=MediaPlayerCommand.MUTE) + list(entity.handle_message(msg)) + assert entity.muted is True + + def test_unmute_restores_previous_volume(self): + entity = make_media_player(key=1, initial_volume=0.8) + entity.previous_volume = 0.8 + entity.muted = True + entity.volume = 0.0 + msg = MediaPlayerCommandRequest(key=1, has_command=True, command=MediaPlayerCommand.UNMUTE) + list(entity.handle_message(msg)) + assert abs(entity.volume - 0.8) < 0.001 + + def test_unmute_clears_muted_flag(self): + entity = make_media_player(key=1) + entity.muted = True + msg = MediaPlayerCommandRequest(key=1, has_command=True, command=MediaPlayerCommand.UNMUTE) + list(entity.handle_message(msg)) + assert entity.muted is False + + def test_double_mute_does_not_overwrite_previous_volume(self): + entity = make_media_player(key=1, initial_volume=0.9) + entity.volume = 0.9 + entity.previous_volume = 0.9 + + msg_mute = MediaPlayerCommandRequest(key=1, has_command=True, command=MediaPlayerCommand.MUTE) + list(entity.handle_message(msg_mute)) + # Mute again — should not overwrite previous_volume with 0 + list(entity.handle_message(msg_mute)) + assert abs(entity.previous_volume - 0.9) < 0.001 + + +class TestMediaPlayerEntityPlayback: + def test_pause_command_calls_music_player_pause(self): + entity = make_media_player(key=1) + msg = MediaPlayerCommandRequest(key=1, has_command=True, command=MediaPlayerCommand.PAUSE) + list(entity.handle_message(msg)) + entity.music_player.pause.assert_called_once() + + def test_stop_command_calls_music_player_stop(self): + entity = make_media_player(key=1) + msg = MediaPlayerCommandRequest(key=1, has_command=True, command=MediaPlayerCommand.STOP) + list(entity.handle_message(msg)) + entity.music_player.stop.assert_called_once() + + def test_play_command_calls_music_player_resume(self): + entity = make_media_player(key=1) + msg = MediaPlayerCommandRequest(key=1, has_command=True, command=MediaPlayerCommand.PLAY) + list(entity.handle_message(msg)) + entity.music_player.resume.assert_called_once() + + +# --------------------------------------------------------------------------- +# MuteSwitchEntity +# --------------------------------------------------------------------------- + + +class TestMuteSwitchEntity: + def test_initial_state_synced_from_get_muted(self): + entity = make_mute_switch(initial_muted=True) + assert entity._switch_state is True + + def test_switch_command_calls_set_muted(self): + entity = make_mute_switch(key=2) + msg = SwitchCommandRequest(key=2, state=True) + list(entity.handle_message(msg)) + entity._set_muted.assert_called_once_with(True) + + def test_switch_command_updates_internal_state(self): + entity = make_mute_switch(key=2) + msg = SwitchCommandRequest(key=2, state=True) + list(entity.handle_message(msg)) + assert entity._switch_state is True + + def test_switch_command_yields_switch_state_response(self): + entity = make_mute_switch(key=2) + msg = SwitchCommandRequest(key=2, state=True) + msgs = list(entity.handle_message(msg)) + assert any(isinstance(m, SwitchStateResponse) for m in msgs) + + def test_switch_command_wrong_key_ignored(self): + entity = make_mute_switch(key=2) + msg = SwitchCommandRequest(key=99, state=True) + list(entity.handle_message(msg)) + entity._set_muted.assert_not_called() + + def test_list_entities_request_yields_switch_response(self): + entity = make_mute_switch() + msgs = list(entity.handle_message(ListEntitiesRequest())) + assert any(isinstance(m, ListEntitiesSwitchResponse) for m in msgs) + + def test_subscribe_states_yields_switch_state_response(self): + entity = make_mute_switch() + msgs = list(entity.handle_message(SubscribeHomeAssistantStatesRequest())) + assert any(isinstance(m, SwitchStateResponse) for m in msgs) + + def test_sync_with_state_updates_switch_state(self): + entity = make_mute_switch(initial_muted=False) + entity._get_muted.return_value = True + entity.sync_with_state() + assert entity._switch_state is True + + def test_update_set_muted_replaces_callback(self): + entity = make_mute_switch(key=2) + new_set_muted = MagicMock() + entity.update_set_muted(new_set_muted) + msg = SwitchCommandRequest(key=2, state=True) + list(entity.handle_message(msg)) + new_set_muted.assert_called_once_with(True) + + def test_update_get_muted_replaces_callback(self): + entity = make_mute_switch() + new_get_muted = MagicMock(return_value=True) + entity.update_get_muted(new_get_muted) + entity.sync_with_state() + assert entity._switch_state is True + + +# --------------------------------------------------------------------------- +# MicSettingEntity — number mode (no options) +# --------------------------------------------------------------------------- + + +class TestMicSettingEntityNumber: + def test_list_entities_yields_number_response(self): + entity = make_mic_setting(key=3, options=None) + msgs = list(entity.handle_message(ListEntitiesRequest())) + assert any(isinstance(m, ListEntitiesNumberResponse) for m in msgs) + + def test_number_command_calls_set_value(self): + entity = make_mic_setting(key=3, options=None) + msg = NumberCommandRequest(key=3, state=5.0) + list(entity.handle_message(msg)) + entity._set_value_mock.assert_called_once_with(5.0) + + def test_number_command_yields_number_state_response(self): + entity = make_mic_setting(key=3, options=None) + msg = NumberCommandRequest(key=3, state=5.0) + msgs = list(entity.handle_message(msg)) + assert any(isinstance(m, NumberStateResponse) for m in msgs) + + def test_number_command_wrong_key_ignored(self): + entity = make_mic_setting(key=3, options=None) + msg = NumberCommandRequest(key=99, state=5.0) + list(entity.handle_message(msg)) + entity._set_value_mock.assert_not_called() + + def test_subscribe_states_yields_number_state_response(self): + entity = make_mic_setting(key=3, options=None, value=2.0) + msgs = list(entity.handle_message(SubscribeHomeAssistantStatesRequest())) + assert any(isinstance(m, NumberStateResponse) for m in msgs) + + def test_sync_with_state_updates_internal_state(self): + entity = make_mic_setting(key=3, options=None, value=0.0) + entity._get_value_mock.return_value = 7.0 + entity.sync_with_state() + assert entity._state == 7.0 + + +# --------------------------------------------------------------------------- +# MicSettingEntity — select mode (with options) +# --------------------------------------------------------------------------- + + +class TestMicSettingEntitySelect: + NOISE_OPTIONS = ["Off", "Low", "Medium", "High", "Max"] + + def test_list_entities_yields_select_response(self): + entity = make_mic_setting(key=4, options=self.NOISE_OPTIONS) + msgs = list(entity.handle_message(ListEntitiesRequest())) + assert any(isinstance(m, ListEntitiesSelectResponse) for m in msgs) + + def test_select_response_has_correct_options(self): + entity = make_mic_setting(key=4, options=self.NOISE_OPTIONS) + msgs = list(entity.handle_message(ListEntitiesRequest())) + select_msg = next(m for m in msgs if isinstance(m, ListEntitiesSelectResponse)) + assert list(select_msg.options) == self.NOISE_OPTIONS + + def test_select_command_calls_set_value(self): + entity = make_mic_setting(key=4, options=self.NOISE_OPTIONS) + msg = SelectCommandRequest(key=4, state="High") + list(entity.handle_message(msg)) + entity._set_value_mock.assert_called_once_with("High") + + def test_select_command_yields_select_state_response(self): + entity = make_mic_setting(key=4, options=self.NOISE_OPTIONS) + msg = SelectCommandRequest(key=4, state="Medium") + msgs = list(entity.handle_message(msg)) + assert any(isinstance(m, SelectStateResponse) for m in msgs) + + def test_select_command_wrong_key_ignored(self): + entity = make_mic_setting(key=4, options=self.NOISE_OPTIONS) + msg = SelectCommandRequest(key=99, state="Low") + list(entity.handle_message(msg)) + entity._set_value_mock.assert_not_called() + + def test_subscribe_states_yields_select_state_response(self): + entity = make_mic_setting(key=4, options=self.NOISE_OPTIONS, value="Off") + msgs = list(entity.handle_message(SubscribeHomeAssistantStatesRequest())) + assert any(isinstance(m, SelectStateResponse) for m in msgs) + + def test_subscribe_states_response_has_string_state(self): + entity = make_mic_setting(key=4, options=self.NOISE_OPTIONS, value="Low") + msgs = list(entity.handle_message(SubscribeHomeAssistantStatesRequest())) + select_msg = next(m for m in msgs if isinstance(m, SelectStateResponse)) + assert isinstance(select_msg.state, str) \ No newline at end of file diff --git a/tests/unit/test_models.py b/tests/unit/test_models.py new file mode 100644 index 00000000..12d25349 --- /dev/null +++ b/tests/unit/test_models.py @@ -0,0 +1,304 @@ +"""Unit tests for shared models.""" + +import json +import pytest +from dataclasses import asdict +from pathlib import Path +from queue import Queue +from unittest.mock import MagicMock, patch + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def make_preferences(**kwargs): + from linux_voice_assistant.models import Preferences + return Preferences(**kwargs) + + +def make_server_state(tmp_path, **kwargs): + """Build a minimal ServerState with all required fields mocked out.""" + from linux_voice_assistant.models import Preferences, ServerState + + defaults = dict( + name="lva-test", + friendly_name="LVA Test", + mac_address="aa:bb:cc:dd:ee:ff", + ip_address="192.168.1.100", + network_interface="eth0", + version="1.0.0", + esphome_version="42.0.0", + audio_queue=Queue(), + entities=[], + available_wake_words={}, + wake_words={}, + active_wake_words=set(), + stop_word=MagicMock(), + music_player=MagicMock(), + tts_player=MagicMock(), + wakeup_sound="/sounds/wake.flac", + processing_sound="/sounds/processing.wav", + timer_finished_sound="/sounds/timer.flac", + mute_sound="/sounds/mute.flac", + unmute_sound="/sounds/unmute.flac", + preferences=Preferences(), + preferences_path=tmp_path / "preferences.json", + download_dir=tmp_path / "downloads", + ) + defaults.update(kwargs) + return ServerState(**defaults) + + +# --------------------------------------------------------------------------- +# WakeWordType +# --------------------------------------------------------------------------- + + +class TestWakeWordType: + def test_micro_value(self): + from linux_voice_assistant.models import WakeWordType + assert WakeWordType.MICRO_WAKE_WORD == "micro" + + def test_open_wake_word_value(self): + from linux_voice_assistant.models import WakeWordType + assert WakeWordType.OPEN_WAKE_WORD == "openWakeWord" + + def test_is_string_enum(self): + from linux_voice_assistant.models import WakeWordType + assert isinstance(WakeWordType.MICRO_WAKE_WORD, str) + + +# --------------------------------------------------------------------------- +# Preferences +# --------------------------------------------------------------------------- + + +class TestPreferences: + def test_default_active_wake_words_is_empty_list(self): + p = make_preferences() + assert p.active_wake_words == [] + + def test_default_volume_is_none(self): + p = make_preferences() + assert p.volume is None + + def test_default_thinking_sound_is_zero(self): + p = make_preferences() + assert p.thinking_sound == 0 + + def test_default_mic_auto_gain_is_zero(self): + p = make_preferences() + assert p.mic_auto_gain == 0 + + def test_default_mic_noise_suppression_is_zero(self): + p = make_preferences() + assert p.mic_noise_suppression == 0 + + def test_custom_values_stored(self): + p = make_preferences( + volume=0.8, + mic_auto_gain=5, + mic_noise_suppression=2, + thinking_sound=1, + ) + assert p.volume == 0.8 + assert p.mic_auto_gain == 5 + assert p.mic_noise_suppression == 2 + assert p.thinking_sound == 1 + + def test_active_wake_words_are_independent_per_instance(self): + """Mutable default must not be shared between instances.""" + p1 = make_preferences() + p2 = make_preferences() + p1.active_wake_words.append("okay_nabu") + assert p2.active_wake_words == [] + + +# --------------------------------------------------------------------------- +# ServerState.save_preferences() +# --------------------------------------------------------------------------- + + +class TestSavePreferences: + def test_creates_preferences_file(self, tmp_path): + state = make_server_state(tmp_path) + state.save_preferences() + assert state.preferences_path.exists() + + def test_saved_json_is_valid(self, tmp_path): + state = make_server_state(tmp_path) + state.save_preferences() + with open(state.preferences_path) as f: + data = json.load(f) + assert isinstance(data, dict) + + def test_saved_json_contains_expected_keys(self, tmp_path): + state = make_server_state(tmp_path) + state.save_preferences() + with open(state.preferences_path) as f: + data = json.load(f) + assert "mic_auto_gain" in data + assert "mic_noise_suppression" in data + assert "volume" in data + assert "thinking_sound" in data + + def test_saved_values_match_preferences(self, tmp_path): + from linux_voice_assistant.models import Preferences + prefs = Preferences(volume=0.75, mic_auto_gain=3, mic_noise_suppression=1) + state = make_server_state(tmp_path, preferences=prefs) + state.save_preferences() + with open(state.preferences_path) as f: + data = json.load(f) + assert data["volume"] == 0.75 + assert data["mic_auto_gain"] == 3 + assert data["mic_noise_suppression"] == 1 + + def test_creates_parent_directory_if_missing(self, tmp_path): + nested_path = tmp_path / "nested" / "dir" / "preferences.json" + state = make_server_state(tmp_path, preferences_path=nested_path) + state.save_preferences() + assert nested_path.exists() + + +# --------------------------------------------------------------------------- +# ServerState.persist_volume() +# --------------------------------------------------------------------------- + + +class TestPersistVolume: + def test_updates_volume_on_state(self, tmp_path): + state = make_server_state(tmp_path) + state.persist_volume(0.5) + assert state.volume == 0.5 + + def test_updates_volume_on_preferences(self, tmp_path): + state = make_server_state(tmp_path) + state.persist_volume(0.5) + assert state.preferences.volume == 0.5 + + def test_clamps_above_one(self, tmp_path): + state = make_server_state(tmp_path) + state.persist_volume(1.5) + assert state.volume == 1.0 + + def test_clamps_below_zero(self, tmp_path): + state = make_server_state(tmp_path) + state.persist_volume(-0.5) + assert state.volume == 0.0 + + def test_saves_to_file(self, tmp_path): + state = make_server_state(tmp_path) + state.persist_volume(0.6) + with open(state.preferences_path) as f: + data = json.load(f) + assert data["volume"] == pytest.approx(0.6) + + def test_skips_save_when_volume_unchanged(self, tmp_path): + from linux_voice_assistant.models import Preferences + prefs = Preferences(volume=0.5) + state = make_server_state(tmp_path, preferences=prefs, volume=0.5) + + with patch.object(state, "save_preferences") as mock_save: + state.persist_volume(0.5) + mock_save.assert_not_called() + + def test_saves_when_volume_changed(self, tmp_path): + from linux_voice_assistant.models import Preferences + prefs = Preferences(volume=0.5) + state = make_server_state(tmp_path, preferences=prefs, volume=0.5) + + with patch.object(state, "save_preferences") as mock_save: + state.persist_volume(0.8) + mock_save.assert_called_once() + + def test_boundary_value_zero(self, tmp_path): + state = make_server_state(tmp_path) + state.persist_volume(0.0) + assert state.volume == 0.0 + + def test_boundary_value_one(self, tmp_path): + state = make_server_state(tmp_path) + state.persist_volume(1.0) + assert state.volume == 1.0 + + +# --------------------------------------------------------------------------- +# ServerState.persist_mic_gain() +# --------------------------------------------------------------------------- + + +class TestPersistMicGain: + def test_updates_mic_auto_gain_on_state(self, tmp_path): + state = make_server_state(tmp_path) + state.persist_mic_gain(5.0) + assert state.mic_auto_gain == 5 + + def test_updates_mic_auto_gain_on_preferences(self, tmp_path): + state = make_server_state(tmp_path) + state.persist_mic_gain(5.0) + assert state.preferences.mic_auto_gain == 5 + + def test_converts_float_to_int(self, tmp_path): + state = make_server_state(tmp_path) + state.persist_mic_gain(7.9) + assert state.mic_auto_gain == 7 + + def test_skips_save_when_gain_unchanged(self, tmp_path): + state = make_server_state(tmp_path) + state.mic_auto_gain = 3 + state.preferences.mic_auto_gain = 3 + + with patch.object(state, "save_preferences") as mock_save: + state.persist_mic_gain(3.0) + mock_save.assert_not_called() + + def test_saves_when_gain_changed(self, tmp_path): + state = make_server_state(tmp_path) + state.mic_auto_gain = 3 + state.preferences.mic_auto_gain = 3 + + with patch.object(state, "save_preferences") as mock_save: + state.persist_mic_gain(10.0) + mock_save.assert_called_once() + + +# --------------------------------------------------------------------------- +# ServerState.persist_mic_noise() +# --------------------------------------------------------------------------- + + +class TestPersistMicNoise: + def test_updates_mic_noise_suppression_on_state(self, tmp_path): + state = make_server_state(tmp_path) + state.persist_mic_noise(2.0) + assert state.mic_noise_suppression == 2 + + def test_updates_mic_noise_suppression_on_preferences(self, tmp_path): + state = make_server_state(tmp_path) + state.persist_mic_noise(2.0) + assert state.preferences.mic_noise_suppression == 2 + + def test_converts_float_to_int(self, tmp_path): + state = make_server_state(tmp_path) + state.persist_mic_noise(3.9) + assert state.mic_noise_suppression == 3 + + def test_skips_save_when_noise_unchanged(self, tmp_path): + state = make_server_state(tmp_path) + state.mic_noise_suppression = 2 + state.preferences.mic_noise_suppression = 2 + + with patch.object(state, "save_preferences") as mock_save: + state.persist_mic_noise(2.0) + mock_save.assert_not_called() + + def test_saves_when_noise_changed(self, tmp_path): + state = make_server_state(tmp_path) + state.mic_noise_suppression = 2 + state.preferences.mic_noise_suppression = 2 + + with patch.object(state, "save_preferences") as mock_save: + state.persist_mic_noise(4.0) + mock_save.assert_called_once() \ No newline at end of file diff --git a/tests/unit/test_mpv_player.py b/tests/unit/test_mpv_player.py new file mode 100644 index 00000000..4e757f1c --- /dev/null +++ b/tests/unit/test_mpv_player.py @@ -0,0 +1,248 @@ +"""Unit tests for MpvMediaPlayer.""" + +import pytest +from unittest.mock import MagicMock, patch, call + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def make_player(device=None): + """Return an MpvMediaPlayer with the internal LibMpvPlayer mocked out.""" + from linux_voice_assistant.player.state import PlayerState + + mock_lib_player = MagicMock() + mock_lib_player.state.return_value = PlayerState.IDLE + + with patch("linux_voice_assistant.mpv_player.LibMpvPlayer", return_value=mock_lib_player): + from linux_voice_assistant.mpv_player import MpvMediaPlayer + player = MpvMediaPlayer(device=device) + player._mock = mock_lib_player + return player + + +# --------------------------------------------------------------------------- +# Initialization +# --------------------------------------------------------------------------- + + +class TestInit: + def test_player_initialized_with_device(self): + with patch("linux_voice_assistant.mpv_player.LibMpvPlayer") as mock_cls: + mock_cls.return_value = MagicMock() + from linux_voice_assistant.mpv_player import MpvMediaPlayer + MpvMediaPlayer(device="hw:1,0") + mock_cls.assert_called_once_with(device="hw:1,0") + + def test_player_initialized_with_none_device(self): + with patch("linux_voice_assistant.mpv_player.LibMpvPlayer") as mock_cls: + mock_cls.return_value = MagicMock() + from linux_voice_assistant.mpv_player import MpvMediaPlayer + MpvMediaPlayer(device=None) + mock_cls.assert_called_once_with(device=None) + + def test_done_callback_starts_none(self): + player = make_player() + assert player._done_callback is None + + def test_playlist_starts_empty(self): + player = make_player() + assert player._playlist == [] + + +# --------------------------------------------------------------------------- +# play() +# --------------------------------------------------------------------------- + + +class TestPlay: + def test_play_single_url_calls_lib_player(self): + player = make_player() + player.play("http://example.com/audio.mp3") + player._mock.play.assert_called_once() + + def test_play_single_url_passes_correct_url(self): + player = make_player() + player.play("http://example.com/audio.mp3") + args, kwargs = player._mock.play.call_args + assert args[0] == "http://example.com/audio.mp3" + + def test_play_list_plays_first_url(self): + player = make_player() + player.play(["http://a.com/1.mp3", "http://a.com/2.mp3"]) + args, kwargs = player._mock.play.call_args + assert args[0] == "http://a.com/1.mp3" + + def test_play_list_stores_remaining_in_playlist(self): + player = make_player() + player.play(["http://a.com/1.mp3", "http://a.com/2.mp3", "http://a.com/3.mp3"]) + assert player._playlist == ["http://a.com/2.mp3", "http://a.com/3.mp3"] + + def test_play_single_url_playlist_is_empty(self): + player = make_player() + player.play("http://example.com/audio.mp3") + assert player._playlist == [] + + def test_play_stores_done_callback(self): + player = make_player() + cb = MagicMock() + player.play("http://example.com/audio.mp3", done_callback=cb) + assert player._done_callback is cb + + def test_play_empty_list_does_not_call_lib_player(self): + player = make_player() + player.play([]) + player._mock.play.assert_not_called() + + def test_play_while_active_stops_previous(self): + from linux_voice_assistant.player.state import PlayerState + player = make_player() + player._done_callback = MagicMock() # simulate active playback + player._mock.state.return_value = PlayerState.PLAYING + + player.play("http://example.com/new.mp3") + player._mock.stop.assert_called_once() + + def test_play_while_idle_does_not_stop(self): + from linux_voice_assistant.player.state import PlayerState + player = make_player() + player._mock.state.return_value = PlayerState.IDLE + player._done_callback = None + + player.play("http://example.com/new.mp3") + player._mock.stop.assert_not_called() + + +# --------------------------------------------------------------------------- +# _on_track_finished() +# --------------------------------------------------------------------------- + + +class TestOnTrackFinished: + def test_plays_next_url_when_playlist_has_items(self): + player = make_player() + player._playlist = ["http://a.com/2.mp3"] + player._on_track_finished() + args, _ = player._mock.play.call_args + assert args[0] == "http://a.com/2.mp3" + + def test_invokes_done_callback_when_playlist_empty(self): + player = make_player() + cb = MagicMock() + player._done_callback = cb + player._playlist = [] + player._on_track_finished() + cb.assert_called_once() + + def test_clears_done_callback_after_invoking(self): + player = make_player() + player._done_callback = MagicMock() + player._playlist = [] + player._on_track_finished() + assert player._done_callback is None + + def test_no_error_when_done_callback_is_none(self): + player = make_player() + player._done_callback = None + player._playlist = [] + player._on_track_finished() # should not raise + + +# --------------------------------------------------------------------------- +# pause() / resume() / stop() +# --------------------------------------------------------------------------- + + +class TestPauseResumeStop: + def test_pause_delegates_to_lib_player(self): + player = make_player() + player.pause() + player._mock.pause.assert_called_once() + + def test_resume_delegates_to_lib_player(self): + player = make_player() + player.resume() + player._mock.resume.assert_called_once() + + def test_stop_delegates_to_lib_player(self): + player = make_player() + player.stop() + player._mock.stop.assert_called_once() + + def test_stop_invokes_done_callback(self): + player = make_player() + cb = MagicMock() + player._done_callback = cb + player.stop() + cb.assert_called_once() + + def test_stop_clears_done_callback(self): + player = make_player() + player._done_callback = MagicMock() + player.stop() + assert player._done_callback is None + + def test_stop_no_error_when_no_callback(self): + player = make_player() + player._done_callback = None + player.stop() # should not raise + + +# --------------------------------------------------------------------------- +# is_playing +# --------------------------------------------------------------------------- + + +class TestIsPlaying: + def test_true_when_playing(self): + from linux_voice_assistant.player.state import PlayerState + player = make_player() + player._mock.state.return_value = PlayerState.PLAYING + assert player.is_playing is True + + def test_true_when_paused(self): + from linux_voice_assistant.player.state import PlayerState + player = make_player() + player._mock.state.return_value = PlayerState.PAUSED + assert player.is_playing is True + + def test_true_when_loading(self): + from linux_voice_assistant.player.state import PlayerState + player = make_player() + player._mock.state.return_value = PlayerState.LOADING + assert player.is_playing is True + + def test_false_when_idle(self): + from linux_voice_assistant.player.state import PlayerState + player = make_player() + player._mock.state.return_value = PlayerState.IDLE + assert player.is_playing is False + + +# --------------------------------------------------------------------------- +# set_volume() / duck() / unduck() +# --------------------------------------------------------------------------- + + +class TestVolume: + def test_set_volume_delegates_to_lib_player(self): + player = make_player() + player.set_volume(75.0) + player._mock.set_volume.assert_called_once_with(75.0) + + def test_duck_delegates_to_lib_player(self): + player = make_player() + player.duck(factor=0.3) + player._mock.duck.assert_called_once_with(0.3) + + def test_duck_default_factor(self): + player = make_player() + player.duck() + player._mock.duck.assert_called_once_with(0.5) + + def test_unduck_delegates_to_lib_player(self): + player = make_player() + player.unduck() + player._mock.unduck.assert_called_once() \ No newline at end of file diff --git a/tests/unit/test_unit.py b/tests/unit/test_unit.py new file mode 100644 index 00000000..90a46fd3 --- /dev/null +++ b/tests/unit/test_unit.py @@ -0,0 +1,226 @@ +"""Unit tests for utility functions.""" + +import pytest +from unittest.mock import patch, MagicMock +from pathlib import Path + + +# --------------------------------------------------------------------------- +# get_version() +# --------------------------------------------------------------------------- + + +class TestGetVersion: + def setup_method(self): + """Reset the version cache before each test.""" + import linux_voice_assistant.util as util + util._version_cache = None + + def test_returns_unknown_when_file_missing(self): + import linux_voice_assistant.util as util + util._version_cache = None + + with patch.object(Path, "read_text", side_effect=FileNotFoundError): + result = util.get_version() + assert result == "unknown" + + def test_returns_version_string_from_file(self): + import linux_voice_assistant.util as util + util._version_cache = None + + with patch.object(Path, "read_text", return_value="1.2.3\n"): + result = util.get_version() + assert result == "1.2.3" + + def test_strips_whitespace_from_version(self): + import linux_voice_assistant.util as util + util._version_cache = None + + with patch.object(Path, "read_text", return_value=" 2.0.0 \n"): + result = util.get_version() + assert result == "2.0.0" + + def test_returns_unknown_for_empty_file(self): + import linux_voice_assistant.util as util + util._version_cache = None + + with patch.object(Path, "read_text", return_value=" "): + result = util.get_version() + assert result == "unknown" + + def test_caches_result_after_first_call(self): + import linux_voice_assistant.util as util + util._version_cache = None + + with patch.object(Path, "read_text", return_value="3.0.0") as mock_read: + util.get_version() + util.get_version() + mock_read.assert_called_once() + + def test_returns_cached_value_on_second_call(self): + import linux_voice_assistant.util as util + util._version_cache = None + + with patch.object(Path, "read_text", return_value="4.0.0"): + first = util.get_version() + + second = util.get_version() + assert first == second == "4.0.0" + + def test_returns_unknown_on_permission_error(self): + import linux_voice_assistant.util as util + util._version_cache = None + + with patch.object(Path, "read_text", side_effect=PermissionError): + result = util.get_version() + assert result == "unknown" + + +# --------------------------------------------------------------------------- +# get_esphome_version() +# --------------------------------------------------------------------------- + + +class TestGetEsphomeVersion: + def setup_method(self): + import linux_voice_assistant.util as util + util._esphome_version_cache = None + + def test_returns_version_when_package_installed(self): + import linux_voice_assistant.util as util + + with patch("linux_voice_assistant.util.version", return_value="42.7.0"): + result = util.get_esphome_version() + assert result == "42.7.0" + + def test_returns_unknown_when_package_not_installed(self): + import linux_voice_assistant.util as util + + from importlib.metadata import PackageNotFoundError + with patch("linux_voice_assistant.util.version", side_effect=PackageNotFoundError): + result = util.get_esphome_version() + assert result == "unknown" + + def test_caches_result_after_first_call(self): + import linux_voice_assistant.util as util + + with patch("linux_voice_assistant.util.version", return_value="1.0.0") as mock_ver: + util.get_esphome_version() + util.get_esphome_version() + mock_ver.assert_called_once() + + def test_returns_cached_value_on_second_call(self): + import linux_voice_assistant.util as util + + with patch("linux_voice_assistant.util.version", return_value="5.0.0"): + first = util.get_esphome_version() + + second = util.get_esphome_version() + assert first == second == "5.0.0" + + +# --------------------------------------------------------------------------- +# call_all() +# --------------------------------------------------------------------------- + + +class TestCallAll: + def test_calls_single_callable(self): + from linux_voice_assistant.util import call_all + mock_fn = MagicMock() + call_all(mock_fn) + mock_fn.assert_called_once() + + def test_calls_multiple_callables_in_order(self): + from linux_voice_assistant.util import call_all + calls = [] + call_all( + lambda: calls.append(1), + lambda: calls.append(2), + lambda: calls.append(3), + ) + assert calls == [1, 2, 3] + + def test_skips_none_entries(self): + from linux_voice_assistant.util import call_all + mock_fn = MagicMock() + call_all(None, mock_fn, None) + mock_fn.assert_called_once() + + def test_all_none_does_nothing(self): + from linux_voice_assistant.util import call_all + call_all(None, None, None) + + def test_empty_args_does_nothing(self): + from linux_voice_assistant.util import call_all + call_all() + + def test_none_mixed_with_callables_calls_only_non_none(self): + from linux_voice_assistant.util import call_all + results = [] + call_all(None, lambda: results.append("a"), None, lambda: results.append("b")) + assert results == ["a", "b"] + + +# --------------------------------------------------------------------------- +# get_default_interface() and get_default_ipv4() +# --------------------------------------------------------------------------- + + +class TestGetDefaultInterface: + def test_returns_interface_name_from_gateway(self): + import linux_voice_assistant.util as util + with patch("linux_voice_assistant.util.netifaces") as mock_netifaces: + # Set AF_INET before building the dict so the key matches + mock_netifaces.AF_INET = 2 + mock_netifaces.default_gateway.return_value = { + 2: ("192.168.1.1", "eth0") + } + result = util.get_default_interface() + assert result == "eth0" + + def test_returns_none_when_no_gateway(self): + import linux_voice_assistant.util as util + with patch("linux_voice_assistant.util.netifaces") as mock_netifaces: + mock_netifaces.default_gateway.return_value = {} + result = util.get_default_interface() + assert result is None + + def test_returns_none_when_no_ipv4_gateway(self): + import linux_voice_assistant.util as util + with patch("linux_voice_assistant.util.netifaces") as mock_netifaces: + mock_netifaces.AF_INET = 2 + # Only a non-IPv4 gateway present + mock_netifaces.default_gateway.return_value = {99: ("10.0.0.1", "eth1")} + result = util.get_default_interface() + assert result is None + + +class TestGetDefaultIpv4: + def test_returns_ip_for_interface(self): + import linux_voice_assistant.util as util + with patch("linux_voice_assistant.util.netifaces") as mock_netifaces: + mock_netifaces.AF_INET = 2 + mock_netifaces.ifaddresses.return_value = { + 2: [{"addr": "192.168.1.50"}] + } + result = util.get_default_ipv4("eth0") + assert result == "192.168.1.50" + + def test_returns_none_for_empty_interface(self): + import linux_voice_assistant.util as util + result = util.get_default_ipv4("") + assert result is None + + def test_returns_none_for_none_interface(self): + import linux_voice_assistant.util as util + result = util.get_default_ipv4(None) + assert result is None + + def test_returns_none_when_no_ipv4_address(self): + import linux_voice_assistant.util as util + with patch("linux_voice_assistant.util.netifaces") as mock_netifaces: + mock_netifaces.AF_INET = 2 + mock_netifaces.ifaddresses.return_value = {} + result = util.get_default_ipv4("eth0") + assert result is None \ No newline at end of file diff --git a/tests/unit/test_webrtc.py b/tests/unit/test_webrtc.py new file mode 100644 index 00000000..4d5671d9 --- /dev/null +++ b/tests/unit/test_webrtc.py @@ -0,0 +1,195 @@ +"""Unit tests for WebRTCProcessor.""" + +import pytest +from unittest.mock import MagicMock, patch + + +# --------------------------------------------------------------------------- +# Helpers / fixtures +# --------------------------------------------------------------------------- + +FRAME_SIZE = 320 # 160 samples * 2 bytes (16-bit PCM) +PATCH_TARGET = "webrtc_noise_gain.AudioProcessor" + + +def make_audio(n_bytes: int, fill: int = 0xAB) -> bytes: + """Return n_bytes of dummy PCM data.""" + return bytes([fill]) * n_bytes + + +def make_mock_apm(output_fill: int = 0x00): + """Return a mock AudioProcessor whose Process10ms returns a frame-sized result.""" + mock_apm = MagicMock() + mock_result = MagicMock() + mock_result.audio = bytes([output_fill]) * FRAME_SIZE + mock_apm.Process10ms.return_value = mock_result + return mock_apm + + +@pytest.fixture +def processor(): + """WebRTCProcessor with a mocked AudioProcessor so no C extension is needed.""" + mock_apm = make_mock_apm() + with patch(PATCH_TARGET, return_value=mock_apm): + from linux_voice_assistant.webrtc import WebRTCProcessor + + proc = WebRTCProcessor(agc_level=3, ns_level=2) + proc._mock_apm = mock_apm + yield proc + + +# --------------------------------------------------------------------------- +# Initialization +# --------------------------------------------------------------------------- + + +class TestInit: + def test_frame_size_is_320(self, processor): + assert processor.FRAME_SIZE_BYTES == FRAME_SIZE + + def test_agc_stored(self, processor): + assert processor.agc_level == 3 + + def test_ns_stored(self, processor): + assert processor.ns_level == 2 + + def test_buffer_starts_empty(self, processor): + assert len(processor._buffer) == 0 + + def test_audio_processor_constructed_with_correct_levels(self): + mock_apm = make_mock_apm() + with patch(PATCH_TARGET, return_value=mock_apm) as mock_cls: + from linux_voice_assistant.webrtc import WebRTCProcessor + + WebRTCProcessor(agc_level=5, ns_level=1) + mock_cls.assert_called_once_with(5, 1) + + +# --------------------------------------------------------------------------- +# process() — buffering behaviour +# --------------------------------------------------------------------------- + + +class TestProcessBuffering: + def test_exact_one_frame_returns_processed_bytes(self, processor): + result = processor.process(make_audio(FRAME_SIZE)) + assert len(result) == FRAME_SIZE + processor._mock_apm.Process10ms.assert_called_once() + + def test_less_than_one_frame_returns_empty(self, processor): + result = processor.process(make_audio(FRAME_SIZE - 1)) + assert result == b"" + processor._mock_apm.Process10ms.assert_not_called() + + def test_two_frames_returns_two_processed_chunks(self, processor): + result = processor.process(make_audio(FRAME_SIZE * 2)) + assert len(result) == FRAME_SIZE * 2 + assert processor._mock_apm.Process10ms.call_count == 2 + + def test_partial_input_buffers_remainder(self, processor): + # 500 bytes = 1 full frame (320) + 180 leftover + result = processor.process(make_audio(500)) + assert len(result) == FRAME_SIZE + assert len(processor._buffer) == 500 - FRAME_SIZE + + def test_accumulated_calls_eventually_flush(self, processor): + # Two calls of 160 bytes each should flush one frame total + processor.process(make_audio(160)) + assert processor._mock_apm.Process10ms.call_count == 0 + processor.process(make_audio(160)) + assert processor._mock_apm.Process10ms.call_count == 1 + + def test_buffer_drains_in_place(self, processor): + """After processing, only the remainder stays in the buffer.""" + processor.process(make_audio(500)) + assert len(processor._buffer) == 180 + + def test_empty_input_returns_empty(self, processor): + result = processor.process(b"") + assert result == b"" + processor._mock_apm.Process10ms.assert_not_called() + + def test_output_is_concatenation_of_processed_chunks(self, processor): + """Each frame gets processed independently and results are joined.""" + apm = processor._mock_apm + apm.Process10ms.side_effect = [ + MagicMock(audio=b"\x01" * FRAME_SIZE), + MagicMock(audio=b"\x02" * FRAME_SIZE), + ] + result = processor.process(make_audio(FRAME_SIZE * 2)) + assert result == b"\x01" * FRAME_SIZE + b"\x02" * FRAME_SIZE + + def test_multiple_process_calls_accumulate_buffer(self, processor): + """Remainder from first call is used in second call.""" + processor.process(make_audio(200)) # 200 buffered, no flush + processor.process(make_audio(200)) # 400 total, 1 flush, 80 remain + assert processor._mock_apm.Process10ms.call_count == 1 + assert len(processor._buffer) == 80 + + +# --------------------------------------------------------------------------- +# update_settings() +# --------------------------------------------------------------------------- + + +class TestUpdateSettings: + def test_reinitializes_when_agc_changes(self, processor): + new_apm = make_mock_apm() + with patch(PATCH_TARGET, return_value=new_apm): + processor.update_settings(agc_level=10, ns_level=2) + assert processor.apm is new_apm + assert processor.agc_level == 10 + + def test_reinitializes_when_ns_changes(self, processor): + new_apm = make_mock_apm() + with patch(PATCH_TARGET, return_value=new_apm): + processor.update_settings(agc_level=3, ns_level=4) + assert processor.apm is new_apm + assert processor.ns_level == 4 + + def test_no_reinitialize_when_settings_unchanged(self, processor): + original_apm = processor.apm + with patch(PATCH_TARGET) as mock_cls: + processor.update_settings(agc_level=3, ns_level=2) + mock_cls.assert_not_called() + assert processor.apm is original_apm + + def test_stores_new_agc_level(self, processor): + with patch(PATCH_TARGET, return_value=make_mock_apm()): + processor.update_settings(agc_level=15, ns_level=2) + assert processor.agc_level == 15 + + def test_stores_new_ns_level(self, processor): + with patch(PATCH_TARGET, return_value=make_mock_apm()): + processor.update_settings(agc_level=3, ns_level=3) + assert processor.ns_level == 3 + + def test_new_apm_called_with_updated_levels(self, processor): + with patch(PATCH_TARGET, return_value=make_mock_apm()) as mock_cls: + processor.update_settings(agc_level=7, ns_level=1) + mock_cls.assert_called_once_with(7, 1) + + +# --------------------------------------------------------------------------- +# process() after update_settings() +# --------------------------------------------------------------------------- + + +class TestProcessAfterUpdate: + def test_buffer_preserved_across_settings_update(self, processor): + """Buffered bytes should survive a settings change.""" + processor.process(make_audio(160)) # half frame, stays buffered + assert len(processor._buffer) == 160 + + with patch(PATCH_TARGET, return_value=make_mock_apm()): + processor.update_settings(agc_level=10, ns_level=2) + + assert len(processor._buffer) == 160 + + def test_process_uses_new_apm_after_update(self, processor): + new_apm = make_mock_apm(output_fill=0xFF) + with patch(PATCH_TARGET, return_value=new_apm): + processor.update_settings(agc_level=10, ns_level=2) + + processor.process(make_audio(FRAME_SIZE)) + new_apm.Process10ms.assert_called_once() \ No newline at end of file diff --git a/tests/unit/test_zeroconf.py b/tests/unit/test_zeroconf.py new file mode 100644 index 00000000..73de4552 --- /dev/null +++ b/tests/unit/test_zeroconf.py @@ -0,0 +1,174 @@ +"""Unit tests for HomeAssistantZeroconf.""" + +import pytest +from unittest.mock import MagicMock, patch, AsyncMock + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def make_zeroconf(**kwargs): + defaults = dict( + port=6053, + mac_address="aa:bb:cc:dd:ee:ff", + host_ip_address="192.168.1.100", + name="lva-test", + ) + defaults.update(kwargs) + + with patch("linux_voice_assistant.zeroconf.AsyncZeroconf") as mock_zc_cls: + mock_zc = MagicMock() + mock_zc_cls.return_value = mock_zc + from linux_voice_assistant.zeroconf import HomeAssistantZeroconf + instance = HomeAssistantZeroconf(**defaults) + instance._mock_zc = mock_zc + instance._mock_zc_cls = mock_zc_cls + return instance + + +# --------------------------------------------------------------------------- +# __init__ +# --------------------------------------------------------------------------- + + +class TestInit: + def test_name_stored(self): + zc = make_zeroconf(name="my-lva") + assert zc.name == "my-lva" + + def test_port_stored(self): + zc = make_zeroconf(port=1234) + assert zc.port == 1234 + + def test_mac_address_stored(self): + zc = make_zeroconf(mac_address="11:22:33:44:55:66") + assert zc.mac_address == "11:22:33:44:55:66" + + def test_host_ip_stored(self): + zc = make_zeroconf(host_ip_address="10.0.0.1") + assert zc.host_ip_address == "10.0.0.1" + + def test_name_defaults_to_mac_when_not_provided(self): + with patch("linux_voice_assistant.zeroconf.AsyncZeroconf"): + from linux_voice_assistant.zeroconf import HomeAssistantZeroconf + zc = HomeAssistantZeroconf( + port=6053, + mac_address="aa:bb:cc:dd:ee:ff", + host_ip_address="192.168.1.1", + name=None, + ) + assert zc.name == "aa:bb:cc:dd:ee:ff" + + def test_async_zeroconf_instantiated(self): + with patch("linux_voice_assistant.zeroconf.AsyncZeroconf") as mock_cls: + mock_cls.return_value = MagicMock() + from linux_voice_assistant.zeroconf import HomeAssistantZeroconf + HomeAssistantZeroconf( + port=6053, + mac_address="aa:bb:cc:dd:ee:ff", + host_ip_address="192.168.1.1", + ) + mock_cls.assert_called_once() + + +# --------------------------------------------------------------------------- +# register_server() +# --------------------------------------------------------------------------- + + +class TestRegisterServer: + @pytest.mark.asyncio + async def test_register_service_called(self): + zc = make_zeroconf() + zc._mock_zc.async_register_service = AsyncMock() + + with patch("linux_voice_assistant.zeroconf.AsyncServiceInfo"): + await zc.register_server() + + zc._mock_zc.async_register_service.assert_called_once() + + @pytest.mark.asyncio + async def test_service_name_contains_device_name(self): + zc = make_zeroconf(name="lva-aabbccddee") + zc._mock_zc.async_register_service = AsyncMock() + + captured = {} + + with patch("linux_voice_assistant.zeroconf.AsyncServiceInfo") as mock_info_cls: + mock_info_cls.side_effect = lambda *a, **kw: captured.update({"args": a, "kwargs": kw}) or MagicMock() + await zc.register_server() + + # First positional arg is service type, second is full service name + assert "lva-aabbccddee" in captured["args"][1] + + @pytest.mark.asyncio + async def test_service_type_is_esphomelib(self): + zc = make_zeroconf() + zc._mock_zc.async_register_service = AsyncMock() + + captured = {} + + with patch("linux_voice_assistant.zeroconf.AsyncServiceInfo") as mock_info_cls: + mock_info_cls.side_effect = lambda *a, **kw: captured.update({"args": a, "kwargs": kw}) or MagicMock() + await zc.register_server() + + assert captured["args"][0] == "_esphomelib._tcp.local." + + @pytest.mark.asyncio + async def test_service_properties_contain_mac(self): + zc = make_zeroconf(mac_address="aa:bb:cc:dd:ee:ff") + zc._mock_zc.async_register_service = AsyncMock() + + captured = {} + + with patch("linux_voice_assistant.zeroconf.AsyncServiceInfo") as mock_info_cls: + mock_info_cls.side_effect = lambda *a, **kw: captured.update({"args": a, "kwargs": kw}) or MagicMock() + await zc.register_server() + + props = captured["kwargs"].get("properties", {}) + assert "mac" in props + assert props["mac"] == "aa:bb:cc:dd:ee:ff" + + @pytest.mark.asyncio + async def test_service_properties_contain_version(self): + zc = make_zeroconf() + zc._mock_zc.async_register_service = AsyncMock() + + captured = {} + + with patch("linux_voice_assistant.zeroconf.AsyncServiceInfo") as mock_info_cls: + mock_info_cls.side_effect = lambda *a, **kw: captured.update({"args": a, "kwargs": kw}) or MagicMock() + await zc.register_server() + + props = captured["kwargs"].get("properties", {}) + assert "version" in props + + @pytest.mark.asyncio + async def test_service_port_matches(self): + zc = make_zeroconf(port=9999) + zc._mock_zc.async_register_service = AsyncMock() + + captured = {} + + with patch("linux_voice_assistant.zeroconf.AsyncServiceInfo") as mock_info_cls: + mock_info_cls.side_effect = lambda *a, **kw: captured.update({"args": a, "kwargs": kw}) or MagicMock() + await zc.register_server() + + assert captured["kwargs"].get("port") == 9999 + + @pytest.mark.asyncio + async def test_service_address_matches_host_ip(self): + import socket + zc = make_zeroconf(host_ip_address="10.0.0.5") + zc._mock_zc.async_register_service = AsyncMock() + + captured = {} + + with patch("linux_voice_assistant.zeroconf.AsyncServiceInfo") as mock_info_cls: + mock_info_cls.side_effect = lambda *a, **kw: captured.update({"args": a, "kwargs": kw}) or MagicMock() + await zc.register_server() + + addresses = captured["kwargs"].get("addresses", []) + assert socket.inet_aton("10.0.0.5") in addresses \ No newline at end of file From f1f3d942a79eb280a84f55f1a6bcfadb463892d2 Mon Sep 17 00:00:00 2001 From: aryanhasgithub Date: Sat, 25 Apr 2026 13:33:43 +0530 Subject: [PATCH 02/10] Change Python version from 3.13 to 3.12 --- .github/workflows/tests.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 00bdfd0b..8d06c779 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -18,10 +18,10 @@ jobs: - name: Checkout code uses: actions/checkout@v4 - - name: Set up Python 3.13 + - name: Set up Python 3.12 uses: actions/setup-python@v5 with: - python-version: "3.13" + python-version: "3.12" - name: Install dependencies run: | @@ -29,4 +29,4 @@ jobs: script/setup --dev - name: Run unit tests - run: pytest tests/unit \ No newline at end of file + run: pytest tests/unit From 7dfadbf4ec3fa5aa150b13a5dab9c36bc5fbd4c4 Mon Sep 17 00:00:00 2001 From: aryanhasgithub Date: Sat, 25 Apr 2026 13:37:36 +0530 Subject: [PATCH 03/10] Update pytest command to use python -m --- .github/workflows/tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 8d06c779..90fea267 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -29,4 +29,4 @@ jobs: script/setup --dev - name: Run unit tests - run: pytest tests/unit + run: python -m pytest tests/unit From bce2ee429444c1a071ed08c2320d82f9c2216aad Mon Sep 17 00:00:00 2001 From: Your Name Date: Sat, 25 Apr 2026 08:12:19 +0000 Subject: [PATCH 04/10] Use tests Script --- .github/workflows/tests.yml | 2 +- script/{test => tests} | 10 +++++++--- 2 files changed, 8 insertions(+), 4 deletions(-) rename script/{test => tests} (70%) mode change 100755 => 100644 diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 90fea267..1fefda49 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -29,4 +29,4 @@ jobs: script/setup --dev - name: Run unit tests - run: python -m pytest tests/unit + run: ./script/tests diff --git a/script/test b/script/tests old mode 100755 new mode 100644 similarity index 70% rename from script/test rename to script/tests index 98a3068e..e292a22b --- a/script/test +++ b/script/tests @@ -1,13 +1,15 @@ #!/usr/bin/env python3 import subprocess -import sys import venv from pathlib import Path +import sys _DIR = Path(__file__).parent _PROGRAM_DIR = _DIR.parent _VENV_DIR = _PROGRAM_DIR / ".venv" -_TEST_DIR = _PROGRAM_DIR / "tests" +_TESTS_DIR = _PROGRAM_DIR / "tests" / "unit" + +_TEST_DIRS = [_TESTS_DIR] if _VENV_DIR.exists(): context = venv.EnvBuilder().ensure_directories(_VENV_DIR) @@ -15,4 +17,6 @@ if _VENV_DIR.exists(): else: python_exe = "python3" -subprocess.check_call([python_exe, "-m", "pytest", _TEST_DIR] + sys.argv[1:]) + +subprocess.check_call([python_exe, "-m", "pytest"] + _TEST_DIRS) + From a5d909696bbcaf97296f1077939be30c9ec24717 Mon Sep 17 00:00:00 2001 From: Your Name Date: Sat, 25 Apr 2026 08:13:42 +0000 Subject: [PATCH 05/10] Make script/tests executable --- script/tests | 0 1 file changed, 0 insertions(+), 0 deletions(-) mode change 100644 => 100755 script/tests diff --git a/script/tests b/script/tests old mode 100644 new mode 100755 From 70c5e944ca14cbedea5bdb1b9e99e8f0e196e356 Mon Sep 17 00:00:00 2001 From: Your Name Date: Sat, 25 Apr 2026 08:14:13 +0000 Subject: [PATCH 06/10] Make tests executable --- script/tests | 0 1 file changed, 0 insertions(+), 0 deletions(-) mode change 100755 => 100644 script/tests diff --git a/script/tests b/script/tests old mode 100755 new mode 100644 From 98fe4d7b3a28a90eed8c36847fe68529fa164902 Mon Sep 17 00:00:00 2001 From: Your Name Date: Sat, 25 Apr 2026 08:16:38 +0000 Subject: [PATCH 07/10] Make tests executable --- script/tests | 0 1 file changed, 0 insertions(+), 0 deletions(-) mode change 100644 => 100755 script/tests diff --git a/script/tests b/script/tests old mode 100644 new mode 100755 From 7a66b48b19c8e61253fc30837ba4fbb24ebab5a8 Mon Sep 17 00:00:00 2001 From: Your Name Date: Sat, 25 Apr 2026 08:23:59 +0000 Subject: [PATCH 08/10] Update tests CI to install libmpv --- .github/workflows/tests.yml | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 1fefda49..415f8449 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -23,10 +23,15 @@ jobs: with: python-version: "3.12" + - name: Install system dependencies + run: | + sudo apt-get update + sudo apt-get install -y libmpv-dev mpv + - name: Install dependencies run: | python -m pip install --upgrade pip - script/setup --dev + ./script/setup --dev - name: Run unit tests - run: ./script/tests + run: ./script/tests \ No newline at end of file From 38833c76b2d00674764051a3a3b0bcabb3eb8c9b Mon Sep 17 00:00:00 2001 From: Your Name Date: Sun, 26 Apr 2026 04:27:05 +0000 Subject: [PATCH 09/10] Add Unit Tests for wake word, satellite and main --- tests/unit/test_main.py | 273 ++++++++++++++++++++++ tests/unit/test_satellite.py | 425 +++++++++++++++++++++++++++++++++++ tests/unit/test_wake_word.py | 271 ++++++++++++++++++++++ 3 files changed, 969 insertions(+) create mode 100644 tests/unit/test_main.py create mode 100644 tests/unit/test_satellite.py create mode 100644 tests/unit/test_wake_word.py diff --git a/tests/unit/test_main.py b/tests/unit/test_main.py new file mode 100644 index 00000000..d490b82b --- /dev/null +++ b/tests/unit/test_main.py @@ -0,0 +1,273 @@ +"""Unit tests for __main__.py — process_audio logic and argument parsing helpers.""" + +import pytest +import numpy as np +from unittest.mock import MagicMock, patch, call +from queue import Queue +from pathlib import Path + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def make_state(tmp_path): + from linux_voice_assistant.models import Preferences, ServerState + + stop_word = MagicMock() + stop_word.id = "stop" + stop_word.is_active = False + stop_word.process_streaming.return_value = False + + prefs = Preferences() + + state = ServerState( + name="lva-test", + friendly_name="LVA Test", + mac_address="aa:bb:cc:dd:ee:ff", + ip_address="192.168.1.1", + network_interface="eth0", + version="1.0.0", + esphome_version="42.0.0", + audio_queue=Queue(), + entities=[], + available_wake_words={}, + wake_words={}, + active_wake_words=set(), + stop_word=stop_word, + music_player=MagicMock(), + tts_player=MagicMock(), + wakeup_sound="/sounds/wake.flac", + processing_sound="/sounds/processing.wav", + timer_finished_sound="/sounds/timer.flac", + mute_sound="/sounds/mute.flac", + unmute_sound="/sounds/unmute.flac", + preferences=prefs, + preferences_path=tmp_path / "preferences.json", + download_dir=tmp_path / "downloads", + volume=1.0, + mic_volume=100, + mic_auto_gain=0, + mic_noise_suppression=0, + ) + return state + + +def make_audio_chunk(samples: int = 1024, value: float = 0.0) -> np.ndarray: + return np.full(samples, value, dtype=np.float32) + + +# --------------------------------------------------------------------------- +# process_audio — mic volume scaling +# --------------------------------------------------------------------------- + + +class TestProcessAudioMicVolume: + def test_mic_volume_scalar_at_100(self, tmp_path): + """mic_volume=100 → scalar=1.0 → no attenuation.""" + state = make_state(tmp_path) + state.mic_volume = 100 + state.satellite = None # no satellite → loop exits immediately + + # The scalar clamp: max(0.1, min(1.0, 100/100)) = 1.0 + scalar = max(0.1, min(1.0, state.mic_volume / 100.0)) + assert scalar == pytest.approx(1.0) + + def test_mic_volume_scalar_at_50(self, tmp_path): + state = make_state(tmp_path) + state.mic_volume = 50 + scalar = max(0.1, min(1.0, state.mic_volume / 100.0)) + assert scalar == pytest.approx(0.5) + + def test_mic_volume_scalar_clamped_above_100(self, tmp_path): + state = make_state(tmp_path) + state.mic_volume = 200 + scalar = max(0.1, min(1.0, state.mic_volume / 100.0)) + assert scalar == pytest.approx(1.0) + + def test_mic_volume_scalar_minimum_is_0_1(self, tmp_path): + state = make_state(tmp_path) + state.mic_volume = 0 + scalar = max(0.1, min(1.0, state.mic_volume / 100.0)) + assert scalar == pytest.approx(0.1) + + def test_audio_chunk_scaled_by_mic_volume(self, tmp_path): + """Verify the numpy scaling produces correct output.""" + audio = make_audio_chunk(value=0.5) + mic_vol_scalar = 0.5 + result = np.clip(audio * mic_vol_scalar, -1.0, 1.0) + assert np.allclose(result, 0.25) + + def test_audio_chunk_clipped_to_bounds(self, tmp_path): + """Values that exceed [-1, 1] after scaling should be clipped.""" + audio = make_audio_chunk(value=1.0) + mic_vol_scalar = 2.0 # would push to 2.0 without clip + result = np.clip(audio * mic_vol_scalar, -1.0, 1.0) + assert np.all(result <= 1.0) + + +# --------------------------------------------------------------------------- +# process_audio — WebRTC integration +# --------------------------------------------------------------------------- + + +class TestProcessAudioWebRTC: + def test_webrtc_not_created_when_agc_and_ns_zero(self, tmp_path): + """If both AGC and NS are 0, WebRTCProcessor should never be instantiated.""" + state = make_state(tmp_path) + state.preferences.mic_auto_gain = 0 + state.preferences.mic_noise_suppression = 0 + + agc = state.preferences.mic_auto_gain or 0 + ns = state.preferences.mic_noise_suppression or 0 + should_create = agc > 0 or ns > 0 + assert should_create is False + + def test_webrtc_created_when_agc_nonzero(self, tmp_path): + state = make_state(tmp_path) + state.preferences.mic_auto_gain = 5 + state.preferences.mic_noise_suppression = 0 + + agc = state.preferences.mic_auto_gain or 0 + ns = state.preferences.mic_noise_suppression or 0 + should_create = agc > 0 or ns > 0 + assert should_create is True + + def test_webrtc_created_when_ns_nonzero(self, tmp_path): + state = make_state(tmp_path) + state.preferences.mic_auto_gain = 0 + state.preferences.mic_noise_suppression = 2 + + agc = state.preferences.mic_auto_gain or 0 + ns = state.preferences.mic_noise_suppression or 0 + should_create = agc > 0 or ns > 0 + assert should_create is True + + +# --------------------------------------------------------------------------- +# process_audio — wake word refractory +# --------------------------------------------------------------------------- + + +class TestWakeWordRefractory: + def test_refractory_allows_activation_when_none(self): + """First activation (last_active=None) should always be allowed.""" + import time + last_active = None + refractory_seconds = 2.0 + now = time.monotonic() + allowed = (last_active is None) or ((now - last_active) > refractory_seconds) + assert allowed is True + + def test_refractory_blocks_activation_too_soon(self): + """Activation within refractory period should be blocked.""" + import time + last_active = time.monotonic() # just activated + refractory_seconds = 2.0 + now = time.monotonic() + allowed = (last_active is None) or ((now - last_active) > refractory_seconds) + assert allowed is False + + def test_refractory_allows_activation_after_period(self): + """Activation after refractory period should be allowed.""" + import time + last_active = time.monotonic() - 3.0 # 3 seconds ago + refractory_seconds = 2.0 + now = time.monotonic() + allowed = (last_active is None) or ((now - last_active) > refractory_seconds) + assert allowed is True + + +# --------------------------------------------------------------------------- +# process_audio — stop word detection logic +# --------------------------------------------------------------------------- + + +class TestStopWordLogic: + def test_stop_word_only_triggers_when_in_active_set(self, tmp_path): + state = make_state(tmp_path) + # Stop word not in active set → should not trigger stop + state.active_wake_words = set() + stopped = True # simulating detection + + should_stop = stopped and (state.stop_word.id in state.active_wake_words) and not state.muted + assert should_stop is False + + def test_stop_word_triggers_when_in_active_set_and_not_muted(self, tmp_path): + state = make_state(tmp_path) + state.active_wake_words = {state.stop_word.id} + state.muted = False + stopped = True + + should_stop = stopped and (state.stop_word.id in state.active_wake_words) and not state.muted + assert should_stop is True + + def test_stop_word_does_not_trigger_when_muted(self, tmp_path): + state = make_state(tmp_path) + state.active_wake_words = {state.stop_word.id} + state.muted = True + stopped = True + + should_stop = stopped and (state.stop_word.id in state.active_wake_words) and not state.muted + assert should_stop is False + + +# --------------------------------------------------------------------------- +# Preferences loading from args +# --------------------------------------------------------------------------- + + +class TestPreferencesFromArgs: + def test_agc_stored_in_preferences_when_nonzero(self, tmp_path): + from linux_voice_assistant.models import Preferences + prefs = Preferences() + mic_auto_gain = 5 + if mic_auto_gain > 0: + prefs.mic_auto_gain = mic_auto_gain + assert prefs.mic_auto_gain == 5 + + def test_agc_not_stored_when_zero(self, tmp_path): + from linux_voice_assistant.models import Preferences + prefs = Preferences() + mic_auto_gain = 0 + if mic_auto_gain > 0: + prefs.mic_auto_gain = mic_auto_gain + assert prefs.mic_auto_gain == 0 + + def test_ns_stored_in_preferences_when_nonzero(self, tmp_path): + from linux_voice_assistant.models import Preferences + prefs = Preferences() + mic_noise_suppression = 2 + if mic_noise_suppression > 0: + prefs.mic_noise_suppression = mic_noise_suppression + assert prefs.mic_noise_suppression == 2 + + def test_volume_clamped_to_one_when_above(self, tmp_path): + from linux_voice_assistant.models import Preferences + prefs = Preferences(volume=1.5) + initial_volume = prefs.volume if prefs.volume is not None else 1.0 + initial_volume = max(0.0, min(1.0, float(initial_volume))) + assert initial_volume == pytest.approx(1.0) + + def test_volume_clamped_to_zero_when_below(self, tmp_path): + from linux_voice_assistant.models import Preferences + prefs = Preferences(volume=-0.5) + initial_volume = prefs.volume if prefs.volume is not None else 1.0 + initial_volume = max(0.0, min(1.0, float(initial_volume))) + assert initial_volume == pytest.approx(0.0) + + def test_volume_defaults_to_one_when_none(self, tmp_path): + from linux_voice_assistant.models import Preferences + prefs = Preferences(volume=None) + initial_volume = prefs.volume if prefs.volume is not None else 1.0 + initial_volume = max(0.0, min(1.0, float(initial_volume))) + assert initial_volume == pytest.approx(1.0) + + def test_thinking_sound_enabled_by_flag(self, tmp_path): + from linux_voice_assistant.models import Preferences + prefs = Preferences() + enable_thinking_sound = True + if enable_thinking_sound: + prefs.thinking_sound = 1 + assert prefs.thinking_sound == 1 \ No newline at end of file diff --git a/tests/unit/test_satellite.py b/tests/unit/test_satellite.py new file mode 100644 index 00000000..794e136c --- /dev/null +++ b/tests/unit/test_satellite.py @@ -0,0 +1,425 @@ +"""Unit tests for VoiceSatelliteProtocol logic.""" + +import pytest +from unittest.mock import MagicMock, patch, call +from queue import Queue +from pathlib import Path + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def make_state(tmp_path=None): + """Build a minimal ServerState with all fields mocked.""" + from linux_voice_assistant.models import Preferences, ServerState, WakeWordType, AvailableWakeWord + + if tmp_path is None: + import tempfile + tmp_path = Path(tempfile.mkdtemp()) + + stop_word = MagicMock() + stop_word.id = "stop" + stop_word.is_active = False + + prefs = Preferences() + + state = ServerState( + name="lva-test", + friendly_name="LVA Test", + mac_address="aa:bb:cc:dd:ee:ff", + ip_address="192.168.1.1", + network_interface="eth0", + version="1.0.0", + esphome_version="42.0.0", + audio_queue=Queue(), + entities=[], + available_wake_words={}, + wake_words={}, + active_wake_words=set(), + stop_word=stop_word, + music_player=MagicMock(), + tts_player=MagicMock(), + wakeup_sound="/sounds/wake.flac", + processing_sound="/sounds/processing.wav", + timer_finished_sound="/sounds/timer.flac", + mute_sound="/sounds/mute.flac", + unmute_sound="/sounds/unmute.flac", + preferences=prefs, + preferences_path=tmp_path / "preferences.json", + download_dir=tmp_path / "downloads", + volume=1.0, + mic_volume=100, + mic_auto_gain=0, + mic_noise_suppression=0, + ) + return state + + +def make_satellite(tmp_path=None): + """Build a VoiceSatelliteProtocol with all heavy dependencies mocked.""" + state = make_state(tmp_path) + + # Mock the sensitivity entity classes so __init__ doesn't blow up + with patch("linux_voice_assistant.satellite.WakeWord1SensitivityNumberEntity", MagicMock()), \ + patch("linux_voice_assistant.satellite.WakeWord2SensitivityNumberEntity", MagicMock()), \ + patch("linux_voice_assistant.satellite.StopWordSensitivityNumberEntity", MagicMock()): + from linux_voice_assistant.satellite import VoiceSatelliteProtocol + satellite = VoiceSatelliteProtocol(state) + + # Attach mock transport so send_messages works + satellite._writelines = MagicMock() + satellite._loop = None + return satellite + + +# --------------------------------------------------------------------------- +# Initialization +# --------------------------------------------------------------------------- + + +class TestInit: + def test_satellite_stored_on_state(self, tmp_path): + sat = make_satellite(tmp_path) + assert sat.state.satellite is sat + + def test_connected_starts_false(self, tmp_path): + sat = make_satellite(tmp_path) + assert sat.state.connected is False + + def test_media_player_entity_created(self, tmp_path): + from linux_voice_assistant.entity import MediaPlayerEntity + sat = make_satellite(tmp_path) + assert sat.state.media_player_entity is not None + assert isinstance(sat.state.media_player_entity, MediaPlayerEntity) + + def test_mute_switch_entity_created(self, tmp_path): + from linux_voice_assistant.entity import MuteSwitchEntity + sat = make_satellite(tmp_path) + assert sat.state.mute_switch_entity is not None + assert isinstance(sat.state.mute_switch_entity, MuteSwitchEntity) + + def test_thinking_sound_entity_created(self, tmp_path): + from linux_voice_assistant.entity import ThinkingSoundEntity + sat = make_satellite(tmp_path) + assert sat.state.thinking_sound_entity is not None + assert isinstance(sat.state.thinking_sound_entity, ThinkingSoundEntity) + + def test_mic_gain_entity_created(self, tmp_path): + from linux_voice_assistant.entity import MicSettingEntity + sat = make_satellite(tmp_path) + assert sat.state.mic_gain_entity is not None + assert isinstance(sat.state.mic_gain_entity, MicSettingEntity) + + def test_mic_noise_entity_created(self, tmp_path): + from linux_voice_assistant.entity import MicSettingEntity + sat = make_satellite(tmp_path) + assert sat.state.mic_noise_suppression_entity is not None + assert isinstance(sat.state.mic_noise_suppression_entity, MicSettingEntity) + + def test_mic_volume_entity_created(self, tmp_path): + from linux_voice_assistant.entity import MicSettingEntity + sat = make_satellite(tmp_path) + assert sat.state.mic_volume_entity is not None + assert isinstance(sat.state.mic_volume_entity, MicSettingEntity) + + def test_pipeline_not_active_on_start(self, tmp_path): + sat = make_satellite(tmp_path) + assert sat._pipeline_active is False + + def test_not_streaming_audio_on_start(self, tmp_path): + sat = make_satellite(tmp_path) + assert sat._is_streaming_audio is False + + def test_not_muted_on_start(self, tmp_path): + sat = make_satellite(tmp_path) + assert sat.state.muted is False + + def test_thinking_sound_loaded_from_preferences(self, tmp_path): + from linux_voice_assistant.models import Preferences + state = make_state(tmp_path) + state.preferences.thinking_sound = 1 + + with patch("linux_voice_assistant.satellite.WakeWord1SensitivityNumberEntity", MagicMock()), \ + patch("linux_voice_assistant.satellite.WakeWord2SensitivityNumberEntity", MagicMock()), \ + patch("linux_voice_assistant.satellite.StopWordSensitivityNumberEntity", MagicMock()): + from linux_voice_assistant.satellite import VoiceSatelliteProtocol + sat = VoiceSatelliteProtocol(state) + + assert sat.state.thinking_sound_enabled is True + + def test_output_only_sets_limited_features(self, tmp_path): + from aioesphomeapi.model import VoiceAssistantFeature + state = make_state(tmp_path) + state.output_only = True + + with patch("linux_voice_assistant.satellite.WakeWord1SensitivityNumberEntity", MagicMock()), \ + patch("linux_voice_assistant.satellite.WakeWord2SensitivityNumberEntity", MagicMock()), \ + patch("linux_voice_assistant.satellite.StopWordSensitivityNumberEntity", MagicMock()): + from linux_voice_assistant.satellite import VoiceSatelliteProtocol + sat = VoiceSatelliteProtocol(state) + + assert sat.supported_features & VoiceAssistantFeature.VOICE_ASSISTANT == 0 + + +# --------------------------------------------------------------------------- +# _set_muted() +# --------------------------------------------------------------------------- + + +class TestSetMuted: + def test_muting_sets_muted_flag(self, tmp_path): + sat = make_satellite(tmp_path) + sat._set_muted(True) + assert sat.state.muted is True + + def test_unmuting_clears_muted_flag(self, tmp_path): + sat = make_satellite(tmp_path) + sat.state.muted = True + sat._set_muted(False) + assert sat.state.muted is False + + def test_muting_stops_tts_player(self, tmp_path): + sat = make_satellite(tmp_path) + sat._set_muted(True) + sat.state.tts_player.stop.assert_called() + + def test_muting_plays_mute_sound(self, tmp_path): + sat = make_satellite(tmp_path) + sat._set_muted(True) + sat.state.tts_player.play.assert_called_with(sat.state.mute_sound) + + def test_unmuting_plays_unmute_sound(self, tmp_path): + sat = make_satellite(tmp_path) + sat._set_muted(False) + sat.state.tts_player.play.assert_called_with(sat.state.unmute_sound) + + def test_muting_stops_audio_streaming(self, tmp_path): + sat = make_satellite(tmp_path) + sat._is_streaming_audio = True + sat._set_muted(True) + assert sat._is_streaming_audio is False + + +# --------------------------------------------------------------------------- +# _set_thinking_sound_enabled() +# --------------------------------------------------------------------------- + + +class TestSetThinkingSoundEnabled: + def test_enables_thinking_sound(self, tmp_path): + sat = make_satellite(tmp_path) + sat._set_thinking_sound_enabled(True) + assert sat.state.thinking_sound_enabled is True + + def test_disables_thinking_sound(self, tmp_path): + sat = make_satellite(tmp_path) + sat.state.thinking_sound_enabled = True + sat._set_thinking_sound_enabled(False) + assert sat.state.thinking_sound_enabled is False + + def test_updates_preferences(self, tmp_path): + sat = make_satellite(tmp_path) + sat._set_thinking_sound_enabled(True) + assert sat.state.preferences.thinking_sound == 1 + + def test_saves_preferences_to_file(self, tmp_path): + sat = make_satellite(tmp_path) + sat._set_thinking_sound_enabled(True) + assert sat.state.preferences_path.exists() + + +# --------------------------------------------------------------------------- +# _set_sensitivity_1/2 and _set_stop_sensitivity +# --------------------------------------------------------------------------- + + +class TestSensitivitySetters: + def test_set_sensitivity_1_updates_threshold(self, tmp_path): + sat = make_satellite(tmp_path) + sat._set_sensitivity_1(0.85) + assert sat.state.wake_word_1_threshold == pytest.approx(0.85) + + def test_set_sensitivity_1_updates_preferences(self, tmp_path): + sat = make_satellite(tmp_path) + sat._set_sensitivity_1(0.85) + assert sat.state.preferences.wake_word_1_sensitivity == pytest.approx(0.85) + + def test_set_sensitivity_2_updates_threshold(self, tmp_path): + sat = make_satellite(tmp_path) + sat._set_sensitivity_2(0.6) + assert sat.state.wake_word_2_threshold == pytest.approx(0.6) + + def test_set_sensitivity_2_updates_preferences(self, tmp_path): + sat = make_satellite(tmp_path) + sat._set_sensitivity_2(0.6) + assert sat.state.preferences.wake_word_2_sensitivity == pytest.approx(0.6) + + def test_set_stop_sensitivity_updates_threshold(self, tmp_path): + sat = make_satellite(tmp_path) + sat._set_stop_sensitivity(0.5) + assert sat.state.stop_word_threshold == pytest.approx(0.5) + + def test_set_stop_sensitivity_updates_preferences(self, tmp_path): + sat = make_satellite(tmp_path) + sat._set_stop_sensitivity(0.5) + assert sat.state.preferences.stop_word_sensitivity == pytest.approx(0.5) + + def test_sensitivity_saves_preferences(self, tmp_path): + sat = make_satellite(tmp_path) + sat._set_sensitivity_1(0.9) + assert sat.state.preferences_path.exists() + + +# --------------------------------------------------------------------------- +# handle_audio() +# --------------------------------------------------------------------------- + + +class TestHandleAudio: + def test_does_not_send_when_not_streaming(self, tmp_path): + sat = make_satellite(tmp_path) + sat._is_streaming_audio = False + sat.handle_audio(b"\x00" * 320) + sat._writelines.assert_not_called() + + def test_does_not_send_when_muted(self, tmp_path): + sat = make_satellite(tmp_path) + sat._is_streaming_audio = True + sat.state.muted = True + sat.handle_audio(b"\x00" * 320) + sat._writelines.assert_not_called() + + def test_sends_when_streaming_and_not_muted(self, tmp_path): + sat = make_satellite(tmp_path) + sat._is_streaming_audio = True + sat.state.muted = False + sat._loop = None + sat.handle_audio(b"\x00" * 320) + sat._writelines.assert_called() + + +# --------------------------------------------------------------------------- +# play_tts() +# --------------------------------------------------------------------------- + + +class TestPlayTts: + def test_does_not_play_when_no_url(self, tmp_path): + sat = make_satellite(tmp_path) + sat._tts_url = None + sat.play_tts() + sat.state.tts_player.play.assert_not_called() + + def test_does_not_play_when_already_played(self, tmp_path): + sat = make_satellite(tmp_path) + sat._tts_url = "http://example.com/tts.mp3" + sat._tts_played = True + sat.play_tts() + sat.state.tts_player.play.assert_not_called() + + def test_plays_tts_url(self, tmp_path): + sat = make_satellite(tmp_path) + sat._tts_url = "http://example.com/tts.mp3" + sat._tts_played = False + sat.play_tts() + sat.state.tts_player.play.assert_called_once() + args, _ = sat.state.tts_player.play.call_args + assert args[0] == "http://example.com/tts.mp3" + + def test_sets_tts_played_flag(self, tmp_path): + sat = make_satellite(tmp_path) + sat._tts_url = "http://example.com/tts.mp3" + sat._tts_played = False + sat.play_tts() + assert sat._tts_played is True + + def test_adds_stop_word_to_active_wake_words(self, tmp_path): + sat = make_satellite(tmp_path) + sat._tts_url = "http://example.com/tts.mp3" + sat._tts_played = False + sat.play_tts() + assert sat.state.stop_word.id in sat.state.active_wake_words + + +# --------------------------------------------------------------------------- +# stop() +# --------------------------------------------------------------------------- + + +class TestStop: + def test_stop_clears_pipeline_active(self, tmp_path): + sat = make_satellite(tmp_path) + sat._pipeline_active = True + sat.stop() + assert sat._pipeline_active is False + + def test_stop_discards_stop_word_from_active(self, tmp_path): + sat = make_satellite(tmp_path) + sat.state.active_wake_words.add(sat.state.stop_word.id) + sat.stop() + assert sat.state.stop_word.id not in sat.state.active_wake_words + + def test_stop_calls_tts_player_stop(self, tmp_path): + sat = make_satellite(tmp_path) + sat._timer_finished = False + sat.stop() + sat.state.tts_player.stop.assert_called() + + +# --------------------------------------------------------------------------- +# duck() / unduck() +# --------------------------------------------------------------------------- + + +class TestDuckUnduck: + def test_duck_calls_music_player_duck(self, tmp_path): + sat = make_satellite(tmp_path) + sat.duck() + sat.state.music_player.duck.assert_called_once() + + def test_unduck_calls_music_player_unduck(self, tmp_path): + sat = make_satellite(tmp_path) + sat.unduck() + sat.state.music_player.unduck.assert_called_once() + + +# --------------------------------------------------------------------------- +# connection_lost() +# --------------------------------------------------------------------------- + + +class TestConnectionLost: + def test_connection_lost_clears_connected_flag(self, tmp_path): + sat = make_satellite(tmp_path) + sat.state.connected = True + sat.connection_lost(None) + assert sat.state.connected is False + + def test_connection_lost_clears_satellite_reference(self, tmp_path): + sat = make_satellite(tmp_path) + sat.connection_lost(None) + assert sat.state.satellite is None + + def test_connection_lost_stops_streaming(self, tmp_path): + sat = make_satellite(tmp_path) + sat._is_streaming_audio = True + sat.connection_lost(None) + assert sat._is_streaming_audio is False + + def test_connection_lost_clears_pipeline_active(self, tmp_path): + sat = make_satellite(tmp_path) + sat._pipeline_active = True + sat.connection_lost(None) + assert sat._pipeline_active is False + + def test_connection_lost_stops_music_player(self, tmp_path): + sat = make_satellite(tmp_path) + sat.connection_lost(None) + sat.state.music_player.stop.assert_called() + + def test_connection_lost_stops_tts_player(self, tmp_path): + sat = make_satellite(tmp_path) + sat.connection_lost(None) + sat.state.tts_player.stop.assert_called() \ No newline at end of file diff --git a/tests/unit/test_wake_word.py b/tests/unit/test_wake_word.py new file mode 100644 index 00000000..94be3ce2 --- /dev/null +++ b/tests/unit/test_wake_word.py @@ -0,0 +1,271 @@ +"""Unit tests for wake_word.py — find_available_wake_words, load_wake_models, load_stop_model.""" + +import json +import pytest +from pathlib import Path +from unittest.mock import MagicMock, patch + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def make_micro_json(wake_word: str = "okay nabu", probability_cutoff: float = 0.7) -> dict: + return { + "type": "micro", + "wake_word": wake_word, + "trained_languages": ["en"], + "micro": {"probability_cutoff": probability_cutoff}, + } + + +def make_oww_json(wake_word: str = "hey jarvis", model_file: str = "hey_jarvis.tflite") -> dict: + return { + "type": "openWakeWord", + "wake_word": wake_word, + "trained_languages": ["en"], + "model": model_file, + "openWakeWord": {"probability_cutoff": 0.5}, + } + + +def write_json(path: Path, data: dict) -> None: + path.write_text(json.dumps(data), encoding="utf-8") + + +# --------------------------------------------------------------------------- +# find_available_wake_words() +# --------------------------------------------------------------------------- + + +class TestFindAvailableWakeWords: + def test_finds_micro_wake_word(self, tmp_path): + write_json(tmp_path / "okay_nabu.json", make_micro_json()) + from linux_voice_assistant.wake_word import find_available_wake_words + result = find_available_wake_words([tmp_path], stop_model_id="stop") + assert "okay_nabu" in result + + def test_skips_stop_model(self, tmp_path): + write_json(tmp_path / "stop.json", make_micro_json("stop")) + write_json(tmp_path / "okay_nabu.json", make_micro_json()) + from linux_voice_assistant.wake_word import find_available_wake_words + result = find_available_wake_words([tmp_path], stop_model_id="stop") + assert "stop" not in result + assert "okay_nabu" in result + + def test_returns_empty_for_empty_directory(self, tmp_path): + from linux_voice_assistant.wake_word import find_available_wake_words + result = find_available_wake_words([tmp_path], stop_model_id="stop") + assert result == {} + + def test_returns_empty_for_nonexistent_directory(self, tmp_path): + from linux_voice_assistant.wake_word import find_available_wake_words + result = find_available_wake_words([tmp_path / "does_not_exist"], stop_model_id="stop") + assert result == {} + + def test_micro_wake_word_path_is_config_path(self, tmp_path): + write_json(tmp_path / "okay_nabu.json", make_micro_json()) + from linux_voice_assistant.wake_word import find_available_wake_words + result = find_available_wake_words([tmp_path], stop_model_id="stop") + assert result["okay_nabu"].wake_word_path == tmp_path / "okay_nabu.json" + + def test_oww_wake_word_path_is_model_file(self, tmp_path): + write_json(tmp_path / "hey_jarvis.json", make_oww_json(model_file="hey_jarvis.tflite")) + from linux_voice_assistant.wake_word import find_available_wake_words + result = find_available_wake_words([tmp_path], stop_model_id="stop") + assert result["hey_jarvis"].wake_word_path == tmp_path / "hey_jarvis.tflite" + + def test_probability_cutoff_loaded_from_config(self, tmp_path): + write_json(tmp_path / "okay_nabu.json", make_micro_json(probability_cutoff=0.85)) + from linux_voice_assistant.wake_word import find_available_wake_words + result = find_available_wake_words([tmp_path], stop_model_id="stop") + assert result["okay_nabu"].probability_cutoff == pytest.approx(0.85) + + def test_default_probability_cutoff_when_missing(self, tmp_path): + data = {"type": "micro", "wake_word": "okay nabu", "trained_languages": ["en"]} + write_json(tmp_path / "okay_nabu.json", data) + from linux_voice_assistant.wake_word import find_available_wake_words + result = find_available_wake_words([tmp_path], stop_model_id="stop") + assert result["okay_nabu"].probability_cutoff == pytest.approx(0.7) + + def test_trained_languages_stored(self, tmp_path): + data = make_micro_json() + data["trained_languages"] = ["en", "de"] + write_json(tmp_path / "okay_nabu.json", data) + from linux_voice_assistant.wake_word import find_available_wake_words + result = find_available_wake_words([tmp_path], stop_model_id="stop") + assert result["okay_nabu"].trained_languages == ["en", "de"] + + def test_searches_multiple_directories(self, tmp_path): + dir1 = tmp_path / "dir1" + dir2 = tmp_path / "dir2" + dir1.mkdir() + dir2.mkdir() + write_json(dir1 / "okay_nabu.json", make_micro_json()) + write_json(dir2 / "hey_jarvis.json", make_oww_json()) + from linux_voice_assistant.wake_word import find_available_wake_words + result = find_available_wake_words([dir1, dir2], stop_model_id="stop") + assert "okay_nabu" in result + assert "hey_jarvis" in result + + def test_wake_word_text_stored(self, tmp_path): + write_json(tmp_path / "okay_nabu.json", make_micro_json(wake_word="okay nabu")) + from linux_voice_assistant.wake_word import find_available_wake_words + result = find_available_wake_words([tmp_path], stop_model_id="stop") + assert result["okay_nabu"].wake_word == "okay nabu" + + def test_type_stored_correctly(self, tmp_path): + from linux_voice_assistant.models import WakeWordType + write_json(tmp_path / "okay_nabu.json", make_micro_json()) + from linux_voice_assistant.wake_word import find_available_wake_words + result = find_available_wake_words([tmp_path], stop_model_id="stop") + assert result["okay_nabu"].type == WakeWordType.MICRO_WAKE_WORD + + +# --------------------------------------------------------------------------- +# load_wake_models() +# --------------------------------------------------------------------------- + + +class TestLoadWakeModels: + def _make_available(self, tmp_path, model_id="okay_nabu"): + from linux_voice_assistant.models import AvailableWakeWord, WakeWordType + config_path = tmp_path / f"{model_id}.json" + write_json(config_path, make_micro_json()) + mock_model = MagicMock() + mock_model.id = model_id + available = MagicMock(spec=AvailableWakeWord) + available.id = model_id + available.type = WakeWordType.MICRO_WAKE_WORD + available.wake_word = "okay nabu" + available.trained_languages = ["en"] + available.wake_word_path = config_path + available.probability_cutoff = 0.7 + available.load.return_value = mock_model + return available, mock_model + + def test_loads_requested_wake_word(self, tmp_path): + available, mock_model = self._make_available(tmp_path) + from linux_voice_assistant.wake_word import load_wake_models + models, active, fallback = load_wake_models( + {"okay_nabu": available}, ["okay_nabu"], "okay_nabu" + ) + assert "okay_nabu" in models + assert "okay_nabu" in active + + def test_falls_back_to_default_when_no_active(self, tmp_path): + available, _ = self._make_available(tmp_path) + from linux_voice_assistant.wake_word import load_wake_models + models, active, fallback = load_wake_models( + {"okay_nabu": available}, [], "okay_nabu" + ) + assert "okay_nabu" in models + assert fallback is True + + def test_fallback_used_is_false_when_requested_loaded(self, tmp_path): + available, _ = self._make_available(tmp_path) + from linux_voice_assistant.wake_word import load_wake_models + _, _, fallback = load_wake_models( + {"okay_nabu": available}, ["okay_nabu"], "okay_nabu" + ) + assert fallback is False + + def test_skips_unknown_wake_word_id(self, tmp_path): + available, _ = self._make_available(tmp_path) + from linux_voice_assistant.wake_word import load_wake_models + models, active, _ = load_wake_models( + {"okay_nabu": available}, ["unknown_word"], "okay_nabu" + ) + assert "unknown_word" not in models + + def test_falls_back_to_okay_nabu_when_default_missing(self, tmp_path): + available, _ = self._make_available(tmp_path, "okay_nabu") + from linux_voice_assistant.wake_word import load_wake_models + models, active, _ = load_wake_models( + {"okay_nabu": available}, [], "nonexistent_default" + ) + assert "okay_nabu" in models + + def test_raises_when_no_wake_words_available(self): + from linux_voice_assistant.wake_word import load_wake_models + with pytest.raises(RuntimeError, match="No wake word models available"): + load_wake_models({}, [], "okay_nabu") + + def test_loads_multiple_requested_models(self, tmp_path): + available1, _ = self._make_available(tmp_path, "okay_nabu") + available2, _ = self._make_available(tmp_path, "hey_jarvis") + from linux_voice_assistant.wake_word import load_wake_models + models, active, _ = load_wake_models( + {"okay_nabu": available1, "hey_jarvis": available2}, + ["okay_nabu", "hey_jarvis"], + "okay_nabu", + ) + assert "okay_nabu" in models + assert "hey_jarvis" in models + + def test_active_set_matches_loaded_models(self, tmp_path): + available, _ = self._make_available(tmp_path) + from linux_voice_assistant.wake_word import load_wake_models + models, active, _ = load_wake_models( + {"okay_nabu": available}, ["okay_nabu"], "okay_nabu" + ) + assert set(models.keys()) == active + + def test_load_called_once_per_model(self, tmp_path): + available, _ = self._make_available(tmp_path) + from linux_voice_assistant.wake_word import load_wake_models + load_wake_models({"okay_nabu": available}, ["okay_nabu"], "okay_nabu") + available.load.assert_called_once() + + +# --------------------------------------------------------------------------- +# load_stop_model() +# --------------------------------------------------------------------------- + + +class TestLoadStopModel: + def test_returns_model_when_found(self, tmp_path): + write_json(tmp_path / "stop.json", make_micro_json("stop")) + mock_model = MagicMock() + with patch("linux_voice_assistant.wake_word.MicroWakeWord.from_config", return_value=mock_model): + from linux_voice_assistant.wake_word import load_stop_model + result = load_stop_model([tmp_path], "stop") + assert result is mock_model + + def test_returns_none_when_not_found(self, tmp_path): + from linux_voice_assistant.wake_word import load_stop_model + result = load_stop_model([tmp_path], "stop") + assert result is None + + def test_searches_multiple_directories(self, tmp_path): + dir1 = tmp_path / "dir1" + dir2 = tmp_path / "dir2" + dir1.mkdir() + dir2.mkdir() + write_json(dir2 / "stop.json", make_micro_json("stop")) + mock_model = MagicMock() + with patch("linux_voice_assistant.wake_word.MicroWakeWord.from_config", return_value=mock_model): + from linux_voice_assistant.wake_word import load_stop_model + result = load_stop_model([dir1, dir2], "stop") + assert result is mock_model + + def test_returns_none_when_load_fails(self, tmp_path): + write_json(tmp_path / "stop.json", make_micro_json("stop")) + with patch("linux_voice_assistant.wake_word.MicroWakeWord.from_config", side_effect=Exception("load error")): + from linux_voice_assistant.wake_word import load_stop_model + result = load_stop_model([tmp_path], "stop") + assert result is None + + def test_stops_at_first_found(self, tmp_path): + dir1 = tmp_path / "dir1" + dir2 = tmp_path / "dir2" + dir1.mkdir() + dir2.mkdir() + write_json(dir1 / "stop.json", make_micro_json("stop")) + write_json(dir2 / "stop.json", make_micro_json("stop")) + mock_model = MagicMock() + with patch("linux_voice_assistant.wake_word.MicroWakeWord.from_config", return_value=mock_model) as mock_load: + from linux_voice_assistant.wake_word import load_stop_model + load_stop_model([dir1, dir2], "stop") + mock_load.assert_called_once() \ No newline at end of file From c53f1a7e71472f4eaf5dc0a642382371800a6e35 Mon Sep 17 00:00:00 2001 From: Your Name Date: Sun, 26 Apr 2026 04:43:35 +0000 Subject: [PATCH 10/10] Update READMEs --- README.md | 2 +- tests/readme.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 086488a7..28750a6e 100644 --- a/README.md +++ b/README.md @@ -153,7 +153,7 @@ Auto-fix formatting issues (Black + isort): Run the test suite: ``` sh -./script/test +./script/tests ``` ## License diff --git a/tests/readme.md b/tests/readme.md index c9cc73a4..8f0fe397 100644 --- a/tests/readme.md +++ b/tests/readme.md @@ -1,7 +1,7 @@ # Automated testing From the root of LVA run:- -```pytest tests/unit``` +```./script/tests``` # Manual testing