diff --git a/testing/src/scenario/state.py b/testing/src/scenario/state.py index eb4f804ec..59ba99d13 100644 --- a/testing/src/scenario/state.py +++ b/testing/src/scenario/state.py @@ -49,7 +49,7 @@ from . import Context class _StateKwargs(TypedDict, total=False): - config: dict[str, str | int | float | bool] + config: Mapping[str, str | int | float | bool] relations: Iterable[RelationBase] networks: Iterable[Network] containers: Iterable[Container] @@ -60,7 +60,7 @@ class _StateKwargs(TypedDict, total=False): secrets: Iterable[Secret] resources: Iterable[Resource] planned_units: int - deferred: Sequence[DeferredEvent] + deferred: Iterable[DeferredEvent] stored_states: Iterable[StoredState] app_status: _EntityStatus unit_status: _EntityStatus @@ -68,7 +68,7 @@ class _StateKwargs(TypedDict, total=False): AnyJson = str | bool | dict[str, 'AnyJson'] | int | float | list['AnyJson'] -RawSecretRevisionContents = RawDataBagContents = dict[str, str] +RawSecretRevisionContents = RawDataBagContents = Mapping[str, str] UnitID = int CharmType = TypeVar('CharmType', bound=CharmBase) @@ -154,14 +154,14 @@ class JujuLogLine: """The log message.""" -@dataclasses.dataclass(frozen=True, kw_only=True) +@dataclasses.dataclass(frozen=True, kw_only=True, init=False) class CloudCredential: # noqa: D101 __doc__ = ops.CloudCredential.__doc__ auth_type: str """Authentication type.""" - attributes: Mapping[str, str] = dataclasses.field(default_factory=dict) + attributes: Mapping[str, str] """A dictionary containing cloud credentials. For example, for AWS, it contains `access-key` and `secret-key`; @@ -169,11 +169,19 @@ class CloudCredential: # noqa: D101 can be found here. """ - redacted: Sequence[str] = dataclasses.field(default_factory=list) + redacted: Sequence[str] """A list of redacted generic cloud API secrets.""" - def __post_init__(self): - _deepcopy_mutable_fields(self) + def __init__( + self, + *, + auth_type: str, + attributes: Mapping[str, str] = {}, + redacted: Iterable[str] = (), + ): + object.__setattr__(self, 'auth_type', auth_type) + object.__setattr__(self, 'attributes', dict(attributes)) + object.__setattr__(self, 'redacted', list(redacted)) def _to_ops(self) -> CloudCredential_Ops: return CloudCredential_Ops( @@ -183,44 +191,64 @@ def _to_ops(self) -> CloudCredential_Ops: ) -@dataclasses.dataclass(frozen=True) +@dataclasses.dataclass(frozen=True, init=False) class CloudSpec: # noqa: D101 __doc__ = ops.CloudSpec.__doc__ type: str """Type of the cloud.""" - _: dataclasses.KW_ONLY - - name: str = 'localhost' + name: str """Juju cloud name.""" - region: str | None = None + region: str | None """Region of the cloud.""" - endpoint: str | None = None + endpoint: str | None """Endpoint of the cloud.""" - identity_endpoint: str | None = None + identity_endpoint: str | None """Identity endpoint of the cloud.""" - storage_endpoint: str | None = None + storage_endpoint: str | None """Storage endpoint of the cloud.""" - credential: CloudCredential | None = None + credential: CloudCredential | None """Cloud credentials with key-value attributes.""" - ca_certificates: Sequence[str] = dataclasses.field(default_factory=list) + ca_certificates: Sequence[str] """A list of CA certificates.""" - skip_tls_verify: bool = False + skip_tls_verify: bool """Whether to skip TLS verification.""" - is_controller_cloud: bool = False + is_controller_cloud: bool """If this is the cloud used by the controller.""" - def __post_init__(self): - _deepcopy_mutable_fields(self) + def __init__( + self, + type: str, + *, + name: str = 'localhost', + region: str | None = None, + endpoint: str | None = None, + identity_endpoint: str | None = None, + storage_endpoint: str | None = None, + credential: CloudCredential | None = None, + ca_certificates: Iterable[str] = (), + skip_tls_verify: bool = False, + is_controller_cloud: bool = False, + ): + object.__setattr__(self, 'type', type) + object.__setattr__(self, 'name', name) + object.__setattr__(self, 'region', region) + object.__setattr__(self, 'endpoint', endpoint) + object.__setattr__(self, 'identity_endpoint', identity_endpoint) + object.__setattr__(self, 'storage_endpoint', storage_endpoint) + object.__setattr__(self, 'credential', credential) + object.__setattr__(self, 'ca_certificates', list(ca_certificates)) + object.__setattr__(self, 'skip_tls_verify', skip_tls_verify) + object.__setattr__(self, 'is_controller_cloud', is_controller_cloud) def _to_ops(self) -> CloudSpec_Ops: return CloudSpec_Ops( @@ -244,7 +272,7 @@ def _generate_secret_id(): return f'secret:{secret_id}' -@dataclasses.dataclass(frozen=True) +@dataclasses.dataclass(frozen=True, init=False) class Secret: """A Juju secret. @@ -257,21 +285,19 @@ class Secret: This is the content the charm will receive with a :meth:`ops.Secret.get_content` call.""" - _: dataclasses.KW_ONLY - - latest_content: RawSecretRevisionContents | None = None + latest_content: RawSecretRevisionContents """The content of the latest revision of the secret. This is the content the charm will receive with a :meth:`ops.Secret.peek_content` call.""" - id: str = dataclasses.field(default_factory=_generate_secret_id) + id: str """The Juju ID of the secret. This is automatically assigned and should not usually need to be explicitly set. """ - owner: Literal['unit', 'app', None] = None + owner: Literal['unit', 'app'] | None """Indicates if the secret is owned by *this* unit, *this* application, or another application/unit. @@ -279,40 +305,64 @@ class Secret: to this unit. """ - remote_grants: Mapping[int, set[str]] = dataclasses.field(default_factory=dict) + remote_grants: Mapping[int, set[str]] """Mapping from relation IDs to remote units and applications to which this secret has been granted.""" - label: str | None = None + label: str | None """A human-readable label the charm can use to retrieve the secret. If this is set, it implies that the charm has previously set the label. """ - description: str | None = None + description: str | None """A human-readable description of the secret.""" - expire: datetime.datetime | None = None + expire: datetime.datetime | None """The time at which the secret will expire.""" - rotate: SecretRotate | None = None + rotate: SecretRotate | None """The rotation policy for the secret.""" # what revision is currently tracked by this charm. Only meaningful if owner=False - _tracked_revision: int = 1 + _tracked_revision = 1 # what revision is the latest for this secret. - _latest_revision: int = 1 + _latest_revision = 1 + + def __init__( + self, + tracked_content: RawSecretRevisionContents, + *, + latest_content: RawSecretRevisionContents | None = None, + id: str | None = None, + owner: Literal['unit', 'app'] | None = None, + remote_grants: Mapping[int, set[str]] = {}, + label: str | None = None, + description: str | None = None, + expire: datetime.datetime | None = None, + rotate: SecretRotate | None = None, + ): + self._validate_content(tracked_content, 'tracked_content') + if latest_content is not None: + self._validate_content(latest_content, 'latest_content') + object.__setattr__(self, 'tracked_content', tracked_content) + object.__setattr__( + self, + 'latest_content', + latest_content if latest_content is not None else tracked_content, + ) + object.__setattr__(self, 'id', id if id is not None else _generate_secret_id()) + object.__setattr__(self, 'owner', owner) + object.__setattr__(self, 'remote_grants', dict(remote_grants)) + object.__setattr__(self, 'label', label) + object.__setattr__(self, 'description', description) + object.__setattr__(self, 'expire', expire) + object.__setattr__(self, 'rotate', rotate) + object.__setattr__(self, '_tracked_revision', 1) + object.__setattr__(self, '_latest_revision', 1) + _deepcopy_mutable_fields(self) def __hash__(self) -> int: return hash(self.id) - def __post_init__(self): - self._validate_content(self.tracked_content, 'tracked_content') - if self.latest_content is not None: - self._validate_content(self.latest_content, 'latest_content') - if self.latest_content is None: - # bypass frozen dataclass - object.__setattr__(self, 'latest_content', self.tracked_content) - _deepcopy_mutable_fields(self) - @staticmethod def _validate_content(content: dict[str, str], name: str): if not isinstance(content, dict): @@ -425,7 +475,7 @@ def _hook_tool_output_fmt(self): return dct -@dataclasses.dataclass(frozen=True) +@dataclasses.dataclass(frozen=True, init=False) class Network: """A Juju network space. @@ -448,28 +498,44 @@ class Network: binding_name: str """The name of the network space.""" - bind_addresses: Sequence[BindAddress] = dataclasses.field( - default_factory=lambda: [BindAddress([Address('192.0.2.0')])], - ) + bind_addresses: Sequence[BindAddress] """Addresses that the charm's application should bind to.""" - _: dataclasses.KW_ONLY - - ingress_addresses: Sequence[str] = dataclasses.field( - default_factory=lambda: ['192.0.2.0'], - ) + ingress_addresses: Sequence[str] """Addresses other applications should use to connect to the unit.""" - egress_subnets: Sequence[str] = dataclasses.field( - default_factory=lambda: ['192.0.2.0/24'], - ) + egress_subnets: Sequence[str] """Subnets that other units will see the charm connecting from.""" + def __init__( + self, + binding_name: str, + bind_addresses: Iterable[BindAddress] | None = None, + *, + ingress_addresses: Iterable[str] | None = None, + egress_subnets: Iterable[str] | None = None, + ): + object.__setattr__(self, 'binding_name', binding_name) + object.__setattr__( + self, + 'bind_addresses', + list(bind_addresses) + if bind_addresses is not None + else [BindAddress([Address('192.0.2.0')])], + ) + object.__setattr__( + self, + 'ingress_addresses', + list(ingress_addresses) if ingress_addresses is not None else ['192.0.2.0'], + ) + object.__setattr__( + self, + 'egress_subnets', + list(egress_subnets) if egress_subnets is not None else ['192.0.2.0/24'], + ) + def __hash__(self) -> int: return hash(self.binding_name) - def __post_init__(self): - _deepcopy_mutable_fields(self) - def _hook_tool_output_fmt(self): # dumps itself to dict in the same format the hook command would return { @@ -603,7 +669,7 @@ class Relation(RelationBase): remote_app_data: RawDataBagContents = dataclasses.field(default_factory=dict) """The current content of the application databag.""" - remote_units_data: dict[UnitID, RawDataBagContents] = dataclasses.field( + remote_units_data: Mapping[UnitID, RawDataBagContents] = dataclasses.field( default_factory=lambda: {0: _DEFAULT_JUJU_DATABAG.copy()}, # dedup ) """The current content of the databag for each unit in the relation.""" @@ -688,7 +754,7 @@ def remote_unit_name(self) -> str: class PeerRelation(RelationBase): """A relation to share data between units of the charm.""" - peers_data: dict[UnitID, RawDataBagContents] = dataclasses.field(default_factory=dict) + peers_data: Mapping[UnitID, RawDataBagContents] = dataclasses.field(default_factory=dict) """Current contents of the peer databags. Note that this does not include data for the unit being tested. Data for @@ -786,26 +852,24 @@ def _generate_new_change_id(): return _CHANGE_IDS -@dataclasses.dataclass(frozen=True) +@dataclasses.dataclass(frozen=True, init=False) class Exec: """Mock data for simulated :meth:`ops.Container.exec` calls.""" - command_prefix: Sequence[str] + command_prefix: Sequence[str, ...] - _: dataclasses.KW_ONLY - - return_code: int = 0 + return_code: int """The return code of the process. Use 0 to mock the process ending successfully, and other values for failure. """ - stdout: str = '' + stdout: str """Any content written to stdout by the process. Provide content that the real process would write to stdout, which can be read by the charm. """ - stderr: str = '' + stderr: str """Any content written to stderr by the process. Provide content that the real process would write to stderr, which can be @@ -813,14 +877,24 @@ class Exec: """ # change ID: used internally to keep track of mocked processes - _change_id: int = dataclasses.field(default_factory=_generate_new_change_id) + _change_id: int - def __post_init__(self): - # The command prefix can be any sequence type, and a list is tidier to - # write when there's only one string. However, this object needs to be - # hashable, so can't contain a list. We 'freeze' the sequence to a tuple - # to support that. - object.__setattr__(self, 'command_prefix', tuple(self.command_prefix)) + def __init__( + self, + command_prefix: Sequence[str], + *, + return_code: int = 0, + stdout: str = '', + stderr: str = '', + _change_id: int | None = None, + ): + object.__setattr__(self, 'command_prefix', tuple(command_prefix)) + object.__setattr__(self, 'return_code', return_code) + object.__setattr__(self, 'stdout', stdout) + object.__setattr__(self, 'stderr', stderr) + object.__setattr__( + self, '_change_id', _change_id if _change_id is not None else _generate_new_change_id() + ) def _run(self) -> int: return self._change_id @@ -856,7 +930,7 @@ def _next_notice_id(*, update: bool = True): return str(cur) -@dataclasses.dataclass(frozen=True) +@dataclasses.dataclass(frozen=True, init=False) class Notice: """A Pebble notice.""" @@ -867,44 +941,72 @@ class Notice: ``canonical.com/postgresql/backup`` or ``example.com/mycharm/notice``. """ - _: dataclasses.KW_ONLY - - id: str = dataclasses.field(default_factory=_next_notice_id) + id: str """Unique ID for this notice.""" - user_id: int | None = None + user_id: int | None """UID of the user who may view this notice (None means notice is public).""" - type: pebble.NoticeType | str = pebble.NoticeType.CUSTOM + type: pebble.NoticeType | str """Type of the notice.""" - first_occurred: datetime.datetime = dataclasses.field(default_factory=_now_utc) + first_occurred: datetime.datetime """The first time one of these notices (type and key combination) occurs.""" - last_occurred: datetime.datetime = dataclasses.field(default_factory=_now_utc) + last_occurred: datetime.datetime """The last time one of these notices occurred.""" - last_repeated: datetime.datetime = dataclasses.field(default_factory=_now_utc) + last_repeated: datetime.datetime """The time this notice was last repeated. See Pebble's `Notices documentation `_ for an explanation of what "repeated" means. """ - occurrences: int = 1 + occurrences: int """The number of times one of these notices has occurred.""" - last_data: Mapping[str, str] = dataclasses.field(default_factory=dict) + last_data: Mapping[str, str] """Additional data captured from the last occurrence of one of these notices.""" - repeat_after: datetime.timedelta | None = None + repeat_after: datetime.timedelta | None """Minimum time after one of these was last repeated before Pebble will repeat it again.""" - expire_after: datetime.timedelta | None = None + expire_after: datetime.timedelta | None """How long since one of these last occurred until Pebble will drop the notice.""" - def __post_init__(self): - _deepcopy_mutable_fields(self) + def __init__( + self, + key: str, + *, + id: str | None = None, + user_id: int | None = None, + type: pebble.NoticeType | str = pebble.NoticeType.CUSTOM, + first_occurred: datetime.datetime | None = None, + last_occurred: datetime.datetime | None = None, + last_repeated: datetime.datetime | None = None, + occurrences: int = 1, + last_data: Mapping[str, str] = {}, + repeat_after: datetime.timedelta | None = None, + expire_after: datetime.timedelta | None = None, + ): + object.__setattr__(self, 'key', key) + object.__setattr__(self, 'id', id if id is not None else _next_notice_id()) + object.__setattr__(self, 'user_id', user_id) + object.__setattr__(self, 'type', type) + object.__setattr__( + self, 'first_occurred', first_occurred if first_occurred is not None else _now_utc() + ) + object.__setattr__( + self, 'last_occurred', last_occurred if last_occurred is not None else _now_utc() + ) + object.__setattr__( + self, 'last_repeated', last_repeated if last_repeated is not None else _now_utc() + ) + object.__setattr__(self, 'occurrences', occurrences) + object.__setattr__(self, 'last_data', dict(last_data)) + object.__setattr__(self, 'repeat_after', repeat_after) + object.__setattr__(self, 'expire_after', expire_after) def _to_ops(self) -> pebble.Notice: return pebble.Notice( @@ -1012,16 +1114,14 @@ def _to_ops(self) -> pebble.CheckInfo: ) -@dataclasses.dataclass(frozen=True) +@dataclasses.dataclass(frozen=True, init=False) class Container: """A Kubernetes container where a charm's workload runs.""" name: str """Name of the container, as found in the charm metadata.""" - _: dataclasses.KW_ONLY - - can_connect: bool = False + can_connect: bool """When False, all Pebble operations will fail.""" # This is the base plan. On top of it, one can add layers. @@ -1029,10 +1129,10 @@ class Container: # pebble or derive them from the resulting plan (which one CAN get from pebble). # So if we are instantiating Container by fetching info from a 'live' charm, the 'layers' # will be unknown. all that we can know is the resulting plan (the 'computed plan'). - _base_plan: Mapping[str, Any] = dataclasses.field(default_factory=dict) + _base_plan: Mapping[str, Any] # We expect most of the user-facing testing to be covered by this 'layers' attribute, # as it is all that will be known when unit-testing. - layers: Mapping[str, pebble.Layer] = dataclasses.field(default_factory=dict) + layers: Mapping[str, pebble.Layer] """All :class:`ops.pebble.Layer` definitions that have already been added to the container. Note that the layers should be added to the dictionary in the order in which they would have @@ -1041,12 +1141,10 @@ class Container: this means adding them in the order of the API calls. """ - service_statuses: Mapping[str, pebble.ServiceStatus] = dataclasses.field( - default_factory=dict, - ) + service_statuses: Mapping[str, pebble.ServiceStatus] """The current status of each Pebble service running in the container.""" - mounts: Mapping[str, Mount] = dataclasses.field(default_factory=dict) + mounts: Mapping[str, Mount] """Provides access to the contents of the simulated container filesystem. For example, suppose you want to express that your container has: @@ -1069,7 +1167,7 @@ class Container: be safely modified. """ - execs: Iterable[Exec] = frozenset() + execs: frozenset[Exec] """Simulate executing commands in the container. Specify each command the charm might run in the container and an :class:`Exec` @@ -1090,21 +1188,39 @@ class Container: ) """ - notices: Sequence[Notice] = dataclasses.field(default_factory=list) + notices: Sequence[Notice] """Any Pebble notices that already exist in the container.""" - check_infos: Iterable[CheckInfo] = frozenset() + check_infos: frozenset[CheckInfo] """All Pebble health checks that have been added to the container.""" + def __init__( + self, + name: str, + *, + can_connect: bool = False, + _base_plan: Mapping[str, Any] = {}, + layers: Mapping[str, pebble.Layer] = {}, + service_statuses: Mapping[str, pebble.ServiceStatus] = {}, + mounts: Mapping[str, Mount] = {}, + execs: Iterable[Exec] = (), + notices: Iterable[Notice] = (), + check_infos: Iterable[CheckInfo] = (), + ): + object.__setattr__(self, 'name', name) + object.__setattr__(self, 'can_connect', can_connect) + object.__setattr__(self, '_base_plan', dict(_base_plan)) + object.__setattr__(self, 'layers', dict(layers)) + object.__setattr__(self, 'service_statuses', dict(service_statuses)) + object.__setattr__(self, 'mounts', dict(mounts)) + object.__setattr__(self, 'execs', frozenset(execs)) + object.__setattr__(self, 'notices', list(notices)) + object.__setattr__(self, 'check_infos', frozenset(check_infos)) + _deepcopy_mutable_fields(self) + def __hash__(self) -> int: return hash(self.name) - def __post_init__(self): - if not isinstance(self.execs, frozenset): - # Allow passing a regular set (or other iterable) of Execs. - object.__setattr__(self, 'execs', frozenset(self.execs)) - _deepcopy_mutable_fields(self) - def _render_services(self): services: dict[str, pebble.Service] = {} for layer in self.layers.values(): @@ -1360,11 +1476,11 @@ def __init__(self, message: str = ''): ) -@dataclasses.dataclass(frozen=True) +@dataclasses.dataclass(frozen=True, init=False) class StoredState: """Represents unit-local state that persists across events.""" - name: str = '_stored' + name: str """The attribute in the parent Object where the state is stored. For example, ``_stored`` in this class:: @@ -1374,9 +1490,7 @@ class MyCharm(ops.CharmBase): """ - _: dataclasses.KW_ONLY - - owner_path: str | None = None + owner_path: str | None """The path to the owner of this StoredState instance. If ``None``, the owner is the Framework. Otherwise, /-separated object names, @@ -1387,18 +1501,29 @@ class MyCharm(ops.CharmBase): # However, it's complex to describe those types, since it's a recursive # definition - even in TypeShed the _Marshallable type includes containers # like list[Any], which seems to defeat the point. - content: Mapping[str, Any] = dataclasses.field(default_factory=dict) + content: Mapping[str, Any] """The content of the :class:`ops.StoredState` instance.""" - _data_type_name: str = 'StoredStateData' + _data_type_name: str + + def __init__( + self, + name: str = '_stored', + *, + owner_path: str | None = None, + content: Mapping[str, Any] = {}, + _data_type_name: str = 'StoredStateData', + ): + object.__setattr__(self, 'name', name) + object.__setattr__(self, 'owner_path', owner_path) + object.__setattr__(self, 'content', dict(content)) + object.__setattr__(self, '_data_type_name', _data_type_name) + _deepcopy_mutable_fields(self) @property def _handle_path(self): return f'{self.owner_path or ""}/{self._data_type_name}[{self.name}]' - def __post_init__(self): - _deepcopy_mutable_fields(self) - def __hash__(self) -> int: return hash(self._handle_path) @@ -1559,7 +1684,7 @@ class Resource: """A local path that will be provided to the charm as the content of the resource.""" -@dataclasses.dataclass(frozen=True, kw_only=True) +@dataclasses.dataclass(frozen=True, kw_only=True, init=False) class State: """Represents the Juju-owned portion of a unit's state. @@ -1568,13 +1693,11 @@ class State: return data from `State.leader`, and so on. """ - config: dict[str, str | int | float | bool] = dataclasses.field( - default_factory=dict, - ) + config: Mapping[str, str | int | float | bool] """The present configuration of this charm.""" - relations: Iterable[RelationBase] = dataclasses.field(default_factory=frozenset) + relations: frozenset[RelationBase] """All relations that currently exist for this charm.""" - networks: Iterable[Network] = dataclasses.field(default_factory=frozenset) + networks: frozenset[Network] """Manual overrides for any relation and extra bindings currently provisioned for this charm. If a metadata-defined relation endpoint is not explicitly mapped to a Network in this field, it will be defaulted. @@ -1584,28 +1707,28 @@ class State: support it, but use at your own risk. If a metadata-defined extra-binding is left empty, it will be defaulted. """ - containers: Iterable[Container] = dataclasses.field(default_factory=frozenset) + containers: frozenset[Container] """All containers (whether they can connect or not) that this charm is aware of.""" - storages: Iterable[Storage] = dataclasses.field(default_factory=frozenset) + storages: frozenset[Storage] """All **attached** storage instances for this charm. If a storage is not attached, omit it from this listing.""" # we don't use sets to make json serialization easier - opened_ports: Iterable[Port] = dataclasses.field(default_factory=frozenset) + opened_ports: frozenset[Port] """Ports opened by Juju on this charm.""" - leader: bool = False + leader: bool """Whether this charm has leadership.""" - model: Model = dataclasses.field(default_factory=Model) + model: Model """The model this charm lives in.""" - secrets: Iterable[Secret] = dataclasses.field(default_factory=frozenset) + secrets: frozenset[Secret] """The secrets this charm has access to (as an owner, or as a grantee). The presence of a secret in this list entails that the charm can read it. Whether it can manage it or not depends on the individual secret's `owner` flag.""" - resources: Iterable[Resource] = dataclasses.field(default_factory=frozenset) + resources: frozenset[Resource] """All resources that this charm can access.""" - planned_units: int = 1 + planned_units: int """Number of non-dying planned units that are expected to be running this application. Use with caution.""" @@ -1614,45 +1737,72 @@ class State: # dispatched, and represent the events that had been deferred during the previous run. # If the charm defers any events during "this execution", they will be appended # to this list. - deferred: Sequence[DeferredEvent] = dataclasses.field(default_factory=list) + deferred: Sequence[DeferredEvent] """Events that have been deferred on this charm by some previous execution.""" - stored_states: Iterable[StoredState] = dataclasses.field( - default_factory=frozenset, - ) + stored_states: frozenset[StoredState] """Contents of a charm's stored state.""" # the current statuses. - app_status: _EntityStatus = dataclasses.field(default_factory=UnknownStatus) + app_status: _EntityStatus """Status of the application.""" - unit_status: _EntityStatus = dataclasses.field(default_factory=UnknownStatus) + unit_status: _EntityStatus """Status of the unit.""" - workload_version: str = '' + workload_version: str """Workload version.""" - def __post_init__(self): - # Let people pass in the ops classes, and convert them to the appropriate Scenario classes. - for name in ['app_status', 'unit_status']: - val = getattr(self, name) - if isinstance(val, _EntityStatus): - pass + def __init__( + self, + *, + config: Mapping[str, str | int | float | bool] | None = None, + relations: Iterable[RelationBase] = (), + networks: Iterable[Network] = (), + containers: Iterable[Container] = (), + storages: Iterable[Storage] = (), + opened_ports: Iterable[Port] = (), + leader: bool = False, + model: Model | None = None, + secrets: Iterable[Secret] = (), + resources: Iterable[Resource] = (), + planned_units: int = 1, + deferred: Iterable[DeferredEvent] = (), + stored_states: Iterable[StoredState] = (), + app_status: _EntityStatus | StatusBase | None = None, + unit_status: _EntityStatus | StatusBase | None = None, + workload_version: str = '', + ): + object.__setattr__(self, 'config', config if config is not None else {}) + object.__setattr__(self, 'leader', leader) + object.__setattr__(self, 'model', model if model is not None else Model()) + object.__setattr__(self, 'planned_units', planned_units) + object.__setattr__(self, 'deferred', list(deferred)) + object.__setattr__(self, 'workload_version', workload_version) + + # Handle status conversion from ops types. + for name, val in [('app_status', app_status), ('unit_status', unit_status)]: + if val is None: + object.__setattr__(self, name, UnknownStatus()) + elif isinstance(val, _EntityStatus): + object.__setattr__(self, name, val) elif isinstance(val, StatusBase): object.__setattr__(self, name, _EntityStatus.from_ops(val)) else: raise TypeError(f'Invalid status.{name}: {val!r}') - normalised_ports = [ + + # Normalise ops.Port to scenario Port. + normalised_ports = frozenset( Port(protocol=port.protocol, port=port.port) if isinstance(port, ops.Port) else port - for port in self.opened_ports - ] - if self.opened_ports != normalised_ports: - object.__setattr__(self, 'opened_ports', normalised_ports) - normalised_storage = [ + for port in opened_ports + ) + object.__setattr__(self, 'opened_ports', normalised_ports) + + # Normalise ops.Storage to scenario Storage. + normalised_storage = frozenset( Storage(name=storage.name, index=storage.index) if isinstance(storage, ops.Storage) else storage - for storage in self.storages - ] - if self.storages != normalised_storage: - object.__setattr__(self, 'storages', normalised_storage) + for storage in storages + ) + object.__setattr__(self, 'storages', normalised_storage) # ops.Container, ops.Model, ops.Relation, ops.Secret should not be instantiated by # charmers. @@ -1660,23 +1810,12 @@ def __post_init__(self): # ops.Resources does not contain the source of the resource, so cannot be converted. # ops.StoredState is not convenient to initialise with data, so not useful here. - # It's convenient to pass a set, but we really want the attributes to be - # frozen sets to increase the immutability of State objects. - for name in [ - 'relations', - 'containers', - 'storages', - 'networks', - 'opened_ports', - 'secrets', - 'resources', - 'stored_states', - ]: - val = getattr(self, name) - # It's ok to pass any iterable (of hashable objects), but you'll get - # a frozenset as the actual attribute. - if not isinstance(val, frozenset): - object.__setattr__(self, name, frozenset(val)) + object.__setattr__(self, 'relations', frozenset(relations)) + object.__setattr__(self, 'networks', frozenset(networks)) + object.__setattr__(self, 'containers', frozenset(containers)) + object.__setattr__(self, 'secrets', frozenset(secrets)) + object.__setattr__(self, 'resources', frozenset(resources)) + object.__setattr__(self, 'stored_states', frozenset(stored_states)) def _update_workload_version(self, new_workload_version: str): """Update the current app version and record the previous one.""" @@ -1817,7 +1956,7 @@ def from_context( ctx: Context[CharmType], *, # If provided, these merge with or replace the generated versions. - config: dict[str, str | int | float | bool] | None = None, + config: Mapping[str, str | int | float | bool] | None = None, relations: Iterable[RelationBase] | None = None, containers: Iterable[Container] | None = None, storages: Iterable[Storage] | None = None, @@ -1999,9 +2138,9 @@ class _CharmSpec(Generic[CharmType]): """Charm spec.""" charm_type: type[CharmBase] - meta: dict[str, Any] - actions: dict[str, Any] | None = None - config: dict[str, Any] | None = None + meta: Mapping[str, Any] + actions: Mapping[str, Any] | None = None + config: Mapping[str, Any] | None = None # autoloaded means: we are running a 'real' charm class, living in some # /src/charm.py, and the metadata files are 'real' metadata files. diff --git a/testing/tests/test_e2e/test_state.py b/testing/tests/test_e2e/test_state.py index 2d327f901..99e2cdfcb 100644 --- a/testing/tests/test_e2e/test_state.py +++ b/testing/tests/test_e2e/test_state.py @@ -373,7 +373,6 @@ def test_replace_state(): (CloudCredential, 'attributes', {'auth_type': 'foo'}), (Secret, 'tracked_content', {}), (Secret, 'latest_content', {'tracked_content': {'password': 'password'}}), - (Secret, 'remote_grants', {'tracked_content': {'password': 'password'}}), (Relation, 'local_app_data', {'endpoint': 'foo'}), (Relation, 'local_unit_data', {'endpoint': 'foo'}), (Relation, 'remote_app_data', {'endpoint': 'foo'}), @@ -387,7 +386,6 @@ def test_replace_state(): (Container, 'layers', {'name': 'foo'}), (Container, 'service_statuses', {'name': 'foo'}), (Container, 'mounts', {'name': 'foo'}), - (Container, 'notices', {'name': 'foo'}), (StoredState, 'content', {}), ], ) @@ -408,6 +406,19 @@ def test_immutable_content_dict( assert getattr(obj2, attribute) == {'foo': 'bar'} +def test_immutable_remote_grants(): + content = {0: {'app1', 'app2'}} + obj1 = Secret(tracked_content={'password': 'password'}, remote_grants=content) + obj2 = Secret(tracked_content={'password': 'password'}, remote_grants=content) + assert obj1.remote_grants == obj2.remote_grants == {0: frozenset({'app1', 'app2'})} + assert obj1.remote_grants is not obj2.remote_grants + content[1] = {'app3'} + assert obj1.remote_grants == obj2.remote_grants == {0: frozenset({'app1', 'app2'})} + object.__setattr__(obj1, 'remote_grants', {1: frozenset({'app3'})}) + assert obj1.remote_grants == {1: frozenset({'app3'})} + assert obj2.remote_grants == {0: frozenset({'app1', 'app2'})} + + @pytest.mark.parametrize( 'component,attribute,required_args', [ @@ -417,6 +428,7 @@ def test_immutable_content_dict( (Network, 'bind_addresses', {'binding_name': 'foo'}), (Network, 'ingress_addresses', {'binding_name': 'foo'}), (Network, 'egress_subnets', {'binding_name': 'foo'}), + (Container, 'notices', {'name': 'foo'}), ], ) def test_immutable_content_list( @@ -464,6 +476,53 @@ def test_immutable_content_dict_of_dicts( assert getattr(obj2, attribute) == {0: {'foo': 'bar'}, 1: {'baz': 'qux'}} +@pytest.mark.parametrize( + 'component,attribute,expected_type,input_value,required_args', + [ + # Mapping -> dict + (CloudCredential, 'attributes', dict, {'a': 'b'}, {'auth_type': 'foo'}), + (Secret, 'remote_grants', dict, {1: {'app'}}, {'tracked_content': {'k': 'v'}}), + (Notice, 'last_data', dict, {'k': 'v'}, {'key': 'foo'}), + (Container, 'layers', dict, {}, {'name': 'foo'}), + (Container, 'service_statuses', dict, {}, {'name': 'foo'}), + (Container, 'mounts', dict, {}, {'name': 'foo'}), + (StoredState, 'content', dict, {'k': 'v'}, {}), + # Iterable -> list + (CloudCredential, 'redacted', list, ('a', 'b'), {'auth_type': 'foo'}), + (CloudSpec, 'ca_certificates', list, ('a', 'b'), {'type': 'foo'}), + ( + Network, + 'bind_addresses', + list, + iter([BindAddress([Address('192.0.2.0')])]), + {'binding_name': 'foo'}, + ), + (Network, 'ingress_addresses', list, ('1.2.3.4',), {'binding_name': 'foo'}), + (Network, 'egress_subnets', list, ('1.2.3.0/24',), {'binding_name': 'foo'}), + (Container, 'notices', list, (Notice(key='foo'),), {'name': 'foo'}), + (State, 'deferred', list, (), {}), + # Iterable -> frozenset + (Container, 'execs', frozenset, (), {'name': 'foo'}), + (Container, 'check_infos', frozenset, (), {'name': 'foo'}), + (State, 'relations', frozenset, (Relation(endpoint='foo'),), {}), + (State, 'networks', frozenset, (Network(binding_name='foo'),), {}), + (State, 'containers', frozenset, (Container(name='foo'),), {}), + (State, 'secrets', frozenset, (Secret(tracked_content={'k': 'v'}),), {}), + (State, 'stored_states', frozenset, (), {}), + ], +) +def test_init_converts_to_concrete_type( + component: type[object], + attribute: str, + expected_type: type, + input_value: Any, + required_args: dict[str, Any], +): + """Verify that __init__ converts broader input types to concrete attribute types.""" + obj = component(**required_args, **{attribute: input_value}) + assert isinstance(getattr(obj, attribute), expected_type) + + @pytest.mark.parametrize( 'obj_in,attribute,get_method,key_attr', [