diff --git a/allways/contract_client.py b/allways/contract_client.py index ae8eb48..77816e2 100644 --- a/allways/contract_client.py +++ b/allways/contract_client.py @@ -56,6 +56,22 @@ from allways.classes import Swap, SwapStatus from allways.constants import CONTRACT_ADDRESS, MIN_BALANCE_FOR_TX_RAO +from allways.utils.scale import ( + ACCOUNT_ID_BYTES, + U32_BYTES, + U64_BYTES, + U128_BYTES, + compact_encode_len, + decode_account_id, + decode_string, + decode_u32, + decode_u64, + decode_u128, + encode_bytes, + encode_str, + encode_u128, + strip_hex_prefix, +) # ========================================================================= # Contract selectors (from metadata — deterministic per contract build) @@ -208,23 +224,6 @@ _EXTRINSIC_NOT_FOUND = tuple(t for t in [ExtrinsicNotFound, AsyncExtrinsicNotFound] if t is not None) -def compact_encode_len(length: int) -> bytes: - """SCALE compact-encode a length prefix. Shared by contract client and axon handlers.""" - if length < 64: - return bytes([length << 2]) - elif length < 16384: - return bytes([((length << 2) | 1) & 0xFF, length >> 6]) - else: - return bytes( - [ - ((length << 2) | 2) & 0xFF, - (length >> 6) & 0xFF, - (length >> 14) & 0xFF, - (length >> 22) & 0xFF, - ] - ) - - # ContractExecResult byte layout offsets (after gas prefix) _GAS_PREFIX_BYTES = 16 # Skip gas consumed/required _RESULT_OK_OFFSET = 10 # Byte indicating Ok(0x00) vs Err in Result @@ -366,7 +365,7 @@ def raw_contract_read( if not result.get('result'): return None - raw = bytes.fromhex(result['result'].replace('0x', '')) + raw = bytes.fromhex(strip_hex_prefix(result['result'])) if len(raw) < 32: return None @@ -376,7 +375,7 @@ def raw_contract_read( if len(r) < _DATA_COMPACT_OFFSET or r[_RESULT_OK_OFFSET] != 0x00: return None - flags = struct.unpack_from(' bytes: return struct.pack('B', int(value)) elif type_tag == 'hash': if isinstance(value, str): - return bytes.fromhex(value.replace('0x', '')) - return bytes(value)[:32].ljust(32, b'\x00') + return bytes.fromhex(strip_hex_prefix(value)) + return bytes(value)[:ACCOUNT_ID_BYTES].ljust(ACCOUNT_ID_BYTES, b'\x00') elif type_tag == 'bytes': data = value if isinstance(value, (bytes, bytearray)) else value.encode('utf-8') - return compact_encode_len(len(data)) + data + return encode_bytes(data) elif type_tag == 'u32': return struct.pack('> 64) + return encode_u128(int(value)) elif type_tag == 'bool': return b'\x01' if value else b'\x00' elif type_tag == 'AccountId': @@ -524,8 +522,7 @@ def encode_value(self, value, type_tag: str) -> bytes: return bytes.fromhex(self.subtensor.substrate.ss58_decode(value)) return bytes(value) elif type_tag == 'str': - data = value.encode('utf-8') if isinstance(value, str) else value - return compact_encode_len(len(data)) + data + return encode_str(value) if isinstance(value, str) else encode_bytes(value) elif type_tag == 'vec_u64': items = list(value) encoded = compact_encode_len(len(items)) @@ -535,21 +532,19 @@ def encode_value(self, value, type_tag: str) -> bytes: raise ValueError(f'Unsupported type: {type_tag}') def extract_u32(self, data: bytes) -> Optional[int]: - if not data or len(data) < 4: + if not data or len(data) < U32_BYTES: return None - return struct.unpack_from(' Optional[int]: - if not data or len(data) < 8: + if not data or len(data) < U64_BYTES: return None - return struct.unpack_from(' Optional[int]: - if not data or len(data) < 16: + if not data or len(data) < U128_BYTES: return None - low = struct.unpack_from(' Optional[bool]: if not data: @@ -557,86 +552,38 @@ def extract_bool(self, data: bytes) -> Optional[bool]: return data[0] != 0 def extract_account_id(self, data: bytes) -> Optional[str]: - if not data or len(data) < 32: + if not data or len(data) < ACCOUNT_ID_BYTES: return None - return self.subtensor.substrate.ss58_encode(data[:32].hex()) - - def decode_string(self, data: bytes, offset: int) -> Tuple[str, int]: - """Decode a SCALE compact-prefixed string. Returns (string, new_offset).""" - if offset >= len(data): - return '', offset - first = data[offset] - mode = first & 0x03 - if mode == 0: - str_len = first >> 2 - offset += 1 - elif mode == 1: - if offset + 1 >= len(data): - return '', offset - str_len = (data[offset] | (data[offset + 1] << 8)) >> 2 - offset += 2 - else: - if offset + 3 >= len(data): - return '', offset - str_len = ( - data[offset] | (data[offset + 1] << 8) | (data[offset + 2] << 16) | (data[offset + 3] << 24) - ) >> 2 - offset += 4 - if offset + str_len > len(data): - return '', offset - s = data[offset : offset + str_len].decode('utf-8', errors='replace') - return s, offset + str_len + return decode_account_id(data, 0)[0] def decode_swap_data(self, data: bytes, offset: int = 0) -> Optional[Swap]: """Decode a SwapData struct from raw SCALE bytes.""" try: o = offset - - swap_id = struct.unpack_from(' Tuple[int, int]: if data is None or len(data) < 5: return (0, 0) strike_count = data[0] - last_expired = struct.unpack_from(' int: @@ -912,18 +859,9 @@ def get_reservation_data(self, miner_hotkey: str) -> Optional[Tuple[int, int, in if data[0] != 0x01: return None o = 1 - # 3 x u128 - tao_lo = struct.unpack_from(' str: + """Remove a leading ``0x`` from a hex string if present.""" + return s[2:] if s.startswith('0x') else s + + +def compact_encode_len(length: int) -> bytes: + """SCALE compact-encode a length prefix.""" + if length < 64: + return bytes([length << 2]) + if length < 16384: + return bytes([((length << 2) | 1) & 0xFF, length >> 6]) + return bytes( + [ + ((length << 2) | 2) & 0xFF, + (length >> 6) & 0xFF, + (length >> 14) & 0xFF, + (length >> 22) & 0xFF, + ] + ) + + +def encode_bytes(data: bytes) -> bytes: + """SCALE-encode raw bytes as compact length prefix + bytes.""" + return compact_encode_len(len(data)) + data + + +def encode_str(s: str) -> bytes: + """SCALE-encode a UTF-8 string as compact length prefix + bytes.""" + return encode_bytes(s.encode('utf-8')) + + +def encode_u128(value: int) -> bytes: + """SCALE-encode a u128 as 16 little-endian bytes.""" + return value.to_bytes(U128_BYTES, 'little') + + +# ─── Streaming decoders ──────────────────────────────────────────────────── + + +def decode_u32(data: bytes, offset: int) -> Tuple[int, int]: + return struct.unpack_from(' Tuple[int, int]: + return struct.unpack_from(' Tuple[int, int]: + lo = struct.unpack_from(' Tuple[bool, int]: + return data[offset] != 0, offset + 1 + + +def decode_account_id(data: bytes, offset: int) -> Tuple[str, int]: + raw = data[offset : offset + ACCOUNT_ID_BYTES] + return ss58_encode(raw, SS58_PREFIX), offset + ACCOUNT_ID_BYTES + + +def decode_string(data: bytes, offset: int) -> Tuple[str, int]: + """SCALE-decode a compact-length-prefixed UTF-8 string. + + Returns ``('', offset)`` on truncated or out-of-bounds input so callers + streaming composite structs degrade cleanly instead of raising. + """ + if offset >= len(data): + return '', offset + first = data[offset] + mode = first & 0x03 + if mode == 0: + str_len = first >> 2 + offset += 1 + elif mode == 1: + if offset + 1 >= len(data): + return '', offset + str_len = (data[offset] | (data[offset + 1] << 8)) >> 2 + offset += 2 + else: + if offset + 3 >= len(data): + return '', offset + str_len = (data[offset] | (data[offset + 1] << 8) | (data[offset + 2] << 16) | (data[offset + 3] << 24)) >> 2 + offset += 4 + if offset + str_len > len(data): + return '', offset + s = data[offset : offset + str_len].decode('utf-8', errors='replace') + return s, offset + str_len diff --git a/allways/validator/axon_handlers.py b/allways/validator/axon_handlers.py index 9817088..46b8050 100644 --- a/allways/validator/axon_handlers.py +++ b/allways/validator/axon_handlers.py @@ -18,8 +18,9 @@ from allways.classes import MinerPair from allways.commitments import read_miner_commitment from allways.constants import RESERVATION_COOLDOWN_BLOCKS -from allways.contract_client import AllwaysContractClient, ContractError, compact_encode_len, is_contract_rejection +from allways.contract_client import AllwaysContractClient, ContractError, is_contract_rejection from allways.synapses import MinerActivateSynapse, SwapConfirmSynapse, SwapReserveSynapse +from allways.utils.scale import encode_bytes, encode_str, encode_u128 from allways.validator.state_store import PendingConfirm if TYPE_CHECKING: @@ -46,36 +47,23 @@ def scale_encode_reserve_hash_input( &(miner, user_from_address, from_chain, to_chain, tao_amount, from_amount, to_amount) ). """ - src_bytes = from_chain.encode('utf-8') - dst_bytes = to_chain.encode('utf-8') return ( - miner_bytes # AccountId: 32 bytes raw - + compact_encode_len(len(from_addr_bytes)) - + from_addr_bytes # String: compact length + UTF-8 bytes - + compact_encode_len(len(src_bytes)) - + src_bytes # String: compact length + UTF-8 bytes - + compact_encode_len(len(dst_bytes)) - + dst_bytes # String: compact length + UTF-8 bytes - + tao_amount.to_bytes(16, 'little') # u128: 16 bytes LE - + from_amount.to_bytes(16, 'little') # u128: 16 bytes LE - + to_amount.to_bytes(16, 'little') # u128: 16 bytes LE + miner_bytes + + encode_bytes(from_addr_bytes) + + encode_str(from_chain) + + encode_str(to_chain) + + encode_u128(tao_amount) + + encode_u128(from_amount) + + encode_u128(to_amount) ) -def scale_encode_extend_hash_input( - miner_bytes: bytes, - from_tx_hash: str, -) -> bytes: +def scale_encode_extend_hash_input(miner_bytes: bytes, from_tx_hash: str) -> bytes: """SCALE-encode the extend hash input tuple: (AccountId, &str). Matches ink::env::hash_encoded::(&(miner, from_tx_hash)). """ - tx_bytes = from_tx_hash.encode('utf-8') - return ( - miner_bytes # AccountId: 32 bytes raw - + compact_encode_len(len(tx_bytes)) - + tx_bytes # &str (SCALE: compact length + bytes) - ) + return miner_bytes + encode_str(from_tx_hash) def scale_encode_initiate_hash_input( @@ -102,22 +90,17 @@ def scale_encode_initiate_hash_input( consensus on the full swap shape — the quorum-reaching vote cannot substitute any of these fields without invalidating the hash. """ - - def encode_str(s: str) -> bytes: - raw = s.encode('utf-8') - return compact_encode_len(len(raw)) + raw - return ( - miner_bytes # AccountId: 32 bytes raw + miner_bytes + encode_str(from_tx_hash) + encode_str(from_chain) + encode_str(to_chain) + encode_str(miner_from_address) + encode_str(miner_to_address) + encode_str(rate) - + tao_amount.to_bytes(16, 'little') # u128: 16 bytes LE - + from_amount.to_bytes(16, 'little') # u128: 16 bytes LE - + to_amount.to_bytes(16, 'little') # u128: 16 bytes LE + + encode_u128(tao_amount) + + encode_u128(from_amount) + + encode_u128(to_amount) ) diff --git a/allways/validator/event_watcher.py b/allways/validator/event_watcher.py index 7c2f206..d81b689 100644 --- a/allways/validator/event_watcher.py +++ b/allways/validator/event_watcher.py @@ -14,62 +14,24 @@ from __future__ import annotations import json -import struct from dataclasses import dataclass, field from pathlib import Path from typing import Any, Dict, List, Optional, Set, Tuple import bittensor as bt -from substrateinterface.utils.ss58 import ss58_encode from allways.constants import SCORING_WINDOW_BLOCKS +from allways.utils.scale import ( + decode_account_id, + decode_bool, + decode_string, + decode_u32, + decode_u64, + decode_u128, + strip_hex_prefix, +) from allways.validator.state_store import ValidatorStateStore -SS58_PREFIX = 42 - - -# ─── SCALE field decoders (ported from alw-utils watch_contract_events) ───── - - -def decode_u32(data: bytes, offset: int) -> Tuple[int, int]: - return struct.unpack_from(' Tuple[int, int]: - return struct.unpack_from(' Tuple[int, int]: - lo = struct.unpack_from(' Tuple[bool, int]: - return data[offset] != 0, offset + 1 - - -def decode_account_id(data: bytes, offset: int) -> Tuple[str, int]: - raw = data[offset : offset + 32] - return ss58_encode(raw, SS58_PREFIX), offset + 32 - - -def decode_string(data: bytes, offset: int) -> Tuple[str, int]: - first = data[offset] - mode = first & 0x03 - if mode == 0: - str_len = first >> 2 - offset += 1 - elif mode == 1: - str_len = (data[offset] | (data[offset + 1] << 8)) >> 2 - offset += 2 - else: - str_len = (data[offset] | (data[offset + 1] << 8) | (data[offset + 2] << 16) | (data[offset + 3] << 24)) >> 2 - offset += 4 - s = data[offset : offset + str_len].decode('utf-8', errors='replace') - return s, offset + str_len - - DATA_DECODERS = { 'u32': decode_u32, 'u64': decode_u64, @@ -81,11 +43,11 @@ def decode_string(data: bytes, offset: int) -> Tuple[str, int]: def topic_account_id(topic_bytes: bytes) -> str: - return ss58_encode(topic_bytes[:32], SS58_PREFIX) + return decode_account_id(topic_bytes, 0)[0] def topic_u64(topic_bytes: bytes) -> int: - return struct.unpack_from(' bool: @@ -196,9 +158,8 @@ def to_bytes(val: Any) -> bytes: if isinstance(val, bytes): return val if isinstance(val, str): - s = val.replace('0x', '') try: - return bytes.fromhex(s) + return bytes.fromhex(strip_hex_prefix(val)) except ValueError: return val.encode('utf-8') if isinstance(val, (list, tuple)): diff --git a/tests/test_scale.py b/tests/test_scale.py index fdfc026..a784000 100644 --- a/tests/test_scale.py +++ b/tests/test_scale.py @@ -4,7 +4,8 @@ from unittest.mock import MagicMock from allways.chain_providers.subtensor import SubtensorProvider -from allways.contract_client import AllwaysContractClient, compact_encode_len +from allways.contract_client import AllwaysContractClient +from allways.utils.scale import compact_encode_len, decode_string def make_client(): @@ -228,26 +229,25 @@ class TestDecodeString: def test_roundtrip_short(self): c = make_client() encoded = c.encode_value('hello', 'str') - s, offset = c.decode_string(encoded, 0) + s, offset = decode_string(encoded, 0) assert s == 'hello' assert offset == len(encoded) def test_roundtrip_empty(self): c = make_client() encoded = c.encode_value('', 'str') - s, offset = c.decode_string(encoded, 0) + s, offset = decode_string(encoded, 0) assert s == '' def test_roundtrip_medium(self): c = make_client() text = 'x' * 100 # Still in single-byte compact mode encoded = c.encode_value(text, 'str') - s, offset = c.decode_string(encoded, 0) + s, offset = decode_string(encoded, 0) assert s == text def test_offset_past_end(self): - c = make_client() - s, offset = c.decode_string(b'\x00', 10) + s, offset = decode_string(b'\x00', 10) assert s == '' def test_roundtrip_two_byte_compact(self): @@ -255,7 +255,7 @@ def test_roundtrip_two_byte_compact(self): # String of length 64+ triggers two-byte compact mode text = 'a' * 64 encoded = c.encode_value(text, 'str') - s, offset = c.decode_string(encoded, 0) + s, offset = decode_string(encoded, 0) assert s == text