Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve type annotations in sklearn.metrics._regression #357

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

matkozak
Copy link

Improve type annotations in scikit-learn's regression metrics to more accurately represent the relationship between input parameters and return types. Specifically:

  1. Add overloads to correctly type the relationship between the multioutput parameter and return types. Functions return ndarray when multioutput="raw_values" and floats for other multioutput options.
  2. Fix inconsistent float type annotations: built-in float for functions that explicitly convert their result with float(result), Float type alias for functions that return NumPy floating-point types (np.float64, etc.)
  3. Correct d2_tweedie_score return type which claims to return "float or ndarray of floats" in the docstring but has no code path that returns an ndarray.

Various sklearn metrics return floats or ndarrays based on the value
of `multioutput` parameter. This commit adds overloads for the separate
paths.
Various sklearn metrics return either a standard Python float,
or a numpy flating point scalar type.

E.g.

```
>>> import numpy as np
>>> from sklearn.metrics import mean_absolute_error, median_absolute_error
>>> a = np.array([1,2,3])
>>> b = np.array([4,5,6])
>>> type(mean_absolute_error(a,b))
float
>>> type(median_absolute_error(a,b))
numpy.float64
```

This commit fixes the type annotations for the following functions:
- `mean_absolute_error`
- `mean_absolute_percentage_error`
- `mean_squared_error`
- `r2_score`
- `mean_tweedie_deviance`
- `d2_pinball_score`
- `d2_absolute_error_score`
The docs say float or ndarray but there is not ndarray return path.
@matkozak
Copy link
Author

@microsoft-github-policy-service agree

@debonte
Copy link
Contributor

debonte commented Mar 25, 2025

@matkozak, thanks for the contribution! Can you please add some unit tests for these overloads similar to https://github.com/microsoft/python-type-stubs/blob/main/tests/sklearn/preprocessing_tests.py?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants