diff --git a/ops/_private/harness.py b/ops/_private/harness.py index 898046978..1fe5324d2 100644 --- a/ops/_private/harness.py +++ b/ops/_private/harness.py @@ -3000,15 +3000,39 @@ def secret_remove(self, id: str, *, revision: int | None = None) -> None: else: self._secrets = [s for s in self._secrets if not self._secret_ids_are_equal(s.id, id)] - def open_port(self, protocol: str, port: int | None = None): + def open_port( + self, + protocol: str, + port: int | None = None, + *, + to_port: int | None = None, + endpoints: Sequence[str] = '*', + ): self._check_protocol_and_port(protocol, port) protocol_lit = cast('Literal["tcp", "udp", "icmp"]', protocol) - self._opened_ports.add(model.Port(protocol_lit, port)) - - def close_port(self, protocol: str, port: int | None = None): + endpoints = tuple(endpoints) if endpoints != '*' else '*' + # This only really works for the happy path. + # We should really be checking for overlapping port ranges (and erroring), + # and merging with existing entries for endpoints, but since harness is deprecated, + # and this is only needed for charms adopting this new feature, we can get away with it. + self._opened_ports.add(model.Port(protocol_lit, port, to_port, endpoints=endpoints)) + + def close_port( + self, + protocol: str, + port: int | None = None, + *, + to_port: int | None = None, + endpoints: Sequence[str] = '*', + ): self._check_protocol_and_port(protocol, port) protocol_lit = cast('Literal["tcp", "udp", "icmp"]', protocol) - self._opened_ports.discard(model.Port(protocol_lit, port)) + endpoints = tuple(endpoints) if endpoints != '*' else '*' + # This only really works for the happy path. + # We should really be checking for overlapping port ranges (and erroring), + # and merging with existing entries for endpoints, but since harness is deprecated, + # and this is only needed for charms adopting this new feature, we can get away with it. + self._opened_ports.discard(model.Port(protocol_lit, port, to_port, endpoints=endpoints)) def opened_ports(self) -> set[model.Port]: return set(self._opened_ports) diff --git a/ops/hookcmds/_port.py b/ops/hookcmds/_port.py index f8f2f55a4..a352ff03f 100644 --- a/ops/hookcmds/_port.py +++ b/ops/hookcmds/_port.py @@ -29,7 +29,7 @@ def close_port( *, to_port: int | None = None, endpoints: str | Iterable[str] | None = None, -) -> None: ... +) -> str | None: ... @overload def close_port( protocol: str | None, @@ -37,14 +37,14 @@ def close_port( *, to_port: int | None = None, endpoints: str | Iterable[str] | None = None, -) -> None: ... +) -> str | None: ... def close_port( protocol: str | None = None, port: int | None = None, *, to_port: int | None = None, endpoints: str | Iterable[str] | None = None, -): +) -> str | None: """Register a request to close a port or port range. For more details, see: @@ -58,13 +58,16 @@ def close_port( if port is None: if protocol is None: raise TypeError('You must provide a port or protocol.') + if to_port is not None: + raise TypeError('to_port cannot be specified if port is not specified') args.append(protocol) else: port_arg = f'{port}-{to_port}' if to_port is not None else str(port) if protocol is not None: port_arg = f'{port_arg}/{protocol}' args.append(port_arg) - run('close-port', *args) + result = run('close-port', *args).strip() + return result or None @overload @@ -74,7 +77,7 @@ def open_port( *, to_port: int | None = None, endpoints: str | Iterable[str] | None = None, -) -> None: ... +) -> str | None: ... @overload def open_port( protocol: str | None, @@ -82,14 +85,14 @@ def open_port( *, to_port: int | None = None, endpoints: str | Iterable[str] | None = None, -) -> None: ... +) -> str | None: ... def open_port( protocol: str | None = None, port: int | None = None, *, to_port: int | None = None, endpoints: str | Iterable[str] | None = None, -): +) -> str | None: """Register a request to open a port or port range. For more details, see: @@ -113,13 +116,21 @@ def open_port( if port is None: if protocol is None: raise TypeError('You must provide a port or protocol.') + if to_port is not None: + raise TypeError('to_port can only be specified if port is also specified') args.append(protocol) else: port_arg = f'{port}-{to_port}' if to_port is not None else str(port) if protocol is not None: port_arg = f'{port_arg}/{protocol}' args.append(port_arg) - run('open-port', *args) + # In the happy case (already open or opened successfully) open-ports exits silently with 0. + # If open-port exits with an error code, then run will raise an Error. + # Specifying a non-existent endpoint exits with 0, but prints an error message, + # **and does not open the port**. In this case, we return the error message. + # Ops will use this to raise an error. + result = run('open-port', *args).strip() + return result or None def opened_ports(*, endpoints: bool = False) -> list[Port]: @@ -146,7 +157,8 @@ def opened_ports(*, endpoints: bool = False) -> list[Port]: # '8000-8999/tcp' or '8000-8999/udp' (where the two numbers can be any ports) # '8000-8999' (where these could be any port number) # If ``--endpoints`` is used, then each port will be followed by a - # (possibly empty) tuple of endpoints. + # (non-empty) tuple of endpoints, e.g. '8000-8999/tcp (ep1,ep2)' or '80/tcp (*)'. + # (*) indicates that the port applies to all endpoints. for port in result: if endpoints: port, port_endpoints = port.rsplit(' ', 1) diff --git a/ops/hookcmds/_types.py b/ops/hookcmds/_types.py index 6898e4772..d57fcabc9 100644 --- a/ops/hookcmds/_types.py +++ b/ops/hookcmds/_types.py @@ -242,9 +242,6 @@ def _from_dict(cls, d: dict[str, Any]) -> Network: return cls(bind_addresses=bind, egress_subnets=egress, ingress_addresses=ingress) -# Note that we intend to merge this with model.py's `Port` in the future, and -# that does not have `kw_only=True`. That means that we should not use it here, -# either, so that merging can be backwards compatible. @dataclasses.dataclass(frozen=True) class Port: """A port that Juju has opened for the charm.""" diff --git a/ops/model.py b/ops/model.py index a2c9cd31a..01e0230b8 100644 --- a/ops/model.py +++ b/ops/model.py @@ -37,7 +37,7 @@ import warnings import weakref from abc import ABC, abstractmethod -from collections.abc import Callable, Generator, Iterable, Mapping, MutableMapping +from collections.abc import Callable, Generator, Iterable, Mapping, MutableMapping, Sequence from pathlib import Path, PurePath from typing import ( Any, @@ -719,7 +719,11 @@ def add_secret( ) def open_port( - self, protocol: typing.Literal['tcp', 'udp', 'icmp'], port: int | None = None + self, + protocol: typing.Literal['tcp', 'udp', 'icmp'], + port: int | tuple[int, int | None] | None = None, + *, + endpoints: Sequence[str] = '*', ) -> None: """Open a port with the given protocol for this unit. @@ -736,18 +740,30 @@ def open_port( protocol: String representing the protocol; must be one of 'tcp', 'udp', or 'icmp' (lowercase is recommended, but uppercase is also supported). - port: The port to open. Required for TCP and UDP; not allowed - for ICMP. + port: The port to open. Required for TCP and UDP; not allowed for ICMP. + May be a tuple of two integers to specify a port range. + endpoints: The endpoints for which to open the port. + '*' means to open the port for all endpoints. Raises: ModelError: If ``port`` is provided when ``protocol`` is 'icmp' or ``port`` is not provided when ``protocol`` is 'tcp' or 'udp'. """ - self._backend.open_port(protocol.lower(), port) + if isinstance(port, tuple): + port, to_port = port + else: + port, to_port = port, None + if not endpoints: + raise TypeError('endpoints must be a non-empty string or sequence of strings') + self._backend.open_port(protocol.lower(), port, to_port=to_port, endpoints=endpoints) def close_port( - self, protocol: typing.Literal['tcp', 'udp', 'icmp'], port: int | None = None + self, + protocol: typing.Literal['tcp', 'udp', 'icmp'], + port: int | tuple[int, int | None] | None = None, + *, + endpoints: Sequence[str] = '*', ) -> None: """Close a port with the given protocol for this unit. @@ -765,21 +781,29 @@ def close_port( protocol: String representing the protocol; must be one of 'tcp', 'udp', or 'icmp' (lowercase is recommended, but uppercase is also supported). - port: The port to open. Required for TCP and UDP; not allowed - for ICMP. + port: The port to open. Required for TCP and UDP; not allowed for ICMP. + May be a tuple of two integers to specify a port range. + endpoints: The endpoints for which to close the port. + '*' means to close the port for all endpoints. Raises: ModelError: If ``port`` is provided when ``protocol`` is 'icmp' or ``port`` is not provided when ``protocol`` is 'tcp' or 'udp'. """ - self._backend.close_port(protocol.lower(), port) + if isinstance(port, tuple): + port, to_port = port + else: + port, to_port = port, None + if not endpoints: + raise TypeError('endpoints must be a non-empty string or sequence of strings') + self._backend.close_port(protocol.lower(), port, to_port=to_port, endpoints=endpoints) def opened_ports(self) -> set[Port]: """Return a list of opened ports for this unit.""" return self._backend.opened_ports() - def set_ports(self, *ports: int | Port) -> None: + def set_ports(self, *ports: int | tuple[int, int | None] | Port) -> None: """Set the open ports for this unit, closing any others that are open. Some behaviour, such as whether the port is opened or closed externally without @@ -800,16 +824,19 @@ def set_ports(self, *ports: int | Port) -> None: ``port`` is not ``None``, or where ``protocol`` is 'tcp' or 'udp' and ``port`` is ``None``. """ - # Normalise to get easier comparisons. - existing = {(port.protocol, port.port) for port in self._backend.opened_ports()} + existing = self._backend.opened_ports() desired = { - ('tcp', port) if isinstance(port, int) else (port.protocol, port.port) + Port('tcp', port) + if isinstance(port, int) + else Port('tcp', port[0], to_port=port[1]) + if isinstance(port, tuple) + else port for port in ports } - for protocol, port in existing - desired: - self._backend.close_port(protocol, port) - for protocol, port in desired - existing: - self._backend.open_port(protocol, port) + for p in existing - desired: + self._backend.close_port(p.protocol, p.port, to_port=p.to_port, endpoints=p.endpoints) + for p in desired - existing: + self._backend.open_port(p.protocol, p.port, to_port=p.to_port, endpoints=p.endpoints) def reboot(self, now: bool = False) -> None: """Reboot the host machine. @@ -846,6 +873,21 @@ class Port: port: int | None """The port number. Will be ``None`` if protocol is ``'icmp'``.""" + to_port: int | None = None + """The end of the port range, if a range was specified. + + Will be ``None`` if a single port was specified (or if protocol is ``'icmp'``). + """ + + _: dataclasses.KW_ONLY + + endpoints: tuple[str, ...] | Literal['*'] = '*' + """The endpoints for which the port is open. + + Will be ``"*"`` if open for all endpoints, + or a tuple of endpoint names if specified for particular endpoints. + """ + OpenedPort = Port """Alias to Port for backwards compatibility. @@ -4037,26 +4079,59 @@ def secret_remove(self, id: str, *, revision: int | None = None): with self._wrap_hookcmd('secret-remove', id=id, revision=revision): hookcmds.secret_remove(id, revision=revision) - def open_port(self, protocol: str, port: int | None = None): - with self._wrap_hookcmd('open-port', protocol=protocol, port=port): - hookcmds.open_port(protocol, port) + def open_port( + self, + protocol: str, + port: int | None = None, + *, + to_port: int | None = None, + endpoints: Sequence[str] = '*', + ): + with self._wrap_hookcmd( + 'open-port', protocol=protocol, port=port, to_port=to_port, endpoints=endpoints + ): + result = hookcmds.open_port(protocol, port, to_port=to_port, endpoints=endpoints) + if result is not None: + raise ModelError(result) - def close_port(self, protocol: str, port: int | None = None): - with self._wrap_hookcmd('close-port', protocol=protocol, port=port): - hookcmds.close_port(protocol, port) + def close_port( + self, + protocol: str, + port: int | None = None, + *, + to_port: int | None = None, + endpoints: Sequence[str] = '*', + ): + with self._wrap_hookcmd( + 'close-port', protocol=protocol, port=port, to_port=to_port, endpoints=endpoints + ): + result = hookcmds.close_port(protocol, port, to_port=to_port, endpoints=endpoints) + if result is not None: + raise ModelError(result) def opened_ports(self) -> set[Port]: - with self._wrap_hookcmd('opened-ports'): - results = hookcmds.opened_ports() + with self._wrap_hookcmd('opened-ports', endpoints=True): + result = hookcmds.opened_ports(endpoints=True) ports: set[Port] = set() - for raw_port in results: - if raw_port.protocol not in ('tcp', 'udp', 'icmp'): - logger.warning('Unexpected opened-ports protocol: %s', raw_port.protocol) + for port in result: + if port.protocol not in ('tcp', 'udp', 'icmp'): + logger.warning('Unexpected opened-ports protocol: %s', port.protocol) + continue + if not port.endpoints: + logger.warning('opened-ports result with no endpoints: %s', port) continue - if raw_port.to_port is not None: - logger.warning('Ignoring opened-ports port range: %s', raw_port) - port = Port(raw_port.protocol or 'tcp', raw_port.port) - ports.add(port) + match port.endpoints: + case ['*']: + model_endpoints = '*' + case _: + model_endpoints = tuple(port.endpoints) + model_port = Port( + port.protocol, + port.port, + to_port=port.to_port, + endpoints=model_endpoints, + ) + ports.add(model_port) return ports def reboot(self, now: bool = False): diff --git a/test/test_hookcmds.py b/test/test_hookcmds.py index 0a414ac1f..19e895e65 100644 --- a/test/test_hookcmds.py +++ b/test/test_hookcmds.py @@ -259,6 +259,16 @@ def test_close_port_range(run: Run): hookcmds.close_port(protocol='tcp', port=8080, to_port=8090) +def test_close_port_range_and_endpoints(run: Run): + run.handle(['close-port', '--endpoints', 'ep1,ep2', '8080-8090/tcp']) + hookcmds.close_port(protocol='tcp', port=8080, to_port=8090, endpoints=['ep1', 'ep2']) + + +def test_close_port_to_port_without_port(): + with pytest.raises(TypeError): + hookcmds.close_port(protocol='tcp', to_port=8090) + + def test_config_get(run: Run): run.handle(['config-get', '--format=json'], stdout='{"foo": "bar"}') result = hookcmds.config_get() @@ -429,6 +439,31 @@ def test_open_port_range(run: Run): hookcmds.open_port(protocol='tcp', port=8080, to_port=8090) +def test_open_port_range_and_endpoints(run: Run): + run.handle(['open-port', '--endpoints', 'ep1,ep2', '8080-8090/tcp']) + hookcmds.open_port(protocol='tcp', port=8080, to_port=8090, endpoints=['ep1', 'ep2']) + + +def test_open_port_to_port_without_port(): + with pytest.raises(TypeError): + hookcmds.open_port(protocol='tcp', to_port=8090) + + +def test_open_port_nonexistent_endpoint(run: Run): + # Juju exits 0 but prints an error to stdout when the endpoint does not exist. + # hookcmds.open_port returns the error string; the ModelError is raised higher up. + error_msg = ( + 'cannot open/close ports: open port range: endpoint "nonexistent-ep"' + ' for unit "app/0" not found' + ) + run.handle( + ['open-port', '--endpoints', 'nonexistent-ep', '8080/tcp'], + stdout=error_msg, + ) + result = hookcmds.open_port(protocol='tcp', port=8080, endpoints=['nonexistent-ep']) + assert result == error_msg + + def test_opened_ports(run: Run): run.handle( ['opened-ports', '--format=json'], @@ -449,12 +484,21 @@ def test_opened_ports(run: Run): def test_opened_ports_endpoints(run: Run): run.handle( ['opened-ports', '--endpoints', '--format=json'], - stdout='["8080/tcp (ep1,ep2)"]', + stdout='["8080/tcp (ep1,ep2)", "1234/ftp (ep1)", "1000-2000/udp (*)"]', ) result = hookcmds.opened_ports(endpoints=True) - assert result[0].port == 8080 assert result[0].protocol == 'tcp' + assert result[0].port == 8080 + assert result[0].to_port is None assert result[0].endpoints == ['ep1', 'ep2'] + assert result[1].protocol == 'ftp' + assert result[1].port == 1234 + assert result[1].to_port is None + assert result[1].endpoints == ['ep1'] + assert result[2].protocol == 'udp' + assert result[2].port == 1000 + assert result[2].to_port == 2000 + assert result[2].endpoints == ['*'] @pytest.mark.parametrize('id', [None, 123]) diff --git a/test/test_model.py b/test/test_model.py index 613c8a2f1..a0edb3ebf 100644 --- a/test/test_model.py +++ b/test/test_model.py @@ -4278,11 +4278,56 @@ def test_open_port(self, fake_script: FakeScript, unit: ops.Unit): unit.open_port('icmp') assert fake_script.calls(clear=True) == [ - ['open-port', '8080/tcp'], - ['open-port', '4000/udp'], - ['open-port', 'icmp'], + ['open-port', '--endpoints', '*', '8080/tcp'], + ['open-port', '--endpoints', '*', '4000/udp'], + ['open-port', '--endpoints', '*', 'icmp'], ] + def test_open_port_range(self, fake_script: FakeScript, unit: ops.Unit): + fake_script.write('open-port', 'exit 0') + + unit.open_port('tcp', (8080, 8090)) + unit.open_port('UDP', (4000, 5000)) # type: ignore + unit.open_port('tcp', (8080, None)) + + assert fake_script.calls(clear=True) == [ + ['open-port', '--endpoints', '*', '8080-8090/tcp'], + ['open-port', '--endpoints', '*', '4000-5000/udp'], + ['open-port', '--endpoints', '*', '8080/tcp'], + ] + + def test_open_port_endpoints(self, fake_script: FakeScript, unit: ops.Unit): + fake_script.write('open-port', 'exit 0') + + unit.open_port('tcp', 8080, endpoints=['ep1', 'ep2']) + unit.open_port('udp', 4000, endpoints=['ep1']) + + assert fake_script.calls(clear=True) == [ + ['open-port', '--endpoints', 'ep1,ep2', '8080/tcp'], + ['open-port', '--endpoints', 'ep1', '4000/udp'], + ] + + def test_open_port_range_and_endpoints(self, fake_script: FakeScript, unit: ops.Unit): + fake_script.write('open-port', 'exit 0') + + unit.open_port('tcp', (8080, 8090), endpoints=['ep1', 'ep2']) + + assert fake_script.calls(clear=True) == [ + ['open-port', '--endpoints', 'ep1,ep2', '8080-8090/tcp'], + ] + + def test_open_port_range_none_port(self, fake_script: FakeScript, unit: ops.Unit): + fake_script.write('open-port', 'exit 0') + + with pytest.raises(TypeError): + unit.open_port('tcp', (None, 8090)) # type: ignore + + def test_open_port_empty_endpoints(self, fake_script: FakeScript, unit: ops.Unit): + fake_script.write('open-port', 'exit 0') + + with pytest.raises(TypeError): + unit.open_port('tcp', 8080, endpoints=[]) + def test_open_port_error(self, fake_script: FakeScript, unit: ops.Unit): fake_script.write('open-port', "echo 'ERROR bad protocol' >&2; exit 1") @@ -4291,7 +4336,22 @@ def test_open_port_error(self, fake_script: FakeScript, unit: ops.Unit): assert str(excinfo.value) == 'ERROR bad protocol\n' assert fake_script.calls(clear=True) == [ - ['open-port', '8080/ftp'], + ['open-port', '--endpoints', '*', '8080/ftp'], + ] + + def test_open_port_nonexistent_endpoint(self, fake_script: FakeScript, unit: ops.Unit): + # Juju exits 0 but prints an error to stdout when the endpoint does not exist. + error_msg = ( + 'cannot open/close ports: open port range: endpoint "nonexistent-ep"' + ' for unit "myapp/0" not found' + ) + fake_script.write('open-port', f"echo '{error_msg}'") + + with pytest.raises(ops.ModelError): + unit.open_port('tcp', 8080, endpoints=['nonexistent-ep']) + + assert fake_script.calls(clear=True) == [ + ['open-port', '--endpoints', 'nonexistent-ep', '8080/tcp'], ] def test_close_port(self, fake_script: FakeScript, unit: ops.Unit): @@ -4302,11 +4362,56 @@ def test_close_port(self, fake_script: FakeScript, unit: ops.Unit): unit.close_port('icmp') assert fake_script.calls(clear=True) == [ - ['close-port', '8080/tcp'], - ['close-port', '4000/udp'], - ['close-port', 'icmp'], + ['close-port', '--endpoints', '*', '8080/tcp'], + ['close-port', '--endpoints', '*', '4000/udp'], + ['close-port', '--endpoints', '*', 'icmp'], + ] + + def test_close_port_range(self, fake_script: FakeScript, unit: ops.Unit): + fake_script.write('close-port', 'exit 0') + + unit.close_port('tcp', (8080, 8090)) + unit.close_port('UDP', (4000, 5000)) # type: ignore + unit.close_port('tcp', (8080, None)) + + assert fake_script.calls(clear=True) == [ + ['close-port', '--endpoints', '*', '8080-8090/tcp'], + ['close-port', '--endpoints', '*', '4000-5000/udp'], + ['close-port', '--endpoints', '*', '8080/tcp'], ] + def test_close_port_endpoints(self, fake_script: FakeScript, unit: ops.Unit): + fake_script.write('close-port', 'exit 0') + + unit.close_port('tcp', 8080, endpoints=['ep1', 'ep2']) + unit.close_port('udp', 4000, endpoints=['ep1']) + + assert fake_script.calls(clear=True) == [ + ['close-port', '--endpoints', 'ep1,ep2', '8080/tcp'], + ['close-port', '--endpoints', 'ep1', '4000/udp'], + ] + + def test_close_port_range_and_endpoints(self, fake_script: FakeScript, unit: ops.Unit): + fake_script.write('close-port', 'exit 0') + + unit.close_port('tcp', (8080, 8090), endpoints=['ep1', 'ep2']) + + assert fake_script.calls(clear=True) == [ + ['close-port', '--endpoints', 'ep1,ep2', '8080-8090/tcp'], + ] + + def test_close_port_range_none_port(self, fake_script: FakeScript, unit: ops.Unit): + fake_script.write('close-port', 'exit 0') + + with pytest.raises(TypeError): + unit.close_port('tcp', (None, 8090)) # type: ignore + + def test_close_port_empty_endpoints(self, fake_script: FakeScript, unit: ops.Unit): + fake_script.write('close-port', 'exit 0') + + with pytest.raises(TypeError): + unit.close_port('tcp', 8080, endpoints=[]) + def test_close_port_error(self, fake_script: FakeScript, unit: ops.Unit): fake_script.write('close-port', "echo 'ERROR bad protocol' >&2; exit 1") @@ -4315,11 +4420,26 @@ def test_close_port_error(self, fake_script: FakeScript, unit: ops.Unit): assert str(excinfo.value) == 'ERROR bad protocol\n' assert fake_script.calls(clear=True) == [ - ['close-port', '8080/ftp'], + ['close-port', '--endpoints', '*', '8080/ftp'], + ] + + def test_close_port_nonexistent_endpoint(self, fake_script: FakeScript, unit: ops.Unit): + # Juju exits 0 but prints an error to stdout when the endpoint does not exist. + error_msg = ( + 'cannot open/close ports: close port range: endpoint "nonexistent-ep"' + ' for unit "myapp/0" not found' + ) + fake_script.write('close-port', f"echo '{error_msg}'") + + with pytest.raises(ops.ModelError): + unit.close_port('tcp', 8080, endpoints=['nonexistent-ep']) + + assert fake_script.calls(clear=True) == [ + ['close-port', '--endpoints', 'nonexistent-ep', '8080/tcp'], ] def test_opened_ports(self, fake_script: FakeScript, unit: ops.Unit): - fake_script.write('opened-ports', """echo '["8080/tcp", "icmp"]'""") + fake_script.write('opened-ports', """echo '["8080-8081/tcp (ep1,ep2)", "icmp (*)"]'""") ports_set = unit.opened_ports() assert isinstance(ports_set, set) @@ -4328,24 +4448,30 @@ def test_opened_ports(self, fake_script: FakeScript, unit: ops.Unit): assert isinstance(ports[0], ops.Port) assert ports[0].protocol == 'icmp' assert ports[0].port is None + assert ports[0].to_port is None + assert ports[0].endpoints == '*' assert isinstance(ports[1], ops.Port) assert ports[1].protocol == 'tcp' assert ports[1].port == 8080 + assert ports[1].to_port == 8081 + assert ports[1].endpoints == ('ep1', 'ep2') assert fake_script.calls(clear=True) == [ - ['opened-ports', '--format=json'], + ['opened-ports', '--endpoints', '--format=json'], ] def test_opened_ports_warnings( self, caplog: pytest.LogCaptureFixture, fake_script: FakeScript, unit: ops.Unit ): - fake_script.write('opened-ports', """echo '["8080/tcp", "1234/ftp", "1000-2000/udp"]'""") + fake_script.write( + 'opened-ports', + """echo '["8080/tcp (*)", "1234/ftp (ep1)", "1000-2000/udp (ep1,ep2)"]'""", + ) with caplog.at_level(level='WARNING', logger='ops.model'): ports_set = unit.opened_ports() - assert len(caplog.records) == 2 + assert len(caplog.records) == 1 assert re.search(r'.*protocol.*', caplog.records[0].message) - assert re.search(r'.*range.*', caplog.records[1].message) assert isinstance(ports_set, set) ports = sorted(ports_set, key=lambda p: (p.protocol, p.port)) @@ -4353,12 +4479,16 @@ def test_opened_ports_warnings( assert isinstance(ports[0], ops.Port) assert ports[0].protocol == 'tcp' assert ports[0].port == 8080 + assert ports[0].to_port is None + assert ports[0].endpoints == '*' assert isinstance(ports[1], ops.Port) assert ports[1].protocol == 'udp' assert ports[1].port == 1000 + assert ports[1].to_port == 2000 + assert ports[1].endpoints == ('ep1', 'ep2') assert fake_script.calls(clear=True) == [ - ['opened-ports', '--format=json'], + ['opened-ports', '--endpoints', '--format=json'], ] def test_set_ports_all_open(self, fake_script: FakeScript, unit: ops.Unit): @@ -4367,57 +4497,77 @@ def test_set_ports_all_open(self, fake_script: FakeScript, unit: ops.Unit): fake_script.write('opened-ports', 'echo []') unit.set_ports(8000, 8025) calls = fake_script.calls(clear=True) - assert calls.pop(0) == ['opened-ports', '--format=json'] + assert calls.pop(0) == ['opened-ports', '--endpoints', '--format=json'] calls.sort() # We make no guarantee on the order the ports are opened. assert calls == [ - ['open-port', '8000/tcp'], - ['open-port', '8025/tcp'], + ['open-port', '--endpoints', '*', '8000/tcp'], + ['open-port', '--endpoints', '*', '8025/tcp'], ] def test_set_ports_mixed(self, fake_script: FakeScript, unit: ops.Unit): # Two open ports, leave one alone and open another one. fake_script.write('open-port', 'exit 0') fake_script.write('close-port', 'exit 0') - fake_script.write('opened-ports', """echo '["8025/tcp", "8028/tcp"]'""") + fake_script.write('opened-ports', """echo '["8025/tcp (ep1,ep2)", "8028/tcp (*)"]'""") unit.set_ports(ops.Port('udp', 8022), 8028) assert fake_script.calls(clear=True) == [ - ['opened-ports', '--format=json'], - ['close-port', '8025/tcp'], - ['open-port', '8022/udp'], + ['opened-ports', '--endpoints', '--format=json'], + ['close-port', '--endpoints', 'ep1,ep2', '8025/tcp'], + ['open-port', '--endpoints', '*', '8022/udp'], ] def test_set_ports_replace(self, fake_script: FakeScript, unit: ops.Unit): fake_script.write('open-port', 'exit 0') fake_script.write('close-port', 'exit 0') - fake_script.write('opened-ports', """echo '["8025/tcp", "8028/tcp"]'""") + fake_script.write('opened-ports', """echo '["8025/tcp (*)", "8028/tcp (ep)"]'""") unit.set_ports(8001, 8002) calls = fake_script.calls(clear=True) - assert calls.pop(0) == ['opened-ports', '--format=json'] + assert calls.pop(0) == ['opened-ports', '--endpoints', '--format=json'] calls.sort() assert calls == [ - ['close-port', '8025/tcp'], - ['close-port', '8028/tcp'], - ['open-port', '8001/tcp'], - ['open-port', '8002/tcp'], + ['close-port', '--endpoints', '*', '8025/tcp'], + ['close-port', '--endpoints', 'ep', '8028/tcp'], + ['open-port', '--endpoints', '*', '8001/tcp'], + ['open-port', '--endpoints', '*', '8002/tcp'], ] def test_set_ports_close_all(self, fake_script: FakeScript, unit: ops.Unit): fake_script.write('open-port', 'exit 0') fake_script.write('close-port', 'exit 0') - fake_script.write('opened-ports', """echo '["8022/udp"]'""") + fake_script.write('opened-ports', """echo '["8022/udp (ep1,ep2,ep3)"]'""") unit.set_ports() assert fake_script.calls(clear=True) == [ - ['opened-ports', '--format=json'], - ['close-port', '8022/udp'], + ['opened-ports', '--endpoints', '--format=json'], + ['close-port', '--endpoints', 'ep1,ep2,ep3', '8022/udp'], ] def test_set_ports_noop(self, fake_script: FakeScript, unit: ops.Unit): fake_script.write('open-port', 'exit 0') fake_script.write('close-port', 'exit 0') - fake_script.write('opened-ports', """echo '["8000/tcp"]'""") + fake_script.write('opened-ports', """echo '["8000/tcp (*)"]'""") unit.set_ports(ops.Port('tcp', 8000)) assert fake_script.calls(clear=True) == [ - ['opened-ports', '--format=json'], + ['opened-ports', '--endpoints', '--format=json'], + ] + + def test_set_ports_with_tuple(self, fake_script: FakeScript, unit: ops.Unit): + fake_script.write('open-port', 'exit 0') + fake_script.write('close-port', 'exit 0') + fake_script.write('opened-ports', 'echo []') + unit.set_ports((8000, 8090)) + calls = fake_script.calls(clear=True) + assert calls.pop(0) == ['opened-ports', '--endpoints', '--format=json'] + assert calls == [ + ['open-port', '--endpoints', '*', '8000-8090/tcp'], + ] + + def test_set_ports_noop_with_range(self, fake_script: FakeScript, unit: ops.Unit): + fake_script.write('open-port', 'exit 0') + fake_script.write('close-port', 'exit 0') + fake_script.write('opened-ports', """echo '["8000-8090/tcp (*)"]'""") + unit.set_ports(ops.Port('tcp', 8000, to_port=8090)) + assert fake_script.calls(clear=True) == [ + ['opened-ports', '--endpoints', '--format=json'], ] diff --git a/test/test_testing.py b/test/test_testing.py index 4d32c32cc..3350ba926 100644 --- a/test/test_testing.py +++ b/test/test_testing.py @@ -6298,6 +6298,77 @@ def test_ports(self, request: pytest.FixtureRequest): ports_set = unit.opened_ports() assert ports_set == set() + def test_port_ranges(self, request: pytest.FixtureRequest): + harness = ops.testing.Harness(ops.CharmBase, meta='name: webapp') + request.addfinalizer(harness.cleanup) + unit = harness.model.unit + + unit.open_port('tcp', (8000, 8080)) + unit.open_port('udp', (5000, 5010)) + + ports_set = unit.opened_ports() + assert isinstance(ports_set, set) + ports = sorted(ports_set, key=lambda p: (p.protocol, p.port)) + assert len(ports) == 2 + assert isinstance(ports[0], ops.Port) + assert ports[0].protocol == 'tcp' + assert ports[0].port == 8000 + assert ports[0].to_port == 8080 + assert isinstance(ports[1], ops.Port) + assert ports[1].protocol == 'udp' + assert ports[1].port == 5000 + assert ports[1].to_port == 5010 + + unit.close_port('tcp', (8000, 8080)) + unit.close_port('tcp', (8000, 8080)) # closing same range again has no effect + + ports_set = unit.opened_ports() + ports = sorted(ports_set, key=lambda p: (p.protocol, p.port)) + assert len(ports) == 1 + assert ports[0].protocol == 'udp' + assert ports[0].port == 5000 + assert ports[0].to_port == 5010 + + unit.close_port('udp', (5000, 5010)) + + ports_set = unit.opened_ports() + assert ports_set == set() + + def test_endpoints(self, request: pytest.FixtureRequest): + harness = ops.testing.Harness(ops.CharmBase, meta='name: webapp') + request.addfinalizer(harness.cleanup) + unit = harness.model.unit + + unit.open_port('tcp', 8080, endpoints=['endpoint-a', 'endpoint-b']) + unit.open_port('udp', 4000, endpoints=['endpoint-c']) + + ports_set = unit.opened_ports() + assert isinstance(ports_set, set) + ports = sorted(ports_set, key=lambda p: (p.protocol, p.port)) + assert len(ports) == 2 + assert isinstance(ports[0], ops.Port) + assert ports[0].protocol == 'tcp' + assert ports[0].port == 8080 + assert ports[0].endpoints == ('endpoint-a', 'endpoint-b') + assert isinstance(ports[1], ops.Port) + assert ports[1].protocol == 'udp' + assert ports[1].port == 4000 + assert ports[1].endpoints == ('endpoint-c',) + + unit.close_port('tcp', 8080, endpoints=['endpoint-a', 'endpoint-b']) + + ports_set = unit.opened_ports() + ports = sorted(ports_set, key=lambda p: (p.protocol, p.port)) + assert len(ports) == 1 + assert ports[0].protocol == 'udp' + assert ports[0].port == 4000 + assert ports[0].endpoints == ('endpoint-c',) + + unit.close_port('udp', 4000, endpoints=['endpoint-c']) + + ports_set = unit.opened_ports() + assert ports_set == set() + def test_errors(self, request: pytest.FixtureRequest): harness = ops.testing.Harness(ops.CharmBase, meta='name: webapp') request.addfinalizer(harness.cleanup) diff --git a/testing/src/scenario/mocking.py b/testing/src/scenario/mocking.py index 0cc9fed7b..3ada50385 100644 --- a/testing/src/scenario/mocking.py +++ b/testing/src/scenario/mocking.py @@ -13,7 +13,7 @@ import io import shutil import uuid -from collections.abc import Mapping +from collections.abc import Mapping, Sequence from pathlib import Path from typing import ( TYPE_CHECKING, @@ -61,6 +61,8 @@ SubordinateRelation, _EntityStatus, _port_cls_by_protocol, + _port_str, + _PortMap, _RawPortProtocolLiteral, ) @@ -147,18 +149,32 @@ def __init__( def opened_ports(self) -> set[Port_Ops]: return { - Port_Ops(protocol=port.protocol, port=port.port) for port in self._state.opened_ports + Port_Ops(protocol=port.protocol, port=port.port, to_port=port.to_port) + for port in self._state.opened_ports } def open_port( self, protocol: _RawPortProtocolLiteral, port: int | None = None, + *, + to_port: int | None = None, + endpoints: Sequence[str] = '*', ): - port_ = _port_cls_by_protocol[protocol](port=port) # type: ignore - ports = set(self._state.opened_ports) - if port_ not in ports: - ports.add(port_) + if port is None and to_port is not None: + raise TypeError('to_port cannot be specified if port is not specified') + endpoints = tuple(endpoints) if endpoints != '*' else '*' + port_ = _port_cls_by_protocol[protocol]( + port=port, # type: ignore + to_port=to_port, + endpoints=endpoints, # type: ignore + ) + port_map = _PortMap(self._state.opened_ports) + if (p := port_map.get_first_overlap(port_)) is not None: + e = f'cannot open {_port_str(port_)}: port range conflicts with {_port_str(p)}' + raise ModelError(e) + port_map.open_port(port_) + ports = port_map.get_ports() if ports != self._state.opened_ports: self._state._update_opened_ports(frozenset(ports)) @@ -166,11 +182,25 @@ def close_port( self, protocol: _RawPortProtocolLiteral, port: int | None = None, + *, + to_port: int | None = None, + endpoints: Sequence[str] = '*', ): - port_ = _port_cls_by_protocol[protocol](port=port) # type: ignore - ports = set(self._state.opened_ports) - if port_ in ports: - ports.remove(port_) + if port is None and to_port is not None: + raise TypeError('to_port cannot be specified if port is not specified') + endpoints = tuple(endpoints) if endpoints != '*' else '*' + port_ = _port_cls_by_protocol[protocol]( + port=port, # type: ignore + to_port=to_port, + endpoints=endpoints, # type: ignore + ) + port_map = _PortMap(self._state.opened_ports) + if (p := port_map.get_first_overlap(port_)) is not None: + e = f'cannot open {_port_str(port_)}: port range conflicts with {_port_str(p)}' + raise ModelError(e) + all_endpoints = [e for e, _ in self._charm_spec.get_all_relations()] + port_map.close_port(port_, all_endpoints) + ports = port_map.get_ports() if ports != self._state.opened_ports: self._state._update_opened_ports(frozenset(ports)) diff --git a/testing/src/scenario/state.py b/testing/src/scenario/state.py index a5274ebf6..17defd461 100644 --- a/testing/src/scenario/state.py +++ b/testing/src/scenario/state.py @@ -1394,19 +1394,64 @@ class Port: protocol: _RawPortProtocolLiteral = 'tcp' """The protocol that data transferred over the port will use.""" + to_port: int | None = None + + endpoints: Literal['*'] | tuple[str, Unpack[tuple[str, ...]]] = '*' + def __post_init__(self): if type(self) is Port: raise RuntimeError( 'Port cannot be instantiated directly; please use TCPPort, UDPPort, or ICMPPort', ) + self._validate_ports() def __eq__(self, other: object) -> bool: if isinstance(other, (Port, ops.Port)): - return (self.protocol, self.port) == (other.protocol, other.port) - return False + return ( + self.protocol == other.protocol + and self.port == other.port + and self.to_port == other.to_port + and self.endpoints == other.endpoints + ) + return False # FIXME: should be NotImplemented, but needs testing for compatibility. def _to_ops(self) -> ops.Port: - return ops.Port(port=self.port, protocol=self.protocol) + return ops.Port(protocol=self.protocol, port=self.port, to_port=self.to_port) + + def _validate_ports(self): + if self.port is None and self.to_port is not None: + # Raise TypeError following ops.hookcmds.open/close_port behaviour. + raise TypeError('to_port can only be specified if port is also specified') + for port_attr, port_value in (('port', self.port), ('to_port', self.to_port)): + if port_value is None: + continue + if port_value not in range(1, 65535 + 1): + raise StateValidationError( + f'`{port_attr}` outside bounds [1:65535], got {port_value}', + ) + + def _overlaps(self, other: Port | ops.Port) -> bool: + # Overlapping port ranges are allowed if the protocols are different. + if self.protocol != other.protocol: + return False + # Only an ICMP port has port=None, and since ICMP ports don't have port ranges, + # they can't overlap with each other. + if self.port is None or other.port is None: + return False + # If the ports are identical aside from the endpoints, they aren't considered overlapping. + # It's valid to open/close the same port for different endpoints. + if ( + self.protocol == other.protocol + and self.port == other.port + and self.to_port == other.to_port + ): + return False + # Same protocol, non-identical ports -- Juju will error if the ranges overlap. + # Note that if to_port is None, the range will just include the single port. + # (Port values are validated to be in the range [1:65535] when constructed.) + a = range(self.port, self.to_port or self.port + 1) + b = range(other.port, other.to_port or other.port + 1) + return a.start in b or b.start in a @dataclasses.dataclass(frozen=True) @@ -1420,13 +1465,9 @@ class TCPPort(Port): :meta private: """ + to_port: int | None = None - def __post_init__(self): - super().__post_init__() - if not (1 <= self.port <= 65535): - raise StateValidationError( - f'`port` outside bounds [1:65535], got {self.port}', - ) + endpoints: Literal['*'] | tuple[str, Unpack[tuple[str, ...]]] = '*' @dataclasses.dataclass(frozen=True) @@ -1440,13 +1481,9 @@ class UDPPort(Port): :meta private: """ + to_port: int | None = None - def __post_init__(self): - super().__post_init__() - if not (1 <= self.port <= 65535): - raise StateValidationError( - f'`port` outside bounds [1:65535], got {self.port}', - ) + endpoints: Literal['*'] | tuple[str, Unpack[tuple[str, ...]]] = '*' @dataclasses.dataclass(frozen=True, kw_only=True) @@ -1459,9 +1496,11 @@ class ICMPPort(Port): :meta private: """ + endpoints: Literal['*'] | tuple[str, Unpack[tuple[str, ...]]] = '*' + def __post_init__(self): super().__post_init__() - if self.port is not None: + if (self.port, self.to_port) != (None, None): raise StateValidationError('`port` cannot be set for `ICMPPort`') @@ -1472,6 +1511,68 @@ def __post_init__(self): } +class _PortMap: + def __init__(self, ports: Iterable[Port] | None = None): + self._map = {} if ports is None else self._make_map(ports) + + @staticmethod + def _make_map( + ports: Iterable[Port], + ) -> dict[tuple[_RawPortProtocolLiteral, int | None, int | None], set[str]]: + return {(port.protocol, port.port, port.to_port): set(port.endpoints) for port in ports} + + def open_port(self, port: Port) -> None: + key = (port.protocol, port.port, port.to_port) + self._map.setdefault(key, set()).update(port.endpoints) + + def close_port(self, port: Port, all_endpoints: Sequence[str]) -> None: + key = (port.protocol, port.port, port.to_port) + if (endpoints := self._map.get(key)) is None: + return + if port.endpoints == '*': + del self._map[key] + elif '*' in endpoints: + endpoints.clear() + endpoints.update(all_endpoints) + endpoints.difference_update(port.endpoints) + else: + endpoints.difference_update(port.endpoints) + + def get_ports(self) -> frozenset[Port]: + return frozenset( + _port_cls_by_protocol[protocol]( + protocol=protocol, + port=port, # type: ignore + to_port=to_port, + endpoints='*' if '*' in endpoints else tuple(sorted(endpoints)), # type: ignore + ) + for (protocol, port, to_port), endpoints in self._map.items() + ) + + def get_first_overlap(self, port: Port) -> ops.Port | None: + for (protocol, from_port, to_port), endpoints in self._map.items(): + endpoints = '*' if '*' in endpoints else tuple(sorted(endpoints)) + assert endpoints + this_port = ops.Port( + protocol=protocol, + port=from_port, + to_port=to_port, + endpoints=endpoints, + ) + if port._overlaps(this_port): + return this_port + return None + + +def _port_str(port: Port | ops.Port) -> str: + """Return the Juju string representation of a port (without endpoints).""" + if port.port is None: + return port.protocol + if port.to_port is None: + return f'{port.port}/{port.protocol}' + return f'{port.port}-{port.to_port}/{port.protocol}' + + _next_storage_index_counter = 0 # storage indices start at 0 @@ -1611,12 +1712,30 @@ def __post_init__(self): object.__setattr__(self, name, _EntityStatus.from_ops(val)) else: raise TypeError(f'Invalid status.{name}: {val!r}') + + # ports normalised_ports = [ - Port(protocol=port.protocol, port=port.port) if isinstance(port, ops.Port) else port + _port_cls_by_protocol[port.protocol]( + protocol=port.protocol, + port=port.port, # type: ignore + to_port=port.to_port, + endpoints=port.endpoints, + ) + if isinstance(port, ops.Port) + else port for port in self.opened_ports ] + port_map = _PortMap() + for port in normalised_ports: + if (p := port_map.get_first_overlap(port)) is not None: + e = f'cannot open {_port_str(port)}: port range conflicts with {_port_str(p)}' + raise StateValidationError(e) + port_map.open_port(port) + normalised_ports = port_map.get_ports() if self.opened_ports != normalised_ports: object.__setattr__(self, 'opened_ports', normalised_ports) + + # storage normalised_storage = [ Storage(name=storage.name, index=storage.index) if isinstance(storage, ops.Storage) diff --git a/testing/tests/test_e2e/test_ports.py b/testing/tests/test_e2e/test_ports.py index 9b37f5fb0..efa66c45f 100644 --- a/testing/tests/test_e2e/test_ports.py +++ b/testing/tests/test_e2e/test_ports.py @@ -8,8 +8,10 @@ import pytest from scenario import Context, State -from scenario.state import Port, StateValidationError, TCPPort, UDPPort +from scenario.errors import UncaughtCharmError +from scenario.state import ICMPPort, Port, StateValidationError, TCPPort, UDPPort +import ops from ops import CharmBase, Framework, StartEvent, StopEvent @@ -59,3 +61,177 @@ def test_port_port(klass): klass(port=0) with pytest.raises(StateValidationError): klass(port=65536) + + +# --- Port ranges and endpoints --- + + +class _RangeCharm(CharmBase): + META: Mapping[str, Any] = {'name': 'range-charm'} + + def __init__(self, framework: Framework): + super().__init__(framework) + framework.observe(self.on.start, self._on_start) + framework.observe(self.on.stop, self._on_stop) + + def _on_start(self, _: StartEvent): + self.unit.open_port('tcp', (8000, 8090)) + + def _on_stop(self, _: StopEvent): + self.unit.close_port('tcp', (8000, 8090)) + + +class _ICMPCharm(CharmBase): + META: Mapping[str, Any] = {'name': 'icmp-charm'} + + def __init__(self, framework: Framework): + super().__init__(framework) + framework.observe(self.on.start, self._on_start) + + def _on_start(self, _: StartEvent): + self.unit.open_port('icmp') + + +class _UDPRangeCharm(CharmBase): + META: Mapping[str, Any] = {'name': 'udp-range-charm'} + + def __init__(self, framework: Framework): + super().__init__(framework) + framework.observe(self.on.start, self._on_start) + + def _on_start(self, _: StartEvent): + self.unit.open_port('udp', (5000, 5010)) + + +class _EndpointCharm(CharmBase): + META: Mapping[str, Any] = {'name': 'endpoint-charm'} + + def __init__(self, framework: Framework): + super().__init__(framework) + framework.observe(self.on.start, self._on_start) + + def _on_start(self, _: StartEvent): + self.unit.open_port('tcp', 8080, endpoints=['ep1']) + + +class _OverlapCharm(CharmBase): + META: Mapping[str, Any] = {'name': 'overlap-charm'} + + def __init__(self, framework: Framework): + super().__init__(framework) + framework.observe(self.on.start, self._on_start) + + def _on_start(self, _: StartEvent): + # Overlaps with TCPPort(8000, to_port=8090) in the initial state. + self.unit.open_port('tcp', (8050, 8100)) + + +class _SetPortsCharm(CharmBase): + META: Mapping[str, Any] = {'name': 'set-ports-charm'} + + def __init__(self, framework: Framework): + super().__init__(framework) + framework.observe(self.on.start, self._on_start) + + def _on_start(self, _: StartEvent): + self.unit.set_ports(8000, ops.Port('udp', 5000), (9000, 9010)) + + +class _ReadPortsCharm(CharmBase): + """Charm that asserts the contents of opened_ports() from within the handler.""" + + META: Mapping[str, Any] = {'name': 'read-ports-charm'} + + def __init__(self, framework: Framework): + super().__init__(framework) + framework.observe(self.on.start, self._on_start) + + def _on_start(self, _: StartEvent): + ports = self.unit.opened_ports() + assert len(ports) == 1 + port = next(iter(ports)) + assert port.protocol == 'tcp' + assert port.port == 8000 + assert port.to_port == 8090 + + +def test_open_port_range(): + ctx = Context(_RangeCharm, meta=_RangeCharm.META) + out = ctx.run(ctx.on.start(), State()) + assert len(out.opened_ports) == 1 + port = next(iter(out.opened_ports)) + assert port.protocol == 'tcp' + assert port.port == 8000 + assert port.to_port == 8090 + + +def test_close_port_range(): + ctx = Context(_RangeCharm, meta=_RangeCharm.META) + out = ctx.run(ctx.on.stop(), State(opened_ports={TCPPort(8000, to_port=8090)})) + assert not out.opened_ports + + +def test_open_icmp_port(): + ctx = Context(_ICMPCharm, meta=_ICMPCharm.META) + out = ctx.run(ctx.on.start(), State()) + assert len(out.opened_ports) == 1 + port = next(iter(out.opened_ports)) + assert port.protocol == 'icmp' + assert port.port is None + + +def test_open_udp_range(): + ctx = Context(_UDPRangeCharm, meta=_UDPRangeCharm.META) + out = ctx.run(ctx.on.start(), State()) + assert len(out.opened_ports) == 1 + port = next(iter(out.opened_ports)) + assert port.protocol == 'udp' + assert port.port == 5000 + assert port.to_port == 5010 + + +def test_open_port_with_endpoint(): + ctx = Context(_EndpointCharm, meta=_EndpointCharm.META) + out = ctx.run(ctx.on.start(), State()) + assert len(out.opened_ports) == 1 + port = next(iter(out.opened_ports)) + assert port.protocol == 'tcp' + assert port.port == 8080 + assert port.endpoints == ('ep1',) + + +def test_overlapping_port_raises(): + ctx = Context(_OverlapCharm, meta=_OverlapCharm.META) + with pytest.raises(UncaughtCharmError) as exc_info: + ctx.run(ctx.on.start(), State(opened_ports={TCPPort(8000, to_port=8090)})) + assert isinstance(exc_info.value.__cause__, ops.ModelError) + + +def test_set_ports_via_charm(): + ctx = Context(_SetPortsCharm, meta=_SetPortsCharm.META) + out = ctx.run(ctx.on.start(), State()) + assert TCPPort(8000) in out.opened_ports + assert UDPPort(5000) in out.opened_ports + assert TCPPort(9000, to_port=9010) in out.opened_ports + + +def test_opened_ports_in_charm(): + # State has a TCP range port open; the charm asserts it can be read back + # correctly via unit.opened_ports(). + ctx = Context(_ReadPortsCharm, meta=_ReadPortsCharm.META) + ctx.run(ctx.on.start(), State(opened_ports={TCPPort(8000, to_port=8090)})) + + +def test_to_port_validation(): + with pytest.raises(StateValidationError): + TCPPort(8000, to_port=0) + with pytest.raises(StateValidationError): + TCPPort(8000, to_port=65536) + with pytest.raises(StateValidationError): + UDPPort(8000, to_port=65536) + + +def test_icmp_port_with_to_port(): + # to_port is not permitted for ICMP since ICMP has no port concept. + with pytest.raises(TypeError): + ICMPPort(to_port=80) # type: ignore