Skip to content

Commit

Permalink
feature: Make CurrencyCollector generic over Currency (#40)
Browse files Browse the repository at this point in the history
  • Loading branch information
antonagestam authored Jun 20, 2023
1 parent f60b7d7 commit 17ee450
Show file tree
Hide file tree
Showing 7 changed files with 40 additions and 24 deletions.
4 changes: 2 additions & 2 deletions generate-currencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from .registry import CurrencyCollector
from .registry import CurrencyRegistry
__currencies: Final = CurrencyCollector()
__currencies: Final = CurrencyCollector[Currency]()
"""
currency_template = """
Expand All @@ -30,7 +30,7 @@ class {code}Type(Currency):
registry_template = """\
registry: Final[CurrencyRegistry] = __currencies.finalize()
registry: Final[CurrencyRegistry[Currency]] = __currencies.finalize()
del __currencies
"""

Expand Down
2 changes: 1 addition & 1 deletion src/immoney/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def overdraft(
return Overdraft(Money(value, self))

@classmethod
def get_default_registry(cls) -> CurrencyRegistry:
def get_default_registry(cls) -> CurrencyRegistry[Currency]:
from .currencies import registry

return registry
Expand Down
30 changes: 19 additions & 11 deletions src/immoney/_pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from .currencies import registry as default_registry
from .registry import CurrencyRegistry

C = TypeVar("C", bound=Currency)


class MoneyDict(TypedDict):
subunits: int
Expand Down Expand Up @@ -82,7 +84,9 @@ def schema(currency_schema: core_schema.CoreSchema) -> core_schema.CoreSchema:

@staticmethod
@abc.abstractmethod
def validator_from_registry(registry: CurrencyRegistry) -> GeneralValidatorFunction:
def validator_from_registry(
registry: CurrencyRegistry[Currency],
) -> GeneralValidatorFunction:
...

@staticmethod
Expand Down Expand Up @@ -118,11 +122,13 @@ def schema(currency_schema: core_schema.CoreSchema) -> core_schema.CoreSchema:
)

@staticmethod
def validator_from_registry(registry: CurrencyRegistry) -> GeneralValidatorFunction:
def validator_from_registry(
registry: CurrencyRegistry[C],
) -> GeneralValidatorFunction:
def validate_money(
value: MoneyDict | Money[Currency],
*args: object,
_registry: CurrencyRegistry = registry,
_registry: CurrencyRegistry[C] = registry,
) -> Money[Currency]:
if isinstance(value, Money):
if value.currency.code not in _registry:
Expand Down Expand Up @@ -192,11 +198,13 @@ def schema(currency_schema: core_schema.CoreSchema) -> core_schema.CoreSchema:
)

@staticmethod
def validator_from_registry(registry: CurrencyRegistry) -> GeneralValidatorFunction:
def validator_from_registry(
registry: CurrencyRegistry[C],
) -> GeneralValidatorFunction:
def validate_subunit_fraction(
value: SubunitFractionDict | SubunitFraction[Currency],
*args: object,
_registry: CurrencyRegistry = registry,
_registry: CurrencyRegistry[C] = registry,
) -> SubunitFraction[Currency]:
if isinstance(value, SubunitFraction):
if value.currency.code not in _registry:
Expand Down Expand Up @@ -258,11 +266,13 @@ def schema(currency_schema: core_schema.CoreSchema) -> core_schema.CoreSchema:
)

@staticmethod
def validator_from_registry(registry: CurrencyRegistry) -> GeneralValidatorFunction:
def validator_from_registry(
registry: CurrencyRegistry[C],
) -> GeneralValidatorFunction:
def validate_overdraft(
value: OverdraftDict | Overdraft[Currency],
*args: object,
_registry: CurrencyRegistry = registry,
_registry: CurrencyRegistry[C] = registry,
) -> Overdraft[Currency]:
if isinstance(value, Overdraft):
if value.money.currency.code not in _registry:
Expand Down Expand Up @@ -336,9 +346,7 @@ def build_generic_currency_schema(
)


def build_currency_schema(
cls: type[Currency],
) -> core_schema.CoreSchema:
def build_currency_schema(cls: type[C]) -> core_schema.CoreSchema:
if abc.ABC not in cls.__bases__:
raise NotImplementedError(
"Using concrete Currency types as Pydantic fields is not yet supported."
Expand All @@ -349,7 +357,7 @@ def build_currency_schema(
def validate_currency(
value: str,
*args: object,
registry: CurrencyRegistry = cls_registry,
registry: CurrencyRegistry[Currency] = cls_registry,
) -> Currency:
if isinstance(value, str):
return registry[value]
Expand Down
4 changes: 2 additions & 2 deletions src/immoney/currencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from .registry import CurrencyCollector
from .registry import CurrencyRegistry

__currencies: Final = CurrencyCollector()
__currencies: Final = CurrencyCollector[Currency]()


@final
Expand Down Expand Up @@ -3090,5 +3090,5 @@ class ZWNType(Currency):
__currencies.add(ZWN)


registry: Final[CurrencyRegistry] = __currencies.finalize()
registry: Final[CurrencyRegistry[Currency]] = __currencies.finalize()
del __currencies
13 changes: 8 additions & 5 deletions src/immoney/registry.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,24 @@
from collections.abc import Mapping
from typing import Generic
from typing import TypeAlias
from typing import TypeVar

from immutables import Map

from ._base import Currency

CurrencyRegistry: TypeAlias = Mapping[str, Currency]
C = TypeVar("C", bound=Currency)
CurrencyRegistry: TypeAlias = Mapping[str, C]


class CurrencyCollector:
class CurrencyCollector(Generic[C]):
__slots__ = ("__collection",)

def __init__(self) -> None:
self.__collection = list[tuple[str, Currency]]()
self.__collection = list[tuple[str, C]]()

def add(self, currency: Currency) -> None:
def add(self, currency: C) -> None:
self.__collection.append((currency.code, currency))

def finalize(self) -> CurrencyRegistry:
def finalize(self) -> CurrencyRegistry[C]:
return Map(self.__collection)
9 changes: 7 additions & 2 deletions tests/custom_currency.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
from __future__ import annotations

import abc
from typing import Final

from typing_extensions import assert_type

from immoney import Currency
from immoney.registry import CurrencyCollector
from immoney.registry import CurrencyRegistry


class CustomCurrency(Currency, abc.ABC):
@classmethod
def get_default_registry(cls) -> CurrencyRegistry:
def get_default_registry(cls) -> CurrencyRegistry[CustomCurrency]:
return registry


__currencies: Final = CurrencyCollector()
__currencies: Final = CurrencyCollector[CustomCurrency]()


class JupiterCoinType(CustomCurrency):
Expand All @@ -34,4 +38,5 @@ class MoonCoinType(CustomCurrency):


registry: Final = __currencies.finalize()
assert_type(registry, CurrencyRegistry[CustomCurrency])
del __currencies
2 changes: 1 addition & 1 deletion tests/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from immoney import currencies
from immoney.registry import CurrencyCollector

collector: Final = CurrencyCollector()
collector: Final = CurrencyCollector[Currency]()


@final
Expand Down

0 comments on commit 17ee450

Please sign in to comment.