Skip to content

Commit 23c78ba

Browse files
committed
Fix computation log types
1 parent d9d3501 commit 23c78ba

File tree

19 files changed

+202
-178
lines changed

19 files changed

+202
-178
lines changed

noxfile.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88

99
@nox.session(python = ("3.9", "3.8", "3.7"), tags = ("lint", "style"))
10-
@nox.parametrize("numpy", ("1.22", "1.20", "1.21"))
10+
@nox.parametrize("numpy", ("1.23", "1.22", "1.21"))
1111
def style(session, numpy):
1212
"""Run tests."""
1313

@@ -21,7 +21,7 @@ def style(session, numpy):
2121

2222

2323
@nox.session(python = ("3.9", "3.8", "3.7"), tags = ("lint", "docs"))
24-
@nox.parametrize("numpy", ("1.22", "1.20", "1.21"))
24+
@nox.parametrize("numpy", ("1.23", "1.22", "1.21"))
2525
def docs(session, numpy):
2626
"""Run tests."""
2727

@@ -35,7 +35,7 @@ def docs(session, numpy):
3535

3636

3737
@nox.session(python = ("3.9", "3.8", "3.7"), tags = ("lint", "mypy"))
38-
@nox.parametrize("numpy", ("1.22", "1.20", "1.21"))
38+
@nox.parametrize("numpy", ("1.23", "1.22", "1.21"))
3939
def mypy(session, numpy):
4040
"""Run tests."""
4141

@@ -49,7 +49,7 @@ def mypy(session, numpy):
4949

5050

5151
@nox.session(python = ("3.9", "3.8", "3.7"), tags = ("lint", "mypy-hxc"))
52-
@nox.parametrize("numpy", ("1.22", "1.20", "1.21"))
52+
@nox.parametrize("numpy", ("1.23", "1.22", "1.21"))
5353
def mypy_hxc(session, numpy):
5454
"""Run tests."""
5555

@@ -63,7 +63,7 @@ def mypy_hxc(session, numpy):
6363

6464

6565
@nox.session(python = ("3.9", "3.8", "3.7"), tags = ("test", "test-core"))
66-
@nox.parametrize("numpy", ("1.22", "1.20", "1.21"))
66+
@nox.parametrize("numpy", ("1.23", "1.22", "1.21"))
6767
def test_core(session, numpy):
6868
"""Run tests."""
6969

@@ -79,7 +79,7 @@ def test_core(session, numpy):
7979

8080

8181
@nox.session(python = ("3.9", "3.8", "3.7"), tags = ("test", "test-country"))
82-
@nox.parametrize("numpy", ("1.22", "1.20", "1.21"))
82+
@nox.parametrize("numpy", ("1.23", "1.22", "1.21"))
8383
def test_country(session, numpy):
8484
"""Run tests."""
8585

@@ -95,7 +95,7 @@ def test_country(session, numpy):
9595

9696

9797
@nox.session(python = ("3.9", "3.8", "3.7"), tags = ("test", "test-extension"))
98-
@nox.parametrize("numpy", ("1.22", "1.20", "1.21"))
98+
@nox.parametrize("numpy", ("1.23", "1.22", "1.21"))
9999
def test_extension(session, numpy):
100100
"""Run tests."""
101101

openfisca_core/commons/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
* :func:`.average_rate`
99
* :func:`.concat`
1010
* :func:`.empty_clone`
11+
* :func:`.flatten`
1112
* :func:`.marginal_rate`
1213
* :func:`.stringify_array`
1314
* :func:`.switch`
@@ -53,11 +54,11 @@
5354
# Official Public API
5455

5556
from .formulas import apply_thresholds, concat, switch # noqa: F401
56-
from .misc import empty_clone, stringify_array # noqa: F401
57+
from .misc import empty_clone, flatten, stringify_array # noqa: F401
5758
from .rates import average_rate, marginal_rate # noqa: F401
5859

5960
__all__ = ["apply_thresholds", "concat", "switch"]
60-
__all__ = ["empty_clone", "stringify_array", *__all__]
61+
__all__ = ["empty_clone", "flatten", "stringify_array", *__all__]
6162
__all__ = ["average_rate", "marginal_rate", *__all__]
6263

6364
# Deprecated

openfisca_core/commons/formulas.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Dict, Sequence
1+
from typing import Any, Dict, Sequence, Union
22

33
import numpy
44

@@ -29,21 +29,20 @@ def apply_thresholds(
2929
3030
Examples:
3131
>>> input = numpy.array([4, 5, 6, 7, 8])
32-
>>> thresholds =
33-
[5, 7]
32+
>>> thresholds = [5, 7]
3433
>>> choices = [10, 15, 20]
3534
>>> apply_thresholds(input, thresholds, choices)
3635
array([10, 10, 15, 15, 20])
3736
3837
"""
3938

40-
condlist: Sequence[numpy.bool_]
39+
condlist: Sequence[Union[bool, numpy.bool_]]
4140
condlist = [input <= threshold for threshold in thresholds]
4241

4342
if len(condlist) == len(choices) - 1:
4443
# If a choice is provided for input > highest threshold, last condition
4544
# must be true to return it.
46-
condlist += numpy.array([True])
45+
condlist += [True]
4746

4847
assert len(condlist) == len(choices), \
4948
" ".join([
@@ -119,7 +118,7 @@ def switch(
119118

120119
condlist = [
121120
conditions == condition
122-
for condition in tuple(value_by_condition.keys())
121+
for condition in value_by_condition.keys()
123122
]
124123

125124
return numpy.select(condlist, tuple(value_by_condition.values()))

openfisca_core/commons/misc.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
1-
from typing import TypeVar
1+
from typing import Any, Iterator, Optional, Sequence, TypeVar
22

3-
import numpy
3+
import itertools
4+
5+
from openfisca_core import types
46

57
T = TypeVar("T")
68

79

810
def empty_clone(original: T) -> T:
9-
"""Creates an empty instance of the same class of the original object.
11+
"""Create an empty instance of the same class of the original object.
1012
1113
Args:
1214
original: An object to clone.
@@ -43,8 +45,31 @@ def empty_clone(original: T) -> T:
4345
return new
4446

4547

46-
def stringify_array(array: numpy.ndarray) -> str:
47-
"""Generates a clean string representation of a numpy array.
48+
def flatten(seqs: Sequence[Sequence[T]]) -> Iterator[T]:
49+
"""Flatten a sequence of sequences.
50+
51+
Args:
52+
seqs: Any sequence of sequences.
53+
54+
Returns:
55+
An iterator with the values.
56+
57+
Examples:
58+
>>> list(flatten([(1, 2), (3, 4)]))
59+
[1, 2, 3, 4]
60+
61+
>>> list(flatten(["ab", "cd"]))
62+
['a', 'b', 'c', 'd']
63+
64+
.. versionadded:: 36.0.0
65+
66+
"""
67+
68+
return itertools.chain.from_iterable(seqs)
69+
70+
71+
def stringify_array(array: Optional[types.Array[Any]]) -> str:
72+
"""Generate a clean string representation of a numpy array.
4873
4974
Args:
5075
array: An array.

openfisca_core/indexed_enums/enum_array.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def _forbidden_operation(self, other: Any) -> NoReturn:
6363
__and__ = _forbidden_operation
6464
__or__ = _forbidden_operation
6565

66-
def decode(self) -> numpy.object_:
66+
def decode(self) -> numpy.ndarray:
6767
"""
6868
Return the array of enum items corresponding to self.
6969
@@ -82,7 +82,7 @@ def decode(self) -> numpy.object_:
8282
list(self.possible_values),
8383
)
8484

85-
def decode_to_str(self) -> numpy.str_:
85+
def decode_to_str(self) -> numpy.ndarray:
8686
"""
8787
Return the array of string identifiers corresponding to self.
8888

openfisca_core/populations/population.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from openfisca_core import periods, projectors
1111
from openfisca_core.holders import Holder, MemoryUsage
1212
from openfisca_core.projectors import Projector
13-
from openfisca_core.types import Array, Entity, Period, Role, Simulation
13+
from openfisca_core.types import Entity, Period, Role, Simulation
1414

1515
from . import config
1616

@@ -21,14 +21,14 @@ class Population:
2121
entity: Entity
2222
_holders: Dict[str, Holder]
2323
count: int
24-
ids: Array[str]
24+
ids: numpy.ndarray
2525

2626
def __init__(self, entity: Entity) -> None:
2727
self.simulation = None
2828
self.entity = entity
2929
self._holders = {}
3030
self.count = 0
31-
self.ids = []
31+
self.ids = numpy.array([])
3232

3333
def clone(self, simulation: Simulation) -> Population:
3434
result = Population(self.entity)
@@ -38,14 +38,14 @@ def clone(self, simulation: Simulation) -> Population:
3838
result.ids = self.ids
3939
return result
4040

41-
def empty_array(self) -> Array[float]:
41+
def empty_array(self) -> numpy.ndarray:
4242
return numpy.zeros(self.count)
4343

4444
def filled_array(
4545
self,
4646
value: Union[float, bool],
4747
dtype: Optional[numpy.dtype] = None,
48-
) -> Union[Array[float], Array[bool]]:
48+
) -> numpy.ndarray:
4949
return numpy.full(self.count, value, dtype)
5050

5151
def __getattr__(self, attribute: str) -> Projector:
@@ -64,7 +64,7 @@ def get_index(self, id: str) -> int:
6464

6565
def check_array_compatible_with_entity(
6666
self,
67-
array: Array[float],
67+
array: numpy.ndarray,
6868
) -> None:
6969
if self.count == array.size:
7070
return None
@@ -95,7 +95,7 @@ def __call__(
9595
variable_name: str,
9696
period: Optional[Union[int, str, Period]] = None,
9797
options: Optional[Sequence[str]] = None,
98-
) -> Optional[Array[float]]:
98+
) -> Optional[Sequence[float]]:
9999
"""
100100
Calculate the variable ``variable_name`` for the entity and the period ``period``, using the variable formula if it exists.
101101
@@ -169,7 +169,7 @@ def get_memory_usage(
169169
})
170170

171171
@projectors.projectable
172-
def has_role(self, role: Role) -> Optional[Array[bool]]:
172+
def has_role(self, role: Role) -> Optional[Sequence[bool]]:
173173
"""
174174
Check if a person has a given role within its `GroupEntity`
175175
@@ -195,10 +195,10 @@ def has_role(self, role: Role) -> Optional[Array[bool]]:
195195
@projectors.projectable
196196
def value_from_partner(
197197
self,
198-
array: Array[float],
198+
array: numpy.ndarray,
199199
entity: Projector,
200200
role: Role,
201-
) -> Optional[Array[float]]:
201+
) -> Optional[numpy.ndarray]:
202202
self.check_array_compatible_with_entity(array)
203203
self.entity.check_role_validity(role)
204204

@@ -218,9 +218,9 @@ def value_from_partner(
218218
def get_rank(
219219
self,
220220
entity: Population,
221-
criteria: Array[float],
221+
criteria: Sequence[float],
222222
condition: bool = True,
223-
) -> Array[int]:
223+
) -> numpy.ndarray:
224224
"""
225225
Get the rank of a person within an entity according to a criteria.
226226
The person with rank 0 has the minimum value of criteria.

openfisca_core/tracers/computation_log.py

Lines changed: 18 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,31 @@
11
from __future__ import annotations
22

3-
import typing
4-
from typing import List, Optional, Union
3+
from typing import Any, Optional, Sequence
54

6-
import numpy
7-
8-
from .. import tracers
9-
from openfisca_core.indexed_enums import EnumArray
5+
import sys
106

11-
if typing.TYPE_CHECKING:
12-
from numpy.typing import ArrayLike
7+
import numpy
138

14-
Array = Union[EnumArray, ArrayLike]
9+
from openfisca_core import commons, types
1510

1611

1712
class ComputationLog:
13+
_full_tracer: types.FullTracer
1814

19-
_full_tracer: tracers.FullTracer
20-
21-
def __init__(self, full_tracer: tracers.FullTracer) -> None:
15+
def __init__(self, full_tracer: types.FullTracer) -> None:
2216
self._full_tracer = full_tracer
2317

24-
def display(
25-
self,
26-
value: Optional[Array],
27-
) -> str:
28-
if isinstance(value, EnumArray):
18+
def display(self, value: types.Array[Any]) -> str:
19+
if isinstance(value, types.EnumArray):
2920
value = value.decode_to_str()
3021

31-
return numpy.array2string(value, max_line_width = None)
22+
return numpy.array2string(value, max_line_width = sys.maxsize)
3223

3324
def lines(
3425
self,
3526
aggregate: bool = False,
3627
max_depth: Optional[int] = None,
37-
) -> List[str]:
28+
) -> Sequence[str]:
3829
depth = 1
3930

4031
lines_by_tree = [
@@ -43,7 +34,7 @@ def lines(
4334
in self._full_tracer.trees
4435
]
4536

46-
return self._flatten(lines_by_tree)
37+
return tuple(commons.flatten(lines_by_tree))
4738

4839
def print_log(self, aggregate = False, max_depth = None) -> None:
4940
"""
@@ -67,11 +58,14 @@ def print_log(self, aggregate = False, max_depth = None) -> None:
6758

6859
def _get_node_log(
6960
self,
70-
node: tracers.TraceNode,
61+
node: types.TraceNode,
7162
depth: int,
7263
aggregate: bool,
7364
max_depth: Optional[int],
74-
) -> List[str]:
65+
) -> Sequence[str]:
66+
67+
node_log: Sequence[str]
68+
children_log: Sequence[Sequence[str]]
7569

7670
if max_depth is not None and depth > max_depth:
7771
return []
@@ -84,12 +78,12 @@ def _get_node_log(
8478
in node.children
8579
]
8680

87-
return node_log + self._flatten(children_logs)
81+
return [*node_log, *commons.flatten(children_logs)]
8882

8983
def _print_line(
9084
self,
9185
depth: int,
92-
node: tracers.TraceNode,
86+
node: types.TraceNode,
9387
aggregate: bool,
9488
max_depth: Optional[int],
9589
) -> str:
@@ -114,9 +108,3 @@ def _print_line(
114108
formatted_value = self.display(value)
115109

116110
return f"{indent}{node.name}<{node.period}> >> {formatted_value}"
117-
118-
def _flatten(
119-
self,
120-
lists: List[List[str]],
121-
) -> List[str]:
122-
return [item for list_ in lists for item in list_]

0 commit comments

Comments
 (0)