diff --git a/tests/test_axon_handlers.py b/tests/test_axon_handlers.py index fc487ac..ae5143d 100644 --- a/tests/test_axon_handlers.py +++ b/tests/test_axon_handlers.py @@ -1,10 +1,10 @@ -"""Tests for validator axon_handlers.handle_swap_confirm. +"""Tests for allways.validator.axon_handlers. -Covers every rejection branch plus the queued-confirmation path. The -vote_initiate success path is not unit-tested here — it requires mocking -extrinsic submission and is exercised end-to-end in integration testing. -These tests focus on the validation layer, which is the security-critical -surface users and miners can reach directly via the axon. +Covers the easy-to-isolate pure helpers (hashing, SCALE encoders, direction +resolution, the synapse rejection helper, blacklist/priority coroutines) and +the validation layer of handle_swap_confirm — the security-critical surface +users and miners can reach directly via the axon. The vote_initiate success +path is exercised end-to-end in integration testing, not here. """ import asyncio @@ -15,7 +15,293 @@ from allways.classes import MinerPair from allways.contract_client import ContractError from allways.synapses import SwapConfirmSynapse -from allways.validator.axon_handlers import handle_swap_confirm +from allways.validator.axon_handlers import ( + blacklist_miner_activate, + blacklist_swap_confirm, + blacklist_swap_reserve, + handle_swap_confirm, + keccak256, + priority_miner_activate, + priority_swap_confirm, + priority_swap_reserve, + reject_synapse, + resolve_swap_direction, + scale_encode_extend_hash_input, + scale_encode_initiate_hash_input, + scale_encode_reserve_hash_input, +) + + +def _run(coro): + return asyncio.new_event_loop().run_until_complete(coro) + + +def _make_pair( + from_chain: str = 'btc', + to_chain: str = 'tao', + rate: float = 350.0, + counter_rate: float = 0.0, +) -> MinerPair: + return MinerPair( + uid=1, + hotkey='5Fminer', + from_chain=from_chain, + from_address='bc1qminer', + to_chain=to_chain, + to_address='5Fminer_dest', + rate=rate, + rate_str=str(rate), + counter_rate=counter_rate, + counter_rate_str=str(counter_rate) if counter_rate else '', + ) + + +class TestKeccak256: + def test_empty_input(self): + # Known Keccak-256 of empty string (ethereum convention) — guards + # against accidental swap to SHA3-256 (different IV/padding). + expected = bytes.fromhex('c5d2460186f7233c927e7db2dcc703c0e500b653ca82273b7bfad8045d85a470') + assert keccak256(b'') == expected + + +class TestScaleEncodeReserveHashInput: + def test_structure_lengths(self): + miner_bytes = b'\x01' * 32 + from_addr = b'bc1qminer' + encoded = scale_encode_reserve_hash_input( + miner_bytes=miner_bytes, + from_addr_bytes=from_addr, + from_chain='btc', + to_chain='tao', + tao_amount=1_000, + from_amount=2_000, + to_amount=3_000, + ) + # Expect: 32 (AccountId) + 1+len(from_addr) + 1+3 (btc) + 1+3 (tao) + 16+16+16 (u128s) + expected_len = 32 + (1 + len(from_addr)) + (1 + 3) + (1 + 3) + 16 * 3 + assert len(encoded) == expected_len + + def test_miner_bytes_prefix(self): + miner_bytes = b'\xaa' * 32 + encoded = scale_encode_reserve_hash_input( + miner_bytes=miner_bytes, + from_addr_bytes=b'x', + from_chain='btc', + to_chain='tao', + tao_amount=0, + from_amount=0, + to_amount=0, + ) + assert encoded[:32] == miner_bytes + + def test_u128_suffix_little_endian(self): + encoded = scale_encode_reserve_hash_input( + miner_bytes=b'\x00' * 32, + from_addr_bytes=b'', + from_chain='', + to_chain='', + tao_amount=1, + from_amount=2, + to_amount=3, + ) + # Last 48 bytes = three u128s + assert encoded[-48:-32] == (1).to_bytes(16, 'little') + assert encoded[-32:-16] == (2).to_bytes(16, 'little') + assert encoded[-16:] == (3).to_bytes(16, 'little') + + +class TestScaleEncodeExtendHashInput: + def test_includes_miner_and_tx(self): + miner_bytes = b'\x02' * 32 + encoded = scale_encode_extend_hash_input(miner_bytes, 'deadbeef') + assert encoded[:32] == miner_bytes + assert encoded[-len(b'deadbeef'):] == b'deadbeef' + + def test_empty_tx_hash(self): + encoded = scale_encode_extend_hash_input(b'\x00' * 32, '') + # 32 + 1 (compact zero length) = 33 + assert len(encoded) == 33 + + +class TestScaleEncodeInitiateHashInput: + def test_contains_all_string_fields(self): + encoded = scale_encode_initiate_hash_input( + miner_bytes=b'\x03' * 32, + from_tx_hash='abcd', + from_chain='btc', + to_chain='tao', + miner_from_address='bc1qminer', + miner_to_address='5Fdest', + rate='350', + tao_amount=1, + from_amount=2, + to_amount=3, + ) + for fragment in (b'abcd', b'btc', b'tao', b'bc1qminer', b'5Fdest', b'350'): + assert fragment in encoded + + def test_amounts_are_final_48_bytes(self): + encoded = scale_encode_initiate_hash_input( + miner_bytes=b'\x00' * 32, + from_tx_hash='', + from_chain='', + to_chain='', + miner_from_address='', + miner_to_address='', + rate='', + tao_amount=10, + from_amount=20, + to_amount=30, + ) + assert encoded[-48:-32] == (10).to_bytes(16, 'little') + assert encoded[-32:-16] == (20).to_bytes(16, 'little') + assert encoded[-16:] == (30).to_bytes(16, 'little') + + +class TestResolveSwapDirection: + def test_canonical_direction_returns_from_address_as_deposit(self): + pair = _make_pair(from_chain='btc', to_chain='tao', rate=350.0) + result = resolve_swap_direction(pair, 'btc', 'tao') + assert result is not None + from_chain, to_chain, deposit, fulfillment, rate, rate_str = result + assert from_chain == 'btc' + assert to_chain == 'tao' + assert deposit == 'bc1qminer' + assert fulfillment == '5Fminer_dest' + assert rate == 350.0 + + def test_reverse_direction_swaps_addresses(self): + pair = _make_pair(from_chain='btc', to_chain='tao', rate=350.0, counter_rate=0.003) + result = resolve_swap_direction(pair, 'tao', 'btc') + assert result is not None + _, _, deposit, fulfillment, rate, _ = result + assert deposit == '5Fminer_dest' + assert fulfillment == 'bc1qminer' + assert rate == 0.003 + + def test_zero_rate_returns_none(self): + pair = _make_pair(rate=0.0) + assert resolve_swap_direction(pair, 'btc', 'tao') is None + + def test_negative_rate_returns_none(self): + pair = _make_pair(rate=-1.0) + assert resolve_swap_direction(pair, 'btc', 'tao') is None + + def test_empty_synapse_chains_fall_back_to_commitment(self): + pair = _make_pair(from_chain='btc', to_chain='tao', rate=350.0) + result = resolve_swap_direction(pair, '', '') + assert result is not None + assert result[0] == 'btc' + assert result[1] == 'tao' + + +class TestRejectSynapse: + def test_sets_accepted_false_and_reason(self): + synapse = MagicMock() + reject_synapse(synapse, 'bad input') + assert synapse.accepted is False + assert synapse.rejection_reason == 'bad input' + + def test_no_context_no_log_error(self): + synapse = MagicMock() + with patch('allways.validator.axon_handlers.bt.logging.debug') as dbg: + reject_synapse(synapse, 'why', context='') + dbg.assert_not_called() + + def test_with_context_logs_debug(self): + synapse = MagicMock() + with patch('allways.validator.axon_handlers.bt.logging.debug') as dbg: + reject_synapse(synapse, 'reason', context='SomeSynapse(x)') + dbg.assert_called_once_with('SomeSynapse(x): reason') + + +class TestBlacklistMinerActivate: + def _validator(self, hotkeys): + v = MagicMock() + v.metagraph.hotkeys = hotkeys + return v + + def test_missing_dendrite_blacklisted(self): + validator = self._validator(['5Fminer']) + synapse = MagicMock() + synapse.dendrite = None + blocked, reason = _run(blacklist_miner_activate(validator, synapse)) + assert blocked is True + assert 'dendrite' in reason.lower() or 'hotkey' in reason.lower() + + def test_missing_hotkey_blacklisted(self): + validator = self._validator(['5Fminer']) + synapse = MagicMock() + synapse.dendrite = MagicMock() + synapse.dendrite.hotkey = None + blocked, _ = _run(blacklist_miner_activate(validator, synapse)) + assert blocked is True + + def test_unregistered_hotkey_blacklisted(self): + validator = self._validator(['5Fminer']) + synapse = MagicMock() + synapse.dendrite.hotkey = '5Funknown' + blocked, reason = _run(blacklist_miner_activate(validator, synapse)) + assert blocked is True + assert 'unregistered' in reason.lower() + + def test_registered_hotkey_allowed(self): + validator = self._validator(['5Fminer']) + synapse = MagicMock() + synapse.dendrite.hotkey = '5Fminer' + blocked, _ = _run(blacklist_miner_activate(validator, synapse)) + assert blocked is False + + +class TestBlacklistSwapReserve: + def test_pass_through_any_hotkey(self): + # Pass-through by design — field checks happen later in handle_swap_reserve + validator = MagicMock() + synapse = MagicMock() + blocked, reason = _run(blacklist_swap_reserve(validator, synapse)) + assert blocked is False + assert reason == 'Passed' + + +class TestBlacklistSwapConfirm: + def test_pass_through_any_hotkey(self): + validator = MagicMock() + synapse = MagicMock() + blocked, reason = _run(blacklist_swap_confirm(validator, synapse)) + assert blocked is False + assert reason == 'Passed' + + +class TestPriorityFunctions: + def _validator(self, hotkeys, stakes): + v = MagicMock() + v.metagraph.hotkeys = hotkeys + v.metagraph.S = stakes + return v + + def test_miner_activate_returns_stake(self): + validator = self._validator(['5Fa', '5Fb'], [100.0, 250.0]) + synapse = MagicMock() + synapse.dendrite.hotkey = '5Fb' + assert _run(priority_miner_activate(validator, synapse)) == 250.0 + + def test_miner_activate_unknown_hotkey_returns_zero(self): + validator = self._validator(['5Fa'], [100.0]) + synapse = MagicMock() + synapse.dendrite.hotkey = '5Funknown' + assert _run(priority_miner_activate(validator, synapse)) == 0.0 + + def test_swap_reserve_flat_priority(self): + # User-facing synapses use a flat priority + assert _run(priority_swap_reserve(MagicMock(), MagicMock())) == 1.0 + + def test_swap_confirm_flat_priority(self): + assert _run(priority_swap_confirm(MagicMock(), MagicMock())) == 1.0 + + +# --------------------------------------------------------------------------- +# handle_swap_confirm: end-to-end validation layer +# --------------------------------------------------------------------------- def make_synapse( @@ -79,14 +365,14 @@ def make_validator( *, block: int = 1000, reserved_until: int = 2000, - reservation_data: tuple | None = (0, 345_000_000, 100_000, 345_000_000), + reservation_data: tuple | None = (345_000_000, 100_000, 345_000_000), providers: dict | None = None, ) -> MagicMock: """Build a Validator mock with default-happy contract/chain state. Individual tests override specific attributes to simulate each branch. reservation_data tuple mirrors the on-chain layout used by - handle_swap_confirm: (_, tao_amount, source_amount, dest_amount). + handle_swap_confirm: (tao_amount, source_amount, dest_amount). """ validator = MagicMock() validator.block = block @@ -299,7 +585,7 @@ def test_queued_entry_uses_reservation_amounts(self): """The contract-reserved amounts are authoritative. A queued entry must persist those, not any user-supplied value, so the later auto-initiate hashes match what the miner was reserved under.""" - validator = make_validator(reservation_data=(0, 777_000_000, 55_000, 999_000_000)) + validator = make_validator(reservation_data=(777_000_000, 55_000, 999_000_000)) validator.axon_chain_providers['btc'].verify_transaction.return_value = make_tx_info( confirmed=False, confirmations=1, diff --git a/tests/test_bitcoin_signing.py b/tests/test_bitcoin_signing.py index 14da5ea..daeee7d 100644 --- a/tests/test_bitcoin_signing.py +++ b/tests/test_bitcoin_signing.py @@ -12,7 +12,10 @@ ADDR_TYPE_P2WPKH, BitcoinProvider, detect_address_type, + to_mainnet_address, + to_mainnet_wif, ) +from allways.chains import CHAIN_BTC # Known test WIF (compressed) TEST_WIF = 'L1RrrnXkcKut5DEMwtDthjwRcTTwED36thyL1DebVrKuwvohjMNi' @@ -223,3 +226,59 @@ def test_regtest_address_converted_for_verification(self): # not crash, because no valid signature binds to it. result = provider.verify_from_proof('bcrt1qtestnettestaddresstestaddresstestaddr', TEST_MESSAGE, 'AAAA') assert result is False + + +class TestToMainnetWif: + def test_mainnet_wif_unchanged(self): + assert to_mainnet_wif(TEST_WIF) == TEST_WIF + + def test_testnet_wif_converted(self): + import base58 + decoded = base58.b58decode_check(TEST_WIF) + testnet_wif = base58.b58encode_check(bytes([0xEF]) + decoded[1:]).decode() + assert to_mainnet_wif(testnet_wif) == TEST_WIF + + +class TestToMainnetAddress: + def test_mainnet_address_unchanged(self): + addr = 'bc1q6tvmnmetj8vfz98vuetpvtuplqtj4uvvwjgxxc' + assert to_mainnet_address(addr) == addr + + def test_regtest_bech32_converted_to_bc_prefix(self): + converted = to_mainnet_address('bcrt1q6tvmnmetj8vfz98vuetpvtuplqtj4uvvtest9j') + assert converted.startswith('bc1') or converted.startswith('bcrt1') + + def test_unknown_prefix_unchanged(self): + assert to_mainnet_address('4unknownprefix') == '4unknownprefix' + + +class TestBitcoinProviderInit: + def test_rejects_invalid_mode(self): + import pytest + with patch.dict(os.environ, {'BTC_MODE': 'bogus'}, clear=False): + with pytest.raises(ValueError, match='BTC_MODE'): + BitcoinProvider() + + def test_node_mode_infers_testnet_from_port(self): + env = {'BTC_MODE': 'node', 'BTC_RPC_URL': 'http://localhost:18332'} + with patch.dict(os.environ, env, clear=True): + provider = BitcoinProvider() + assert provider.network == 'testnet' + assert provider.mode == 'node' + + def test_node_mode_defaults_to_mainnet(self): + env = {'BTC_MODE': 'node', 'BTC_RPC_URL': 'http://localhost:8332'} + with patch.dict(os.environ, env, clear=True): + provider = BitcoinProvider() + assert provider.network == 'mainnet' + + def test_lightweight_mode_defaults_to_mainnet(self): + with patch.dict(os.environ, {'BTC_MODE': 'lightweight'}, clear=True): + provider = BitcoinProvider() + assert provider.network == 'mainnet' + assert provider.rpc_url == '' + + def test_get_chain_returns_btc(self): + with patch.dict(os.environ, {'BTC_MODE': 'lightweight'}, clear=True): + provider = BitcoinProvider() + assert provider.get_chain() is CHAIN_BTC diff --git a/tests/test_chain_providers_base.py b/tests/test_chain_providers_base.py new file mode 100644 index 0000000..57b5bcd --- /dev/null +++ b/tests/test_chain_providers_base.py @@ -0,0 +1,109 @@ +"""Tests for ChainProvider.verify_transaction — shared post-fetch logic.""" + +from typing import Any, Optional, Tuple +from unittest.mock import MagicMock + +from allways.chain_providers.base import ChainProvider, TransactionInfo +from allways.chains import ChainDefinition + + +_TEST_CHAIN = ChainDefinition( + id='btc', + name='Bitcoin', + native_unit='sat', + decimals=8, + env_prefix='BTC', + min_confirmations=3, +) + + +class _FakeProvider(ChainProvider): + def __init__(self, tx: Optional[TransactionInfo] = None): + self._tx = tx + + def get_chain(self) -> ChainDefinition: + return _TEST_CHAIN + + def check_connection(self, **kwargs) -> None: + return None + + def fetch_matching_tx( + self, tx_hash: str, expected_recipient: str, expected_amount: int, block_hint: int = 0 + ) -> Optional[TransactionInfo]: + return self._tx + + def get_balance(self, address: str) -> int: + return 0 + + def is_valid_address(self, address: str) -> bool: + return True + + def sign_from_proof(self, address: str, message: str, key: Optional[Any] = None) -> str: + return '' + + def verify_from_proof(self, address: str, message: str, signature: str) -> bool: + return True + + def send_amount( + self, to_address: str, amount: int, from_address: Optional[str] = None + ) -> Optional[Tuple[str, int]]: + return None + + +def _tx(**overrides) -> TransactionInfo: + defaults = dict( + tx_hash='deadbeef', + confirmed=True, + sender='bc1qsender', + recipient='bc1qrec', + amount=1000, + confirmations=5, + ) + defaults.update(overrides) + return TransactionInfo(**defaults) + + +class TestVerifyTransaction: + def test_none_fetch_returns_none(self): + p = _FakeProvider(tx=None) + assert p.verify_transaction('tx', 'bc1qrec', 1000) is None + + def test_confirmed_passes(self): + p = _FakeProvider(tx=_tx()) + result = p.verify_transaction('tx', 'bc1qrec', 1000) + assert result is not None + assert result.confirmed + + def test_require_confirmed_rejects_unconfirmed(self): + p = _FakeProvider(tx=_tx(confirmed=False, confirmations=1)) + assert p.verify_transaction('tx', 'bc1qrec', 1000, require_confirmed=True) is None + + def test_require_confirmed_accepts_confirmed(self): + p = _FakeProvider(tx=_tx()) + result = p.verify_transaction('tx', 'bc1qrec', 1000, require_confirmed=True) + assert result is not None + + def test_expected_sender_match(self): + p = _FakeProvider(tx=_tx(sender='bc1qsender')) + result = p.verify_transaction('tx', 'bc1qrec', 1000, expected_sender='bc1qsender') + assert result is not None + + def test_expected_sender_mismatch_returns_none(self): + p = _FakeProvider(tx=_tx(sender='bc1qother')) + assert p.verify_transaction('tx', 'bc1qrec', 1000, expected_sender='bc1qsender') is None + + def test_expected_sender_empty_mismatch_rejected(self): + p = _FakeProvider(tx=_tx(sender='')) + assert p.verify_transaction('tx', 'bc1qrec', 1000, expected_sender='bc1qsender') is None + + def test_block_hint_passed_to_fetch(self): + p = _FakeProvider(tx=_tx()) + p.fetch_matching_tx = MagicMock(return_value=_tx()) + p.verify_transaction('tx', 'bc1qrec', 1000, block_hint=42) + _, kwargs = p.fetch_matching_tx.call_args + assert kwargs['block_hint'] == 42 + + def test_no_sender_check_accepts_empty_sender(self): + p = _FakeProvider(tx=_tx(sender='')) + result = p.verify_transaction('tx', 'bc1qrec', 1000) + assert result is not None diff --git a/tests/test_chain_providers_init.py b/tests/test_chain_providers_init.py new file mode 100644 index 0000000..c332785 --- /dev/null +++ b/tests/test_chain_providers_init.py @@ -0,0 +1,108 @@ +"""Tests for allways.chain_providers.create_chain_providers registry factory.""" + +from unittest.mock import MagicMock, patch + +import pytest + +from allways.chain_providers import create_chain_providers + + +class _FakeProvider: + """Test provider that records constructor kwargs and controls check behavior.""" + + instances: list = [] + + def __init__(self, **kwargs): + self.kwargs = kwargs + self.checked_require_send = None + _FakeProvider.instances.append(self) + + def check_connection(self, require_send: bool = True) -> None: + self.checked_require_send = require_send + + +class _FailingProvider: + def __init__(self, **kwargs): + raise RuntimeError('init failure') + + +class _CheckFailingProvider: + def __init__(self, **kwargs): + self.kwargs = kwargs + + def check_connection(self, require_send: bool = True) -> None: + raise RuntimeError('check failure') + + +@pytest.fixture(autouse=True) +def reset_fake(): + _FakeProvider.instances = [] + + +class TestRegistryInstantiation: + def test_instantiates_each_registered_provider(self): + reg = ( + ('chain-a', _FakeProvider, ()), + ('chain-b', _FakeProvider, ()), + ) + with patch('allways.chain_providers.PROVIDER_REGISTRY', reg): + providers = create_chain_providers() + assert set(providers.keys()) == {'chain-a', 'chain-b'} + + def test_forwards_only_declared_kwargs(self): + reg = (('chain-a', _FakeProvider, ('subtensor',)),) + with patch('allways.chain_providers.PROVIDER_REGISTRY', reg): + subtensor = MagicMock() + wallet = MagicMock() + providers = create_chain_providers(subtensor=subtensor, wallet=wallet) + assert providers['chain-a'].kwargs == {'subtensor': subtensor} + + def test_missing_kwarg_not_passed(self): + reg = (('chain-a', _FakeProvider, ('subtensor',)),) + with patch('allways.chain_providers.PROVIDER_REGISTRY', reg): + providers = create_chain_providers() + assert providers['chain-a'].kwargs == {} + + +class TestCheckConnection: + def test_check_true_invokes_check_connection(self): + reg = (('chain-a', _FakeProvider, ()),) + with patch('allways.chain_providers.PROVIDER_REGISTRY', reg): + providers = create_chain_providers(check=True) + assert providers['chain-a'].checked_require_send is True + + def test_require_send_false_propagates(self): + reg = (('chain-a', _FakeProvider, ()),) + with patch('allways.chain_providers.PROVIDER_REGISTRY', reg): + providers = create_chain_providers(check=True, require_send=False) + assert providers['chain-a'].checked_require_send is False + + def test_check_false_skips_check(self): + reg = (('chain-a', _FakeProvider, ()),) + with patch('allways.chain_providers.PROVIDER_REGISTRY', reg): + providers = create_chain_providers(check=False) + assert providers['chain-a'].checked_require_send is None + + +class TestFailureHandling: + def test_check_true_raises_on_init_failure(self): + reg = (('chain-a', _FailingProvider, ()),) + with patch('allways.chain_providers.PROVIDER_REGISTRY', reg): + with pytest.raises(RuntimeError, match='failed startup check'): + create_chain_providers(check=True) + + def test_check_true_raises_on_check_failure(self): + reg = (('chain-a', _CheckFailingProvider, ()),) + with patch('allways.chain_providers.PROVIDER_REGISTRY', reg): + with pytest.raises(RuntimeError, match='failed startup check'): + create_chain_providers(check=True) + + def test_check_false_swallows_init_failure(self): + reg = ( + ('chain-a', _FailingProvider, ()), + ('chain-b', _FakeProvider, ()), + ) + with patch('allways.chain_providers.PROVIDER_REGISTRY', reg): + providers = create_chain_providers(check=False) + assert 'chain-a' not in providers + assert 'chain-b' in providers diff --git a/tests/test_config.py b/tests/test_config.py new file mode 100644 index 0000000..f36949f --- /dev/null +++ b/tests/test_config.py @@ -0,0 +1,110 @@ +"""Tests for allways.utils.config — argparse wiring + check_config.""" + +import argparse +import os +import tempfile +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +from allways.utils.config import ( + add_args, + add_miner_args, + add_validator_args, + check_config, + config, +) + + +def _parse(add_fn, argv=None): + parser = argparse.ArgumentParser() + add_fn(None, parser) + return parser.parse_args(argv or []) + + +class TestAddArgs: + def test_override_netuid(self): + args = _parse(add_args, ['--netuid', '42']) + assert args.netuid == 42 + + def test_dont_save_events_flag(self): + args = _parse(add_args, ['--neuron.dont_save_events']) + assert getattr(args, 'neuron.dont_save_events') is True + + +class TestAddMinerArgs: + def test_override_miner_name(self): + args = _parse(add_miner_args, ['--neuron.name', 'myminer']) + assert getattr(args, 'neuron.name') == 'myminer' + + def test_poll_interval_override(self): + args = _parse(add_miner_args, ['--miner.poll_interval', '5']) + assert getattr(args, 'miner.poll_interval') == 5 + + +class TestAddValidatorArgs: + def test_disable_set_weights_flag(self): + args = _parse(add_validator_args, ['--neuron.disable_set_weights']) + assert getattr(args, 'neuron.disable_set_weights') is True + + def test_axon_off_flag(self): + args = _parse(add_validator_args, ['--axon_off']) + assert getattr(args, 'neuron.axon_off') is True + + def test_moving_average_alpha_override(self): + args = _parse(add_validator_args, ['--neuron.moving_average_alpha', '0.1']) + assert getattr(args, 'neuron.moving_average_alpha') == 0.1 + + +class TestCheckConfig: + def _make_config(self, tmp_dir: str, dont_save_events=True): + return SimpleNamespace( + logging=SimpleNamespace(logging_dir=tmp_dir), + wallet=SimpleNamespace(name='default', hotkey='default'), + netuid=7, + neuron=SimpleNamespace( + name='validator', + full_path='', + dont_save_events=dont_save_events, + events_retention_size=1024, + ), + ) + + def test_creates_full_path(self): + with tempfile.TemporaryDirectory() as tmp: + cfg = self._make_config(tmp) + with patch('bittensor.logging.check_config'): + check_config(None, cfg) + assert os.path.exists(cfg.neuron.full_path) + + def test_dont_save_events_skips_logger(self): + with tempfile.TemporaryDirectory() as tmp: + cfg = self._make_config(tmp, dont_save_events=True) + with patch('bittensor.logging.check_config'), \ + patch('allways.utils.config.setup_events_logger') as setup: + check_config(None, cfg) + setup.assert_not_called() + + def test_save_events_registers_logger(self): + with tempfile.TemporaryDirectory() as tmp: + cfg = self._make_config(tmp, dont_save_events=False) + logger = MagicMock() + logger.name = 'events' + with patch('bittensor.logging.check_config'), \ + patch('allways.utils.config.setup_events_logger', return_value=logger), \ + patch('bittensor.logging.register_primary_logger') as reg: + check_config(None, cfg) + reg.assert_called_once_with('events') + + +class TestConfig: + def test_builds_config_with_cls_add_args(self): + cls = MagicMock() + with patch('bittensor.Wallet.add_args'), \ + patch('bittensor.Subtensor.add_args'), \ + patch('bittensor.logging.add_args'), \ + patch('bittensor.Axon.add_args'), \ + patch('bittensor.Config', return_value='cfg') as bt_config: + result = config(cls) + assert result == 'cfg' + cls.add_args.assert_called_once() + bt_config.assert_called_once() diff --git a/tests/test_logging.py b/tests/test_logging.py new file mode 100644 index 0000000..f9ad962 --- /dev/null +++ b/tests/test_logging.py @@ -0,0 +1,46 @@ +"""Tests for allways.utils.logging.""" + +import os +import tempfile +from unittest.mock import patch + +from allways.utils import logging as events_logging +from allways.utils.logging import EVENTS_LEVEL_NUM, log_on_change, setup_events_logger + + +class TestSetupEventsLogger: + def test_creates_events_log_file(self): + with tempfile.TemporaryDirectory() as tmp: + logger = setup_events_logger(tmp, events_retention_size=1024) + logger.log(EVENTS_LEVEL_NUM, 'hello') + for h in logger.handlers: + h.flush() + assert os.path.exists(os.path.join(tmp, 'events.log')) + + +class TestLogOnChange: + def setup_method(self): + events_logging._last_seen.clear() + + def test_logs_first_time(self): + with patch('allways.utils.logging.bt.logging.info') as info: + log_on_change('k', 'v1', 'first') + info.assert_called_once_with('first') + + def test_suppresses_unchanged_value(self): + with patch('allways.utils.logging.bt.logging.info') as info: + log_on_change('k', 'v1', 'first') + log_on_change('k', 'v1', 'second') + assert info.call_count == 1 + + def test_logs_when_value_changes(self): + with patch('allways.utils.logging.bt.logging.info') as info: + log_on_change('k', 'v1', 'first') + log_on_change('k', 'v2', 'second') + assert info.call_count == 2 + + def test_independent_keys(self): + with patch('allways.utils.logging.bt.logging.info') as info: + log_on_change('a', 1, 'a-msg') + log_on_change('b', 1, 'b-msg') + assert info.call_count == 2 diff --git a/tests/test_misc.py b/tests/test_misc.py new file mode 100644 index 0000000..091181c --- /dev/null +++ b/tests/test_misc.py @@ -0,0 +1,78 @@ +"""Tests for allways.utils.misc.""" + +import time +from unittest.mock import MagicMock + +from allways.utils.misc import ttl_cache, ttl_get_block, ttl_hash_gen + + +class TestTtlHashGen: + def test_yields_same_value_within_window(self): + gen = ttl_hash_gen(seconds=3600) + a = next(gen) + b = next(gen) + assert a == b + + def test_short_ttl_eventually_changes(self): + gen = ttl_hash_gen(seconds=1) + first = next(gen) + time.sleep(1.1) + assert next(gen) != first + + +class TestTtlCache: + def test_caches_repeated_calls(self): + calls = [] + + @ttl_cache(maxsize=8, ttl=60) + def f(x): + calls.append(x) + return x * 2 + + assert f(3) == 6 + assert f(3) == 6 + assert calls == [3] + + def test_different_args_miss_cache(self): + calls = [] + + @ttl_cache(maxsize=8, ttl=60) + def f(x): + calls.append(x) + return x + 1 + + f(1) + f(2) + assert calls == [1, 2] + + def test_negative_ttl_uses_default(self): + # ttl <= 0 → substituted with 65536 internally + @ttl_cache(ttl=-1) + def f(x): + return x + + assert f(5) == 5 + assert f(5) == 5 + + def test_expired_entry_recomputed(self): + calls = [] + + @ttl_cache(maxsize=8, ttl=1) + def f(x): + calls.append(x) + return x + + f(7) + time.sleep(1.1) + f(7) + assert calls == [7, 7] + + +class TestTtlGetBlock: + def test_cached_within_ttl(self): + obj = MagicMock() + obj.subtensor.get_current_block.side_effect = [100, 200, 300] + first = ttl_get_block(obj) + second = ttl_get_block(obj) + # maxsize=1 cache keyed on (ttl_hash, self); same obj within window → hit + assert first == second diff --git a/tests/test_subtensor.py b/tests/test_subtensor.py new file mode 100644 index 0000000..edef30c --- /dev/null +++ b/tests/test_subtensor.py @@ -0,0 +1,31 @@ +"""Tests for SubtensorProvider basics (connection, cache).""" + +from unittest.mock import MagicMock + +import pytest + +from allways.chain_providers.subtensor import SubtensorProvider + + +class TestProviderBasics: + def test_get_chain_returns_tao(self): + from allways.chains import CHAIN_TAO + assert SubtensorProvider(MagicMock()).get_chain() is CHAIN_TAO + + def test_check_connection_success(self): + subtensor = MagicMock() + subtensor.get_current_block.return_value = 12345 + SubtensorProvider(subtensor).check_connection() + subtensor.get_current_block.assert_called_once() + + def test_check_connection_raises_on_failure(self): + subtensor = MagicMock() + subtensor.get_current_block.side_effect = RuntimeError('down') + with pytest.raises(ConnectionError, match='Cannot reach Subtensor'): + SubtensorProvider(subtensor).check_connection() + + def test_clear_cache(self): + provider = SubtensorProvider(MagicMock()) + provider.block_cache[1] = {'data': 'x'} + provider.clear_cache() + assert provider.block_cache == {} diff --git a/tests/test_swap_poller.py b/tests/test_swap_poller.py new file mode 100644 index 0000000..62ec3c7 --- /dev/null +++ b/tests/test_swap_poller.py @@ -0,0 +1,168 @@ +"""Tests for allways.miner.swap_poller.SwapPoller.""" + +from unittest.mock import MagicMock + +from allways.classes import Swap, SwapStatus +from allways.miner.swap_poller import SwapPoller + + +MINER_HK = 'miner-hk' +OTHER_HK = 'other-hk' + + +def make_swap( + swap_id: int, + miner_hotkey: str = MINER_HK, + status: SwapStatus = SwapStatus.ACTIVE, +) -> Swap: + return Swap( + id=swap_id, + user_hotkey='user', + miner_hotkey=miner_hotkey, + from_chain='btc', + to_chain='tao', + from_amount=1_000_000, + to_amount=345_000_000, + tao_amount=345_000_000, + user_from_address='bc1q-user', + user_to_address='5user', + miner_from_address='bc1q-miner', + rate='345', + status=status, + initiated_block=100, + timeout_block=500, + ) + + +def make_poller(next_id: int = 1, swaps_by_id=None): + client = MagicMock() + client.get_next_swap_id.return_value = next_id + swaps_by_id = swaps_by_id or {} + client.get_swap.side_effect = lambda sid: swaps_by_id.get(sid) + return SwapPoller(client, MINER_HK) + + +class TestPollInitialState: + def test_empty_contract(self): + poller = make_poller(next_id=1) + active, fulfilled = poller.poll() + assert active == [] + assert fulfilled == [] + assert poller.last_poll_ok is True + + def test_no_swaps_for_this_miner(self): + swaps = {1: make_swap(1, miner_hotkey=OTHER_HK)} + poller = make_poller(next_id=2, swaps_by_id=swaps) + active, fulfilled = poller.poll() + assert active == [] + assert fulfilled == [] + + +class TestPollDiscovery: + def test_finds_active_swap_assigned_to_miner(self): + swaps = {1: make_swap(1, status=SwapStatus.ACTIVE)} + poller = make_poller(next_id=2, swaps_by_id=swaps) + active, fulfilled = poller.poll() + assert len(active) == 1 + assert active[0].id == 1 + assert fulfilled == [] + + def test_finds_fulfilled_swap(self): + swaps = {2: make_swap(2, status=SwapStatus.FULFILLED)} + poller = make_poller(next_id=3, swaps_by_id=swaps) + active, fulfilled = poller.poll() + assert active == [] + assert len(fulfilled) == 1 + assert fulfilled[0].id == 2 + + def test_skips_completed_swap(self): + swaps = {1: make_swap(1, status=SwapStatus.COMPLETED)} + poller = make_poller(next_id=2, swaps_by_id=swaps) + active, fulfilled = poller.poll() + assert active == [] + assert fulfilled == [] + + def test_skips_timed_out_swap(self): + swaps = {1: make_swap(1, status=SwapStatus.TIMED_OUT)} + poller = make_poller(next_id=2, swaps_by_id=swaps) + active, fulfilled = poller.poll() + assert active == [] + + def test_skips_none_result(self): + poller = make_poller(next_id=2, swaps_by_id={}) + active, fulfilled = poller.poll() + assert active == [] + assert fulfilled == [] + + +class TestCursor: + def test_cursor_advances_after_scan(self): + swaps = {1: make_swap(1), 2: make_swap(2)} + poller = make_poller(next_id=3, swaps_by_id=swaps) + poller.poll() + assert poller.last_scanned_id == 2 + + def test_cursor_not_advanced_when_contract_empty(self): + poller = make_poller(next_id=1) + poller.poll() + assert poller.last_scanned_id == 0 + + def test_cursor_skips_already_scanned_ids(self): + swaps = {1: make_swap(1)} + poller = make_poller(next_id=2, swaps_by_id=swaps) + poller.last_scanned_id = 5 + poller.poll() + assert poller.client.get_swap.call_count == 0 + + +class TestRefreshActive: + def test_removes_resolved_swap(self): + active_swap = make_swap(1, status=SwapStatus.ACTIVE) + poller = make_poller(next_id=2, swaps_by_id={1: active_swap}) + poller.poll() + assert 1 in poller.active + + completed = make_swap(1, status=SwapStatus.COMPLETED) + poller.client.get_swap.side_effect = lambda sid: {1: completed}.get(sid) + poller.client.get_next_swap_id.return_value = 2 + active, fulfilled = poller.poll() + assert active == [] + assert 1 not in poller.active + + def test_removes_missing_swap(self): + active_swap = make_swap(1, status=SwapStatus.ACTIVE) + poller = make_poller(next_id=2, swaps_by_id={1: active_swap}) + poller.poll() + + poller.client.get_swap.side_effect = lambda sid: None + poller.client.get_next_swap_id.return_value = 2 + poller.poll() + assert 1 not in poller.active + + def test_updates_active_swap_state(self): + active_swap = make_swap(1, status=SwapStatus.ACTIVE) + poller = make_poller(next_id=2, swaps_by_id={1: active_swap}) + poller.poll() + + updated = make_swap(1, status=SwapStatus.FULFILLED) + poller.client.get_swap.side_effect = lambda sid: {1: updated}.get(sid) + poller.client.get_next_swap_id.return_value = 2 + active, fulfilled = poller.poll() + assert active == [] + assert len(fulfilled) == 1 + + +class TestErrorHandling: + def test_exception_in_inner_sets_flag(self): + poller = make_poller(next_id=2) + poller.client.get_next_swap_id.side_effect = RuntimeError('rpc down') + active, fulfilled = poller.poll() + assert active == [] + assert fulfilled == [] + assert poller.last_poll_ok is False + + def test_successful_poll_sets_flag_true(self): + poller = make_poller(next_id=1) + poller.last_poll_ok = False + poller.poll() + assert poller.last_poll_ok is True