diff --git a/src/aleph/sdk/client/http.py b/src/aleph/sdk/client/http.py index a433e48d..742d9675 100644 --- a/src/aleph/sdk/client/http.py +++ b/src/aleph/sdk/client/http.py @@ -37,6 +37,7 @@ from aleph.sdk.client.services.dns import DNS from aleph.sdk.client.services.instance import Instance from aleph.sdk.client.services.port_forwarder import PortForwarder +from aleph.sdk.client.services.pricing import Pricing from aleph.sdk.client.services.scheduler import Scheduler from ..conf import settings @@ -135,7 +136,7 @@ async def __aenter__(self): self.crn = Crn(self) self.scheduler = Scheduler(self) self.instance = Instance(self) - + self.pricing = Pricing(self) return self async def __aexit__(self, exc_type, exc_val, exc_tb): diff --git a/src/aleph/sdk/client/services/crn.py b/src/aleph/sdk/client/services/crn.py index 3317644a..e8d57c8c 100644 --- a/src/aleph/sdk/client/services/crn.py +++ b/src/aleph/sdk/client/services/crn.py @@ -1,18 +1,187 @@ -from typing import TYPE_CHECKING, Dict, Optional, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Union import aiohttp from aiohttp.client_exceptions import ClientResponseError from aleph_message.models import ItemHash +from pydantic import BaseModel from aleph.sdk.conf import settings from aleph.sdk.exceptions import MethodNotAvailableOnCRN, VmNotFoundOnHost -from aleph.sdk.types import CrnExecutionV1, CrnExecutionV2, CrnV1List, CrnV2List -from aleph.sdk.utils import sanitize_url +from aleph.sdk.types import ( + CrnExecutionV1, + CrnExecutionV2, + CrnV1List, + CrnV2List, + DictLikeModel, +) +from aleph.sdk.utils import extract_valid_eth_address, sanitize_url if TYPE_CHECKING: from aleph.sdk.client.http import AlephHttpClient +class GPU(BaseModel): + vendor: str + model: str + device_name: str + device_class: str + pci_host: str + compatible: bool + + +class NetworkGPUS(BaseModel): + total_gpu_count: int + available_gpu_count: int + available_gpu_list: dict[str, List[GPU]] # str = node_url + used_gpu_list: dict[str, List[GPU]] # str = node_url + + +class CRN(DictLikeModel): + # This Model work as dict but where we can type what we need / apply logic on top + + # Simplify search + hash: str + name: str + address: str + + gpu_support: Optional[bool] = False + confidential_support: Optional[bool] = False + qemu_support: Optional[bool] = False + + version: Optional[str] = "0.0.0" + payment_receiver_address: Optional[str] # Can be None if not configured + + +class CrnList(DictLikeModel): + crns: list[CRN] = [] + + @classmethod + def from_api(cls, payload: dict) -> "CrnList": + raw_list = payload.get("crns", []) + crn_list = [ + CRN.model_validate(item) if not isinstance(item, CRN) else item + for item in raw_list + ] + return cls(crns=crn_list) + + def find_gpu_on_network(self): + gpu_count: int = 0 + available_gpu_count: int = 0 + + compatible_gpu: Dict[str, List[GPU]] = {} + available_compatible_gpu: Dict[str, List[GPU]] = {} + + for crn_ in self.crns: + if not crn_.gpu_support: + continue + + # Extracts used GPU + for gpu in crn_.get("compatible_gpus", []): + compatible_gpu[crn_.address] = [] + compatible_gpu[crn_.address].append(GPU.model_validate(gpu)) + gpu_count += 1 + + # Extracts available GPU + for gpu in crn_.get("compatible_available_gpus", []): + available_compatible_gpu[crn_.address] = [] + available_compatible_gpu[crn_.address].append(GPU.model_validate(gpu)) + gpu_count += 1 + available_gpu_count += 1 + + return NetworkGPUS( + total_gpu_count=gpu_count, + available_gpu_count=available_gpu_count, + used_gpu_list=compatible_gpu, + available_gpu_list=available_compatible_gpu, + ) + + def filter_crn( + self, + latest_crn_version: bool = False, + ipv6: bool = False, + stream_address: bool = False, + confidential: bool = False, + gpu: bool = False, + ) -> list[CRN]: + """Filter compute resource node list, unfiltered by default. + Args: + latest_crn_version (bool): Filter by latest crn version. + ipv6 (bool): Filter invalid IPv6 configuration. + stream_address (bool): Filter invalid payment receiver address. + confidential (bool): Filter by confidential computing support. + gpu (bool): Filter by GPU support. + Returns: + list[CRN]: List of compute resource nodes. (if no filter applied, return all) + """ + # current_crn_version = await fetch_latest_crn_version() + # Relax current filter to allow use aleph-vm versions since 1.5.1. + # TODO: Allow to specify that option on settings aggregate on maybe on GitHub + current_crn_version = "1.5.1" + + filtered_crn: list[CRN] = [] + for crn_ in self.crns: + # Check crn version + if latest_crn_version and (crn_.version or "0.0.0") < current_crn_version: + continue + + # Filter with ipv6 check + if ipv6: + ipv6_check = crn_.get("ipv6_check") + if not ipv6_check or not all(ipv6_check.values()): + continue + + if stream_address and not extract_valid_eth_address( + crn_.payment_receiver_address or "" + ): + continue + + # Confidential Filter + if confidential and not crn_.confidential_support: + continue + + # Filter with GPU / Available GPU + available_gpu = crn_.get("compatible_available_gpus") + if gpu and (not crn_.gpu_support or not available_gpu): + continue + + filtered_crn.append(crn_) + return filtered_crn + + # Find CRN by address + def find_crn_by_address(self, address: str) -> Optional[CRN]: + for crn_ in self.crns: + if crn_.address == sanitize_url(address): + return crn_ + return None + + # Find CRN by hash + def find_crn_by_hash(self, crn_hash: str) -> Optional[CRN]: + for crn_ in self.crns: + if crn_.hash == crn_hash: + return crn_ + return None + + def find_crn( + self, + address: Optional[str] = None, + crn_hash: Optional[str] = None, + ) -> Optional[CRN]: + """Find CRN by address or hash (both optional, address priority) + + Args: + address (Optional[str], optional): url of the node. Defaults to None. + crn_hash (Optional[str], optional): hash of the nodes. Defaults to None. + + Returns: + Optional[CRN]: CRN object or None if not found + """ + if address: + return self.find_crn_by_address(address) + if crn_hash: + return self.find_crn_by_hash(crn_hash) + return None + + class Crn: """ This services allow interact with CRNS API @@ -45,7 +214,7 @@ async def get_last_crn_version(self): data = await resp.json() return data.get("tag_name") - async def get_crns_list(self, only_active: bool = True) -> dict: + async def get_crns_list(self, only_active: bool = True) -> CrnList: """ Query a persistent VM running on aleph.im to retrieve list of CRNs: https://crns-list.aleph.sh/crns.json @@ -72,7 +241,7 @@ async def get_crns_list(self, only_active: bool = True) -> dict: sanitize_url(settings.CRN_LIST_URL), params=params ) as resp: resp.raise_for_status() - return await resp.json() + return CrnList.from_api(await resp.json()) async def get_active_vms_v2(self, crn_address: str) -> CrnV2List: endpoint = "/v2/about/executions/list" @@ -136,3 +305,11 @@ async def update_instance_config(self, crn_address: str, item_hash: ItemHash): async with session.post(full_url) as resp: resp.raise_for_status() return await resp.json() + + # Gpu Functions Helper + async def fetch_gpu_on_network( + self, + only_active: bool = True, + ) -> NetworkGPUS: + crn_list = await self.get_crns_list(only_active) + return crn_list.find_gpu_on_network() diff --git a/src/aleph/sdk/client/services/pricing.py b/src/aleph/sdk/client/services/pricing.py new file mode 100644 index 00000000..acf7c214 --- /dev/null +++ b/src/aleph/sdk/client/services/pricing.py @@ -0,0 +1,235 @@ +import logging +import math +from enum import Enum +from typing import TYPE_CHECKING, Dict, List, Optional, Union + +from aleph.sdk.client.services.base import BaseService + +if TYPE_CHECKING: + pass + +from decimal import Decimal + +from pydantic import BaseModel, RootModel + +logger = logging.getLogger(__name__) + + +class PricingEntity(str, Enum): + STORAGE = "storage" + WEB3_HOSTING = "web3_hosting" + PROGRAM = "program" + PROGRAM_PERSISTENT = "program_persistent" + INSTANCE = "instance" + INSTANCE_CONFIDENTIAL = "instance_confidential" + INSTANCE_GPU_STANDARD = "instance_gpu_standard" + INSTANCE_GPU_PREMIUM = "instance_gpu_premium" + + +class GroupEntity(str, Enum): + STORAGE = "storage" + WEBSITE = "website" + PROGRAM = "program" + INSTANCE = "instance" + CONFIDENTIAL = "confidential" + GPU = "gpu" + ALL = "all" + + +class Price(BaseModel): + payg: Optional[Decimal] = None + holding: Optional[Decimal] = None + fixed: Optional[Decimal] = None + + +class ComputeUnit(BaseModel): + vcpus: int + memory_mib: int + disk_mib: int + + +class TierComputedSpec(ComputeUnit): + ... + gpu_model: Optional[str] + vram: Optional[int] + + +class Tier(BaseModel): + id: str + compute_units: int + vram: Optional[int] = None + model: Optional[str] = None + + def extract_tier_id(self) -> str: + return self.id.split("-", 1)[-1] + + +class PricingPerEntity(BaseModel): + price: Dict[str, Union[Price, Decimal]] + compute_unit: Optional[ComputeUnit] = None + tiers: Optional[List[Tier]] = None + + def _get_nb_compute_units( + self, + vcpus: int = 1, + memory_mib: int = 2048, + ) -> Optional[int]: + if self.compute_unit: + memory = math.ceil(memory_mib / self.compute_unit.memory_mib) + nb_compute = vcpus if vcpus >= memory else memory + return nb_compute + return None + + def get_closest_tier( + self, + vcpus: Optional[int] = None, + memory_mib: Optional[int] = None, + compute_unit: Optional[int] = None, + ): + """Get Closest tier for Program / Instance""" + + # We Calculate Compute Unit requested based on vcpus and memory + computed_cu = None + if vcpus is not None and memory_mib is not None: + computed_cu = self._get_nb_compute_units(vcpus=vcpus, memory_mib=memory_mib) + elif vcpus is not None and self.compute_unit is not None: + computed_cu = self._get_nb_compute_units( + vcpus=vcpus, memory_mib=self.compute_unit.memory_mib + ) + elif memory_mib is not None and self.compute_unit is not None: + computed_cu = self._get_nb_compute_units( + vcpus=self.compute_unit.vcpus, memory_mib=memory_mib + ) + + # Case where Vcpus or memory is given but also a number of CU (case on aleph-client) + cu: Optional[int] = None + if computed_cu is not None and compute_unit is not None: + if computed_cu != compute_unit: + logger.warning( + f"Mismatch in compute units: from CPU/RAM={computed_cu}, given={compute_unit}. " + f"Choosing {max(computed_cu, compute_unit)}." + ) + cu = max(computed_cu, compute_unit) # We trust the bigger trier + else: + cu = compute_unit if compute_unit is not None else computed_cu + + # now tier found + if cu is None: + return None + + # With CU available, choose the closest one + candidates = self.tiers + if candidates is None: + return None + + best_tier = min( + candidates, + key=lambda t: (abs(t.compute_units - cu), -t.compute_units), + ) + return best_tier + + def get_services_specs( + self, + tier: Tier, + ) -> TierComputedSpec: + """ + Calculate ammount of vram / cpu / disk | + gpu model / vram if it GPU instance + """ + if self.compute_unit is None: + raise ValueError("ComputeUnit is required to get service specs") + + cpu = tier.compute_units * self.compute_unit.vcpus + memory_mib = tier.compute_units * self.compute_unit.memory_mib + disk = ( + tier.compute_units * self.compute_unit.disk_mib + ) # Min value disk can be increased + + # Gpu Specs + gpu = None + vram = None + if tier.model and tier.vram: + gpu = tier.model + vram = tier.vram + + return TierComputedSpec( + vcpus=cpu, + memory_mib=memory_mib, + disk_mib=disk, + gpu_model=gpu, + vram=vram, + ) + + +class PricingModel(RootModel[Dict[PricingEntity, PricingPerEntity]]): + def __iter__(self): + return iter(self.root) + + def __getitem__(self, item): + return self.root[item] + + +PRICING_GROUPS: dict[str, list[PricingEntity]] = { + GroupEntity.STORAGE: [PricingEntity.STORAGE], + GroupEntity.WEBSITE: [PricingEntity.WEB3_HOSTING], + GroupEntity.PROGRAM: [PricingEntity.PROGRAM, PricingEntity.PROGRAM_PERSISTENT], + GroupEntity.INSTANCE: [PricingEntity.INSTANCE], + GroupEntity.CONFIDENTIAL: [PricingEntity.INSTANCE_CONFIDENTIAL], + GroupEntity.GPU: [ + PricingEntity.INSTANCE_GPU_STANDARD, + PricingEntity.INSTANCE_GPU_PREMIUM, + ], + GroupEntity.ALL: list(PricingEntity), +} + +PAYG_GROUP: list[PricingEntity] = [ + PricingEntity.INSTANCE, + PricingEntity.INSTANCE_CONFIDENTIAL, + PricingEntity.INSTANCE_GPU_STANDARD, + PricingEntity.INSTANCE_GPU_PREMIUM, +] + + +class Pricing(BaseService[PricingModel]): + """ + This Service handle logic around Pricing + """ + + aggregate_key = "pricing" + model_cls = PricingModel + + def __init__(self, client): + super().__init__(client=client) + + # Config from aggregate + async def get_pricing_aggregate( + self, + ) -> PricingModel: + result = await self.get_config( + address="0xFba561a84A537fCaa567bb7A2257e7142701ae2A" + ) + return result.data[0] + + async def get_pricing_for_services( + self, services: List[PricingEntity], pricing_info: Optional[PricingModel] = None + ) -> Dict[PricingEntity, PricingPerEntity]: + """ + Get pricing information for requested services + + Args: + services: List of pricing entities to get information for + pricing_info: Optional pre-fetched pricing aggregate + + Returns: + Dictionary with pricing information for requested services + """ + if ( + not pricing_info + ): # Avoid reloading aggregate info if there is already fetched + pricing_info = await self.get_pricing_aggregate() + + result = {} + for service in services: + if service in pricing_info: + result[service] = pricing_info[service] + + return result diff --git a/src/aleph/sdk/client/services/scheduler.py b/src/aleph/sdk/client/services/scheduler.py index 765ee2bd..0282847a 100644 --- a/src/aleph/sdk/client/services/scheduler.py +++ b/src/aleph/sdk/client/services/scheduler.py @@ -27,7 +27,6 @@ async def get_plan(self) -> SchedulerPlan: async with session.get(url) as resp: resp.raise_for_status() raw = await resp.json() - return SchedulerPlan.model_validate(raw) async def get_nodes(self) -> SchedulerNodes: diff --git a/src/aleph/sdk/types.py b/src/aleph/sdk/types.py index 6c1ae561..3ef8e2c7 100644 --- a/src/aleph/sdk/types.py +++ b/src/aleph/sdk/types.py @@ -1,10 +1,27 @@ from abc import abstractmethod from datetime import datetime from enum import Enum -from typing import Any, Dict, List, Literal, Optional, Protocol, TypeVar, Union +from typing import ( + Any, + Dict, + Iterator, + List, + Literal, + Optional, + Protocol, + TypeVar, + Union, +) from aleph_message.models import ItemHash -from pydantic import BaseModel, Field, RootModel, TypeAdapter, field_validator +from pydantic import ( + BaseModel, + ConfigDict, + Field, + RootModel, + TypeAdapter, + field_validator, +) __all__ = ("StorageEnum", "Account", "AccountFromPrivateKey", "GenericMessage") @@ -289,3 +306,36 @@ class Ports(BaseModel): AllForwarders = RootModel[Dict[ItemHash, Ports]] + + +class DictLikeModel(BaseModel): + """ + Base class: behaves like a dict while still being a Pydantic model. + """ + + # allow extra fields + validate on assignment + model_config = ConfigDict(extra="allow", validate_assignment=True) + + def __getitem__(self, key: str) -> Any: + return getattr(self, key) + + def __setitem__(self, key: str, value: Any) -> None: + setattr(self, key, value) + + def __iter__(self) -> Iterator[str]: + return iter(self.model_dump().keys()) + + def __contains__(self, key: str) -> bool: + return hasattr(self, key) + + def keys(self): + return self.model_dump().keys() + + def values(self): + return self.model_dump().values() + + def items(self): + return self.model_dump().items() + + def get(self, key: str, default=None): + return getattr(self, key, default) diff --git a/src/aleph/sdk/utils.py b/src/aleph/sdk/utils.py index 19a3aa57..94bc3bb9 100644 --- a/src/aleph/sdk/utils.py +++ b/src/aleph/sdk/utils.py @@ -6,6 +6,7 @@ import json import logging import os +import re import subprocess from datetime import date, datetime, time from decimal import Context, Decimal, InvalidOperation @@ -613,3 +614,12 @@ def sanitize_url(url: str) -> str: url = f"https://{url}" return url + + +def extract_valid_eth_address(address: str) -> str: + if address: + pattern = r"0x[a-fA-F0-9]{40}" + match = re.search(pattern, address) + if match: + return match.group(0) + return "" diff --git a/tests/unit/services/pricing_aggregate.json b/tests/unit/services/pricing_aggregate.json new file mode 100644 index 00000000..2da0dbb8 --- /dev/null +++ b/tests/unit/services/pricing_aggregate.json @@ -0,0 +1,273 @@ +{ + "address": "0xFba561a84A537fCaa567bb7A2257e7142701ae2A", + "data": { + "pricing": { + "program": { + "price": { + "storage": { + "payg": "0.000000977", + "holding": "0.05" + }, + "compute_unit": { + "payg": "0.011", + "holding": "200" + } + }, + "tiers": [ + { + "id": "tier-1", + "compute_units": 1 + }, + { + "id": "tier-2", + "compute_units": 2 + }, + { + "id": "tier-3", + "compute_units": 4 + }, + { + "id": "tier-4", + "compute_units": 6 + }, + { + "id": "tier-5", + "compute_units": 8 + }, + { + "id": "tier-6", + "compute_units": 12 + } + ], + "compute_unit": { + "vcpus": 1, + "disk_mib": 2048, + "memory_mib": 2048 + } + }, + "storage": { + "price": { + "storage": { + "holding": "0.333333333" + } + } + }, + "instance": { + "price": { + "storage": { + "payg": "0.000000977", + "holding": "0.05" + }, + "compute_unit": { + "payg": "0.055", + "holding": "1000" + } + }, + "tiers": [ + { + "id": "tier-1", + "compute_units": 1 + }, + { + "id": "tier-2", + "compute_units": 2 + }, + { + "id": "tier-3", + "compute_units": 4 + }, + { + "id": "tier-4", + "compute_units": 6 + }, + { + "id": "tier-5", + "compute_units": 8 + }, + { + "id": "tier-6", + "compute_units": 12 + } + ], + "compute_unit": { + "vcpus": 1, + "disk_mib": 20480, + "memory_mib": 2048 + } + }, + "web3_hosting": { + "price": { + "fixed": 50, + "storage": { + "holding": "0.333333333" + } + } + }, + "program_persistent": { + "price": { + "storage": { + "payg": "0.000000977", + "holding": "0.05" + }, + "compute_unit": { + "payg": "0.055", + "holding": "1000" + } + }, + "tiers": [ + { + "id": "tier-1", + "compute_units": 1 + }, + { + "id": "tier-2", + "compute_units": 2 + }, + { + "id": "tier-3", + "compute_units": 4 + }, + { + "id": "tier-4", + "compute_units": 6 + }, + { + "id": "tier-5", + "compute_units": 8 + }, + { + "id": "tier-6", + "compute_units": 12 + } + ], + "compute_unit": { + "vcpus": 1, + "disk_mib": 20480, + "memory_mib": 2048 + } + }, + "instance_gpu_premium": { + "price": { + "storage": { + "payg": "0.000000977" + }, + "compute_unit": { + "payg": "0.56" + } + }, + "tiers": [ + { + "id": "tier-1", + "vram": 81920, + "model": "A100", + "compute_units": 16 + }, + { + "id": "tier-2", + "vram": 81920, + "model": "H100", + "compute_units": 24 + } + ], + "compute_unit": { + "vcpus": 1, + "disk_mib": 61440, + "memory_mib": 6144 + } + }, + "instance_confidential": { + "price": { + "storage": { + "payg": "0.000000977", + "holding": "0.05" + }, + "compute_unit": { + "payg": "0.11", + "holding": "2000" + } + }, + "tiers": [ + { + "id": "tier-1", + "compute_units": 1 + }, + { + "id": "tier-2", + "compute_units": 2 + }, + { + "id": "tier-3", + "compute_units": 4 + }, + { + "id": "tier-4", + "compute_units": 6 + }, + { + "id": "tier-5", + "compute_units": 8 + }, + { + "id": "tier-6", + "compute_units": 12 + } + ], + "compute_unit": { + "vcpus": 1, + "disk_mib": 20480, + "memory_mib": 2048 + } + }, + "instance_gpu_standard": { + "price": { + "storage": { + "payg": "0.000000977" + }, + "compute_unit": { + "payg": "0.28" + } + }, + "tiers": [ + { + "id": "tier-1", + "vram": 20480, + "model": "RTX 4000 ADA", + "compute_units": 3 + }, + { + "id": "tier-2", + "vram": 24576, + "model": "RTX 3090", + "compute_units": 4 + }, + { + "id": "tier-3", + "vram": 24576, + "model": "RTX 4090", + "compute_units": 6 + }, + { + "id": "tier-3", + "vram": 32768, + "model": "RTX 5090", + "compute_units": 8 + }, + { + "id": "tier-4", + "vram": 49152, + "model": "L40S", + "compute_units": 12 + } + ], + "compute_unit": { + "vcpus": 1, + "disk_mib": 61440, + "memory_mib": 6144 + } + } + } + }, + "info": { + + } +} \ No newline at end of file diff --git a/tests/unit/services/test_pricing.py b/tests/unit/services/test_pricing.py new file mode 100644 index 00000000..ab6f7981 --- /dev/null +++ b/tests/unit/services/test_pricing.py @@ -0,0 +1,212 @@ +import json +from decimal import Decimal +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from aleph.sdk.client.http import AlephHttpClient +from aleph.sdk.client.services.pricing import ( + PAYG_GROUP, + PRICING_GROUPS, + GroupEntity, + Price, + Pricing, + PricingEntity, + PricingModel, + PricingPerEntity, +) + + +@pytest.fixture +def pricing_aggregate(): + """Load the pricing aggregate JSON file for testing.""" + json_path = Path(__file__).parent / "pricing_aggregate.json" + with open(json_path, "r") as f: + data = json.load(f) + return data + + +@pytest.fixture +def mock_client(pricing_aggregate): + """Create a real client with mocked HTTP responses.""" + # Create a mock response for the http session get method + mock_response = AsyncMock() + mock_response.raise_for_status.return_value = None + mock_response.json.return_value = pricing_aggregate + + # Create an async context manager for the mock response + mock_context = AsyncMock() + mock_context.__aenter__.return_value = mock_response + + # Create a mock HTTP session + mock_session = AsyncMock() + mock_session.get = MagicMock(return_value=mock_context) + + client = AlephHttpClient(api_server="http://localhost") + client._http_session = mock_session + + return client + + +@pytest.mark.asyncio +async def test_get_pricing_aggregate(mock_client): + """Test fetching the pricing aggregate data.""" + pricing_service = Pricing(mock_client) + result = await pricing_service.get_pricing_aggregate() + + # Check the result is a PricingModel + assert isinstance(result, PricingModel) + + assert PricingEntity.STORAGE in result + assert PricingEntity.PROGRAM in result + assert PricingEntity.INSTANCE in result + + storage_entity = result[PricingEntity.STORAGE] + assert isinstance(storage_entity, PricingPerEntity) + assert "storage" in storage_entity.price + storage_price = storage_entity.price["storage"] + assert isinstance(storage_price, Price) # Add type assertion for mypy + assert storage_price.holding == Decimal("0.333333333") + + # Check program entity has correct compute unit details + program_entity = result[PricingEntity.PROGRAM] + assert isinstance(program_entity, PricingPerEntity) + assert program_entity.compute_unit is not None # Ensure compute_unit is not None + assert program_entity.compute_unit.vcpus == 1 + assert program_entity.compute_unit.memory_mib == 2048 + assert program_entity.compute_unit.disk_mib == 2048 + + # Check tiers in instance entity + instance_entity = result[PricingEntity.INSTANCE] + assert instance_entity.tiers is not None # Ensure tiers is not None + assert len(instance_entity.tiers) == 6 + assert instance_entity.tiers[0].id == "tier-1" + assert instance_entity.tiers[0].compute_units == 1 + + +@pytest.mark.asyncio +async def test_get_pricing_for_services(mock_client): + """Test fetching pricing for specific services.""" + pricing_service = Pricing(mock_client) + + # Test Case 1: Get pricing for storage and program services + services = [PricingEntity.STORAGE, PricingEntity.PROGRAM] + result = await pricing_service.get_pricing_for_services(services) + + # Check the result contains only the requested entities + assert len(result) == 2 + assert PricingEntity.STORAGE in result + assert PricingEntity.PROGRAM in result + assert PricingEntity.INSTANCE not in result + + # Verify specific pricing data + storage_price = result[PricingEntity.STORAGE].price["storage"] + assert isinstance(storage_price, Price) # Ensure it's a Price object + assert storage_price.holding == Decimal("0.333333333") + + compute_price = result[PricingEntity.PROGRAM].price["compute_unit"] + assert isinstance(compute_price, Price) # Ensure it's a Price object + assert compute_price.payg == Decimal("0.011") + assert compute_price.holding == Decimal("200") + + # Test Case 2: Using pre-fetched pricing aggregate + pricing_info = await pricing_service.get_pricing_aggregate() + result2 = await pricing_service.get_pricing_for_services(services, pricing_info) + + # Results should be the same + assert result[PricingEntity.STORAGE].price == result2[PricingEntity.STORAGE].price + assert result[PricingEntity.PROGRAM].price == result2[PricingEntity.PROGRAM].price + + # Test Case 3: Empty services list + empty_result = await pricing_service.get_pricing_for_services([]) + assert isinstance(empty_result, dict) + assert len(empty_result) == 0 + + # Test Case 4: Web3 hosting service + web3_result = await pricing_service.get_pricing_for_services( + [PricingEntity.WEB3_HOSTING] + ) + assert len(web3_result) == 1 + assert PricingEntity.WEB3_HOSTING in web3_result + assert web3_result[PricingEntity.WEB3_HOSTING].price["fixed"] == Decimal("50") + + # Test Case 5: GPU services have specific properties + gpu_services = [ + PricingEntity.INSTANCE_GPU_STANDARD, + PricingEntity.INSTANCE_GPU_PREMIUM, + ] + gpu_result = await pricing_service.get_pricing_for_services(gpu_services) + assert len(gpu_result) == 2 + # Check GPU models are present + standard_tiers = gpu_result[PricingEntity.INSTANCE_GPU_STANDARD].tiers + premium_tiers = gpu_result[PricingEntity.INSTANCE_GPU_PREMIUM].tiers + assert standard_tiers is not None + assert premium_tiers is not None + assert standard_tiers[0].model == "RTX 4000 ADA" + assert premium_tiers[1].model == "H100" + + +@pytest.mark.asyncio +async def test_get_pricing_for_gpu_services(mock_client): + """Test fetching pricing for GPU services.""" + pricing_service = Pricing(mock_client) + + # Test with GPU services + gpu_services = [ + PricingEntity.INSTANCE_GPU_STANDARD, + PricingEntity.INSTANCE_GPU_PREMIUM, + ] + result = await pricing_service.get_pricing_for_services(gpu_services) + + # Check that both GPU services are returned + assert len(result) == 2 + assert PricingEntity.INSTANCE_GPU_STANDARD in result + assert PricingEntity.INSTANCE_GPU_PREMIUM in result + + # Verify GPU standard pricing and details + gpu_standard = result[PricingEntity.INSTANCE_GPU_STANDARD] + compute_unit_price = gpu_standard.price["compute_unit"] + assert isinstance(compute_unit_price, Price) + assert compute_unit_price.payg == Decimal("0.28") + + standard_tiers = gpu_standard.tiers + assert standard_tiers is not None + assert len(standard_tiers) == 5 + assert standard_tiers[0].model == "RTX 4000 ADA" + assert standard_tiers[0].vram == 20480 + + # Verify GPU premium pricing and details + gpu_premium = result[PricingEntity.INSTANCE_GPU_PREMIUM] + premium_compute_price = gpu_premium.price["compute_unit"] + assert isinstance(premium_compute_price, Price) + assert premium_compute_price.payg == Decimal("0.56") + + premium_tiers = gpu_premium.tiers + assert premium_tiers is not None + assert len(premium_tiers) == 2 + assert premium_tiers[1].model == "H100" + assert premium_tiers[1].vram == 81920 + + +@pytest.mark.asyncio +async def test_pricing_groups(): + """Test the pricing groups constants.""" + # Check that all pricing entities are covered in PRICING_GROUPS + all_entities = set() + for group_entities in PRICING_GROUPS.values(): + for entity in group_entities: + all_entities.add(entity) + + # All PricingEntity values should be in some group + for entity in PricingEntity: + assert entity in all_entities + + # Check ALL group contains all entities + assert set(PRICING_GROUPS[GroupEntity.ALL]) == set(PricingEntity) + + # Check PAYG_GROUP contains expected entities + assert PricingEntity.INSTANCE in PAYG_GROUP + assert PricingEntity.INSTANCE_CONFIDENTIAL in PAYG_GROUP + assert PricingEntity.INSTANCE_GPU_STANDARD in PAYG_GROUP + assert PricingEntity.INSTANCE_GPU_PREMIUM in PAYG_GROUP