Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 29 additions & 18 deletions heidi_engine/telemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,27 +473,39 @@ def load_pricing_config() -> Dict[str, Dict[str, float]]:
- Falls back to DEFAULT_PRICING
- Allows user to customize pricing per model

BOLT OPTIMIZATION:
Uses a thread-safe module-level cache (5.0s TTL) to avoid
redundant disk I/O and JSON parsing on high-frequency calls.

TUNABLE:
- Create pricing.json to override default prices
- Format: {"model_name": {"input": 0.5, "output": 1.5}}
- Prices are per 1M tokens
"""
pricing = DEFAULT_PRICING.copy()
global _pricing_cache, _pricing_check_ts
with _pricing_lock:
now = time.monotonic()
if _pricing_cache and (now - _pricing_check_ts) < 5.0:
return _pricing_cache.copy()

# Check for pricing config file
pricing_file = (
Path(PRICING_CONFIG_PATH) if PRICING_CONFIG_PATH else get_run_dir() / "pricing.json"
)
pricing = DEFAULT_PRICING.copy()

if pricing_file.exists():
try:
with open(pricing_file) as f:
custom = json.load(f)
pricing.update(custom)
except Exception as e:
print(f"[WARN] Failed to load pricing config: {e}", file=sys.stderr)
# Check for pricing config file
pricing_file = (
Path(PRICING_CONFIG_PATH) if PRICING_CONFIG_PATH else get_run_dir() / "pricing.json"
)

if pricing_file.exists():
try:
with open(pricing_file) as f:
custom = json.load(f)
pricing.update(custom)
except Exception as e:
print(f"[WARN] Failed to load pricing config: {e}", file=sys.stderr)

return pricing
_pricing_cache = pricing
_pricing_check_ts = now
return pricing.copy()
Comment on lines +485 to +508
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation of load_pricing_config uses shallow copies (.copy()) for a nested dictionary structure (Dict[str, Dict[str, float]]). This is problematic because the inner dictionaries (e.g., {"input": 0.15, "output": 0.60}) remain shared between the cache, the DEFAULT_PRICING constant, and any caller that receives the returned value. If a caller or a subsequent update modifies an inner dictionary, it will inadvertently corrupt the global state and the cache.

Additionally, the cache is module-level but the pricing file path can be run-specific (via get_run_dir()). While the TTL mitigates this, it's safer to ensure the cache is isolated using copy.deepcopy.

    global _pricing_cache, _pricing_check_ts
    with _pricing_lock:
        now = time.monotonic()
        if _pricing_cache and (now - _pricing_check_ts) < 5.0:
            return copy.deepcopy(_pricing_cache)

        # Use deepcopy to ensure isolation from the module-level constant
        pricing = copy.deepcopy(DEFAULT_PRICING)

        # Check for pricing config file
        pricing_file = (
            Path(PRICING_CONFIG_PATH) if PRICING_CONFIG_PATH else get_run_dir() / "pricing.json"
        )

        if pricing_file.exists():
            try:
                with open(pricing_file) as f:
                    custom = json.load(f)
                    pricing.update(custom)
            except Exception as e:
                print(f"[WARN] Failed to load pricing config: {e}", file=sys.stderr)

        _pricing_cache = pricing
        _pricing_check_ts = now
        return copy.deepcopy(pricing)



def estimate_cost(input_tokens: int, output_tokens: int, model: str) -> float:
Expand Down Expand Up @@ -732,11 +744,6 @@ def get_state(run_id: Optional[str] = None) -> Dict[str, Any]:
"usage": get_default_usage(),
}

# BOLT OPTIMIZATION: Check thread-safe state cache
cached = _state_cache.get(target_run_id, state_file)
if cached:
return cached

try:
with open(state_file) as f:
state = json.load(f)
Expand Down Expand Up @@ -1372,6 +1379,10 @@ def stage_context(stage: str, round_num: int, message: str, **kwargs):
_gpu_check_ts = 0.0
_gpu_lock = threading.Lock()

_pricing_cache: Dict[str, Dict[str, float]] = {}
_pricing_check_ts = 0.0
_pricing_lock = threading.Lock()

_event_ts_cache: Dict[str, str] = {}
_event_ts_check_ts: Dict[str, float] = {} # run_id -> ts
_event_lock = threading.Lock()
Expand Down