Skip to content

Commit 0719af4

Browse files
committed
Merge remote-tracking branch 'upstream/main' into compliant-column
2 parents 4505d59 + 11fe33f commit 0719af4

File tree

11 files changed

+579
-7
lines changed

11 files changed

+579
-7
lines changed

docs/api-reference/expr.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
- filter
2424
- clip
2525
- is_between
26+
- is_close
2627
- is_duplicated
2728
- is_finite
2829
- is_first_distinct

docs/api-reference/series.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
- hist
3636
- implementation
3737
- is_between
38+
- is_close
3839
- is_duplicated
3940
- is_empty
4041
- is_finite

narwhals/_compliant/expr.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -821,6 +821,22 @@ def is_between(
821821
"is_between", lower_bound=lower_bound, upper_bound=upper_bound, closed=closed
822822
)
823823

824+
def is_close(
825+
self,
826+
other: Self | NumericLiteral,
827+
*,
828+
abs_tol: float,
829+
rel_tol: float,
830+
nans_equal: bool,
831+
) -> Self:
832+
return self._reuse_series(
833+
"is_close",
834+
other=other,
835+
abs_tol=abs_tol,
836+
rel_tol=rel_tol,
837+
nans_equal=nans_equal,
838+
)
839+
824840
@property
825841
def cat(self) -> EagerExprCatNamespace[Self]:
826842
return EagerExprCatNamespace(self)

narwhals/_compliant/series.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
NativeSeriesT,
1818
NativeSeriesT_co,
1919
)
20+
from narwhals._compliant.utils import IsClose
2021
from narwhals._translate import FromIterable, FromNative, NumpyConvertible, ToNarwhals
2122
from narwhals._typing_compat import TypeVar, assert_never
2223
from narwhals._utils import (
@@ -73,6 +74,7 @@ class HistData(TypedDict, Generic[NativeSeriesT, "_CountsT_co"]):
7374

7475

7576
class CompliantSeries(
77+
IsClose,
7678
NumpyConvertible["_1DArray", "Into1DArray"],
7779
FromIterable,
7880
FromNative[NativeSeriesT],

narwhals/_compliant/utils.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING, Any, Protocol
4+
5+
if TYPE_CHECKING:
6+
from typing_extensions import Self
7+
8+
from narwhals.typing import NumericLiteral, TemporalLiteral
9+
10+
11+
class IsClose(Protocol):
12+
"""Every member defined is a dependency of `is_close` method."""
13+
14+
def __and__(self, other: Any) -> Self: ...
15+
def __or__(self, other: Any) -> Self: ...
16+
def __invert__(self) -> Self: ...
17+
def __sub__(self, other: Any) -> Self: ...
18+
def __mul__(self, other: Any) -> Self: ...
19+
def __eq__(self, other: Self | Any) -> Self: ... # type: ignore[override]
20+
def __gt__(self, other: Any) -> Self: ...
21+
def __le__(self, other: Any) -> Self: ...
22+
def abs(self) -> Self: ...
23+
def is_nan(self) -> Self: ...
24+
def is_finite(self) -> Self: ...
25+
def clip(
26+
self,
27+
lower_bound: Self | NumericLiteral | TemporalLiteral | None,
28+
upper_bound: Self | NumericLiteral | TemporalLiteral | None,
29+
) -> Self: ...
30+
def is_close(
31+
self,
32+
other: Self | NumericLiteral,
33+
*,
34+
abs_tol: float,
35+
rel_tol: float,
36+
nans_equal: bool,
37+
) -> Self:
38+
from decimal import Decimal
39+
40+
other_abs: Self | NumericLiteral
41+
other_is_nan: Self | bool
42+
other_is_inf: Self | bool
43+
other_is_not_inf: Self | bool
44+
45+
if isinstance(other, (float, int, Decimal)):
46+
from math import isinf, isnan
47+
48+
# NOTE: See https://discuss.python.org/t/inferred-type-of-function-that-calls-dunder-abs-abs/101447
49+
other_abs = other.__abs__()
50+
other_is_nan = isnan(other)
51+
other_is_inf = isinf(other)
52+
53+
# Define the other_is_not_inf variable to prevent triggering the following warning:
54+
# > DeprecationWarning: Bitwise inversion '~' on bool is deprecated and will be
55+
# > removed in Python 3.16.
56+
other_is_not_inf = not other_is_inf
57+
58+
else:
59+
other_abs, other_is_nan = other.abs(), other.is_nan()
60+
other_is_not_inf = other.is_finite() | other_is_nan
61+
other_is_inf = ~other_is_not_inf
62+
63+
rel_threshold = self.abs().clip(lower_bound=other_abs, upper_bound=None) * rel_tol
64+
tolerance = rel_threshold.clip(lower_bound=abs_tol, upper_bound=None)
65+
66+
self_is_nan = self.is_nan()
67+
self_is_not_inf = self.is_finite() | self_is_nan
68+
69+
# Values are close if abs_diff <= tolerance, and both finite
70+
is_close = (
71+
((self - other).abs() <= tolerance) & self_is_not_inf & other_is_not_inf
72+
)
73+
74+
# Handle infinity cases: infinities are close/equal if they have the same sign
75+
self_sign, other_sign = self > 0, other > 0
76+
is_same_inf = (~self_is_not_inf) & other_is_inf & (self_sign == other_sign)
77+
78+
# Handle nan cases:
79+
# * If any value is NaN, then False (via `& ~either_nan`)
80+
# * However, if `nans_equals = True` and if _both_ values are NaN, then True
81+
either_nan = self_is_nan | other_is_nan
82+
result = (is_close | is_same_inf) & ~either_nan
83+
84+
if nans_equal:
85+
both_nan = self_is_nan & other_is_nan
86+
result = result | both_nan
87+
88+
return result

narwhals/_polars/expr.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from narwhals._polars.dataframe import Method, PolarsDataFrame
2828
from narwhals._polars.namespace import PolarsNamespace
2929
from narwhals._utils import Version, _LimitedContext
30-
from narwhals.typing import IntoDType
30+
from narwhals.typing import IntoDType, NumericLiteral
3131

3232

3333
class PolarsExpr:
@@ -232,6 +232,46 @@ def __narwhals_namespace__(self) -> PolarsNamespace: # pragma: no cover
232232

233233
return PolarsNamespace(version=self._version)
234234

235+
def is_close(
236+
self,
237+
other: Self | NumericLiteral,
238+
*,
239+
abs_tol: float,
240+
rel_tol: float,
241+
nans_equal: bool,
242+
) -> Self:
243+
left = self.native
244+
right = other.native if isinstance(other, PolarsExpr) else pl.lit(other)
245+
246+
if self._backend_version < (1, 32, 0):
247+
lower_bound = right.abs()
248+
tolerance = (left.abs().clip(lower_bound) * rel_tol).clip(abs_tol)
249+
250+
# Values are close if abs_diff <= tolerance, and both finite
251+
abs_diff = (left - right).abs()
252+
all_ = pl.all_horizontal
253+
is_close = all_((abs_diff <= tolerance), left.is_finite(), right.is_finite())
254+
255+
# Handle infinity cases: infinities are "close" only if they have the same sign
256+
is_same_inf = all_(
257+
left.is_infinite(), right.is_infinite(), (left.sign() == right.sign())
258+
)
259+
260+
# Handle nan cases:
261+
# * nans_equals = True => if both values are NaN, then True
262+
# * nans_equals = False => if any value is NaN, then False
263+
left_is_nan, right_is_nan = left.is_nan(), right.is_nan()
264+
either_nan = left_is_nan | right_is_nan
265+
result = (is_close | is_same_inf) & either_nan.not_()
266+
267+
if nans_equal:
268+
result = result | (left_is_nan & right_is_nan)
269+
else:
270+
result = left.is_close(
271+
right, abs_tol=abs_tol, rel_tol=rel_tol, nans_equal=nans_equal
272+
)
273+
return self._with_native(result)
274+
235275
@property
236276
def dt(self) -> PolarsExprDateTimeNamespace:
237277
return PolarsExprDateTimeNamespace(self)

narwhals/_polars/series.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
IntoDType,
4242
MultiIndexSelector,
4343
NonNestedLiteral,
44+
NumericLiteral,
4445
_1DArray,
4546
)
4647

@@ -90,6 +91,7 @@
9091
"gather_every",
9192
"head",
9293
"is_between",
94+
"is_close",
9395
"is_duplicated",
9496
"is_empty",
9597
"is_finite",
@@ -133,6 +135,11 @@
133135
class PolarsSeries:
134136
_implementation = Implementation.POLARS
135137

138+
_HIST_EMPTY_SCHEMA: ClassVar[Mapping[IncludeBreakpoint, Sequence[str]]] = {
139+
True: ["breakpoint", "count"],
140+
False: ["count"],
141+
}
142+
136143
def __init__(self, series: pl.Series, *, version: Version) -> None:
137144
self._native_series: pl.Series = series
138145
self._version = version
@@ -480,10 +487,27 @@ def __contains__(self, other: Any) -> bool:
480487
except Exception as e: # noqa: BLE001
481488
raise catch_polars_exception(e) from None
482489

483-
_HIST_EMPTY_SCHEMA: ClassVar[Mapping[IncludeBreakpoint, Sequence[str]]] = {
484-
True: ["breakpoint", "count"],
485-
False: ["count"],
486-
}
490+
def is_close(
491+
self,
492+
other: Self | NumericLiteral,
493+
*,
494+
abs_tol: float,
495+
rel_tol: float,
496+
nans_equal: bool,
497+
) -> PolarsSeries:
498+
if self._backend_version < (1, 32, 0):
499+
name = self.name
500+
ns = self.__narwhals_namespace__()
501+
other_expr = other._to_expr() if isinstance(other, PolarsSeries) else other
502+
expr = ns.col(name).is_close(
503+
other_expr, abs_tol=abs_tol, rel_tol=rel_tol, nans_equal=nans_equal
504+
)
505+
return self.to_frame().select(expr).get_column(name)
506+
other_series = other.native if isinstance(other, PolarsSeries) else other
507+
result = self.native.is_close(
508+
other_series, abs_tol=abs_tol, rel_tol=rel_tol, nans_equal=nans_equal
509+
)
510+
return self._with_native(result)
487511

488512
def hist_from_bins(
489513
self, bins: list[float], *, include_breakpoint: bool

narwhals/expr.py

Lines changed: 92 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
)
1414
from narwhals._utils import _validate_rolling_arguments, ensure_type, flatten
1515
from narwhals.dtypes import _validate_dtype
16-
from narwhals.exceptions import InvalidOperationError
16+
from narwhals.exceptions import ComputeError, InvalidOperationError
1717
from narwhals.expr_cat import ExprCatNamespace
1818
from narwhals.expr_dt import ExprDateTimeNamespace
1919
from narwhals.expr_list import ExprListNamespace
@@ -2254,6 +2254,97 @@ def sqrt(self) -> Self:
22542254
"""
22552255
return self._with_elementwise(lambda plx: self._to_compliant_expr(plx).sqrt())
22562256

2257+
def is_close(
2258+
self,
2259+
other: Self | NumericLiteral,
2260+
*,
2261+
abs_tol: float = 0.0,
2262+
rel_tol: float = 1e-09,
2263+
nans_equal: bool = False,
2264+
) -> Self:
2265+
r"""Check if this expression is close, i.e. almost equal, to the other expression.
2266+
2267+
Two values `a` and `b` are considered close if the following condition holds:
2268+
2269+
$$
2270+
|a-b| \le max \{ \text{rel\_tol} \cdot max \{ |a|, |b| \}, \text{abs\_tol} \}
2271+
$$
2272+
2273+
Arguments:
2274+
other: Values to compare with.
2275+
abs_tol: Absolute tolerance. This is the maximum allowed absolute difference
2276+
between two values. Must be non-negative.
2277+
rel_tol: Relative tolerance. This is the maximum allowed difference between
2278+
two values, relative to the larger absolute value. Must be in the range
2279+
[0, 1).
2280+
nans_equal: Whether NaN values should be considered equal.
2281+
2282+
Returns:
2283+
Expression of Boolean data type.
2284+
2285+
Notes:
2286+
The implementation of this method is symmetric and mirrors the behavior of
2287+
`math.isclose`. Specifically note that this behavior is different to
2288+
`numpy.isclose`.
2289+
2290+
Examples:
2291+
>>> import duckdb
2292+
>>> import pyarrow as pa
2293+
>>> import narwhals as nw
2294+
>>>
2295+
>>> data = {
2296+
... "x": [1.0, float("inf"), 1.41, None, float("nan")],
2297+
... "y": [1.2, float("inf"), 1.40, None, float("nan")],
2298+
... }
2299+
>>> _table = pa.table(data)
2300+
>>> df_native = duckdb.table("_table")
2301+
>>> df = nw.from_native(df_native)
2302+
>>> df.with_columns(
2303+
... is_close=nw.col("x").is_close(
2304+
... nw.col("y"), abs_tol=0.1, nans_equal=True
2305+
... )
2306+
... )
2307+
┌──────────────────────────────┐
2308+
| Narwhals LazyFrame |
2309+
|------------------------------|
2310+
|┌────────┬────────┬──────────┐|
2311+
|│ x │ y │ is_close │|
2312+
|│ double │ double │ boolean │|
2313+
|├────────┼────────┼──────────┤|
2314+
|│ 1.0 │ 1.2 │ false │|
2315+
|│ inf │ inf │ true │|
2316+
|│ 1.41 │ 1.4 │ true │|
2317+
|│ NULL │ NULL │ NULL │|
2318+
|│ nan │ nan │ true │|
2319+
|└────────┴────────┴──────────┘|
2320+
└──────────────────────────────┘
2321+
"""
2322+
if abs_tol < 0:
2323+
msg = f"`abs_tol` must be non-negative but got {abs_tol}"
2324+
raise ComputeError(msg)
2325+
2326+
if not (0 <= rel_tol < 1):
2327+
msg = f"`rel_tol` must be in the range [0, 1) but got {rel_tol}"
2328+
raise ComputeError(msg)
2329+
2330+
kwargs = {"abs_tol": abs_tol, "rel_tol": rel_tol, "nans_equal": nans_equal}
2331+
return self.__class__(
2332+
lambda plx: apply_n_ary_operation(
2333+
plx,
2334+
lambda *exprs: exprs[0].is_close(exprs[1], **kwargs),
2335+
self,
2336+
other,
2337+
str_as_lit=False,
2338+
),
2339+
combine_metadata(
2340+
self,
2341+
other,
2342+
str_as_lit=False,
2343+
allow_multi_output=False,
2344+
to_single_output=False,
2345+
),
2346+
)
2347+
22572348
@property
22582349
def str(self) -> ExprStringNamespace[Self]:
22592350
return ExprStringNamespace(self)

0 commit comments

Comments
 (0)