Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,28 @@ class VectorsStd:
velocity: Vector2d
```

Default values for nested class fields cannot be set directly, as Python doesn't
allow using mutable default values in dataclasses. To get around this, pass
`frozen=True` to the inner class' `dataclass_struct` decorator. Alternatively,
pass a zero-argument callable that returns an instance of the class to the
`default_factory` keyword argument of
[`dataclasses.field`](https://docs.python.org/3/library/dataclasses.html#dataclasses.field).
For example:

```python
from dataclasses import field

@dcs.dataclass_struct()
class VectorsStd:
direction: Vector2d
velocity: Vector2d = field(default_factory=lambda: Vector2d(0, 0))
```

The return type of the `default_factory` will be validated unless
`validate_defaults=False` is passed to the `dataclass_struct` decorator. Note
that this means the callable passed to `default_factory` will be called once
during class creation.

#### Fixed-length arrays

Fixed-length arrays can be represented by annotating a `list` field with
Expand Down Expand Up @@ -381,6 +403,22 @@ class TwoDimArray:
TwoDimArray(fixed=[[1, 2], [3, 4], [5, 6]])
```

As with [nested structs](#nested-structs), a `default_factory` must be used to
set a default value. For example:

```python
from dataclasses import field
from typing import Annotated

@dcs.dataclass_struct()
class DefaultArray:
x: Annotated[list[int], 3] = field(default_factory=lambda: [1, 2, 3])
```

The returned default value's length and type and values of its items will be
validated unless `validate_defaults=False` is passed to the `dataclass_struct`
decorator.

#### Manual padding

Padding can be manually controlled by annotating a type with `PadBefore` or
Expand Down
50 changes: 42 additions & 8 deletions dataclasses_struct/dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,15 @@ def format(self) -> str:
def __repr__(self) -> str:
return f"{super().__repr__()}({self.item_field!r}, {self.n})"

def validate_default(self, val: list) -> None:
n = len(val)
if n != self.n:
msg = f"fixed-length array must have length of {self.n}, got {n}"
raise ValueError(msg)

for i in val:
_validate_field_default(self.item_field, i)


def _validate_modes_match(mode: str, nested_mode: str) -> None:
if mode != nested_mode:
Expand Down Expand Up @@ -382,6 +391,27 @@ def _resolve_field(
return field, type_, pad_before, pad_after


def _get_default_from_dataclasses_field(field: dataclasses.Field) -> Any:
if field.default is not dataclasses.MISSING:
return field.default

if field.default_factory is not dataclasses.MISSING:
return field.default_factory()

return dataclasses.MISSING


def _validate_field_default(field: Field[T], val: Any) -> None:
if not isinstance(val, field.field_type):
msg = (
"invalid type for field: expected "
f"{field.field_type} got {type(val)}"
)
raise TypeError(msg)

field.validate_default(val)


def _validate_and_parse_field(
cls: type,
*,
Expand All @@ -403,18 +433,22 @@ def _validate_and_parse_field(
elif not field.is_std:
raise TypeError(f"field {field} only supported in native size mode")

if validate_defaults and hasattr(cls, name):
init_field = init
if hasattr(cls, name):
val = getattr(cls, name)
if not isinstance(val, field.field_type):
raise TypeError(
"invalid type for field: expected "
f"{field.field_type} got {type(val)}"
)
field.validate_default(val)
if isinstance(val, dataclasses.Field):
if not val.init:
init_field = False

if validate_defaults:
val = _get_default_from_dataclasses_field(val)

if validate_defaults and val is not dataclasses.MISSING:
_validate_field_default(field, val)

return (
_format_str_with_padding(field.format(), pad_before, pad_after),
_FieldInfo(name, field, type_, init),
_FieldInfo(name, field, type_, init_field),
)


Expand Down
115 changes: 115 additions & 0 deletions test/test_dataclasses_field.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
from dataclasses import field
from typing import Any

import pytest
from conftest import (
raises_default_value_invalid_type_error,
raises_default_value_out_of_range_error,
)

import dataclasses_struct as dcs


def test_dataclasses_field_empty() -> None:
@dcs.dataclass_struct()
class T:
x: int = field()

T(12)

with pytest.raises(
TypeError,
match="missing 1 required positional argument",
):
T()


def parametrize_field_kwargs(val: Any) -> pytest.MarkDecorator:
"""
Parametrise dataclasses.field kwargs on 'default' and 'default_kwargs'.
"""
return pytest.mark.parametrize(
"field_kwargs",
({"default": val}, {"default_factory": lambda: val}),
ids=("default", "default_factory"),
)


@parametrize_field_kwargs(100)
def test_dataclasses_field_default(field_kwargs) -> None:
@dcs.dataclass_struct()
class T:
x: int = field(**field_kwargs)

t = T()
assert t.x == 100

t = T(200)
assert t.x == 200


@parametrize_field_kwargs(100.0)
def test_dataclasses_field_default_wrong_type_fails(field_kwargs) -> None:
with raises_default_value_invalid_type_error():

@dcs.dataclass_struct()
class _:
x: int = field(**field_kwargs)


@parametrize_field_kwargs(-100)
def test_dataclasses_field_default_invalid_value_fails(field_kwargs) -> None:
with raises_default_value_out_of_range_error():

@dcs.dataclass_struct()
class _:
x: dcs.UnsignedInt = field(**field_kwargs)


@parametrize_field_kwargs(200)
def test_dataclasses_field_no_init(field_kwargs) -> None:
@dcs.dataclass_struct()
class T:
x: int = field()
y: int = field(init=False, **field_kwargs)

t = T(100)
assert t.x == 100
assert t.y == 200

with pytest.raises(
TypeError,
match=r"takes 2 positional arguments but 3 were given$",
):
T(1, 2)


class DefaultFactoryCallCounter:
def __init__(self):
self.call_count = 0

def __call__(self) -> int:
self.call_count += 1
return 1


def test_default_factory_called_once_during_class_creation() -> None:
factory = DefaultFactoryCallCounter()

@dcs.dataclass_struct()
class _:
x: int = field(default_factory=factory)

assert factory.call_count == 1


def test_default_factory_not_called_during_class_creation_if_validate_defaults_is_false() -> ( # noqa: E501
None
):
factory = DefaultFactoryCallCounter()

@dcs.dataclass_struct(validate_defaults=False)
class _:
x: int = field(default_factory=factory)

assert factory.call_count == 0
49 changes: 49 additions & 0 deletions test/test_pack_unpack.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import dataclasses
import itertools
import struct
from typing import Annotated
Expand Down Expand Up @@ -531,3 +532,51 @@ def __init__(self) -> None:
unpacked = T.from_packed(t.pack())
assert unpacked.x == 3
assert unpacked.y == 4


def test_pack_unpack_with_specific_field_no_init() -> None:
@dcs.dataclass_struct()
class T:
x: int = dataclasses.field(default=-100)
y: int = dataclasses.field(default=200, init=False)

t = T(x=100)
assert t.x == 100
assert t.y == 200

t.y = -200
unpacked = T.from_packed(t.pack())
assert unpacked.x == 100
assert unpacked.y == -200


def test_pack_unpack_with_no_init_in_decorator_overriding_fields_init() -> (
None
):
@dcs.dataclass_struct(init=False)
class T:
x: int = dataclasses.field(init=True, default=100)
y: int = dataclasses.field(init=True, default=200)

t = T()
assert t.x == 100
assert t.y == 200

t.x = -100
t.y = -200
unpacked = T.from_packed(t.pack())
assert unpacked.x == -100
assert unpacked.y == -200


def test_pack_unpack_no_init_fields_with_validate_defaults_false() -> None:
@dcs.dataclass_struct(validate_defaults=False)
class T:
x: int = dataclasses.field(init=False, default=1)

t = T()
assert t.x == 1

t.x = -1
unpacked = T.from_packed(t.pack())
assert unpacked.x == -1
Loading