Skip to content

Commit

Permalink
Merge pull request #172 from flatironinstitute/development
Browse files Browse the repository at this point in the history
remove warn invalid entry
  • Loading branch information
BalzaniEdoardo authored Jun 18, 2024
2 parents 05e7202 + 4355776 commit 1c36083
Show file tree
Hide file tree
Showing 5 changed files with 0 additions and 277 deletions.
3 changes: 0 additions & 3 deletions src/nemos/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,9 +294,6 @@ def _validate(
# validate input and params consistency
init_params = self._check_params(init_params)

# validate input
validation.warn_invalid_entry(X, y)

# validate input and params consistency
self._check_input_and_params_consistency(init_params, X=X, y=y)

Expand Down
3 changes: 0 additions & 3 deletions src/nemos/glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,9 +760,6 @@ def simulate(
# validate input and params consistency
self._check_input_and_params_consistency(params, X=feedforward_input)

# warn if nans in the input
validation.warn_invalid_entry(feedforward_input)

predicted_rate = self._predict(params, feedforward_input)
return (
self._observation_model.sample_generator(
Expand Down
28 changes: 0 additions & 28 deletions src/nemos/validation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Collection of methods utilities."""

import warnings
from typing import Any, Optional, Union

import jax
Expand All @@ -11,33 +10,6 @@
from .tree_utils import get_valid_multitree, pytree_map_and_reduce


def warn_invalid_entry(*pytree: Any):
"""
Warns if any entry in the provided pytrees contains NaN or Infinite (Inf) values.
Parameters
----------
*pytree :
Variable number of pytrees to check for invalid entries. A pytree is a nested structure of lists, tuples,
dictionaries, or other containers, with leaves that are arrays.
"""
any_infs = pytree_map_and_reduce(
jnp.any, any, jax.tree_util.tree_map(jnp.isinf, pytree)
)
any_nans = pytree_map_and_reduce(
jnp.any, any, jax.tree_util.tree_map(jnp.isnan, pytree)
)
if any_infs and any_nans:
warnings.warn(
message="The provided trees contain Infs and Nans!", category=UserWarning
)
elif any_infs:
warnings.warn(message="The provided trees contain Infs!", category=UserWarning)
elif any_nans:
warnings.warn(message="The provided trees contain Nans!", category=UserWarning)


def error_invalid_entry(*pytree: Any):
"""
Raise an error if any entry in the provided pytrees contains NaN or Infinite (Inf) values.
Expand Down
218 changes: 0 additions & 218 deletions tests/test_glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,51 +142,6 @@ def test_fit_param_length(
with expectation:
model.fit(X, y, init_params=init_params)

@pytest.mark.parametrize(
"add_entry, add_to, expectation",
[
(0, "X", does_not_raise()),
(
np.nan,
"X",
pytest.warns(UserWarning, match="The provided trees contain"),
),
(
np.inf,
"X",
pytest.warns(UserWarning, match="The provided trees contain"),
),
(0, "y", does_not_raise()),
(
np.nan,
"y",
pytest.warns(UserWarning, match="The provided trees contain"),
),
(
np.inf,
"y",
pytest.warns(UserWarning, match="The provided trees contain"),
),
],
)
def test_fit_param_values(
self, add_entry, add_to, expectation, poissonGLM_model_instantiation
):
"""
Test the `fit` method with altered X or y values. Ensure the method raises exceptions for NaN or Inf values.
"""
X, y, model, true_params, firing_rate = poissonGLM_model_instantiation
if add_to == "X":
# get an index to be edited
idx = np.unravel_index(np.random.choice(X.size), X.shape)
X[idx] = add_entry
elif add_to == "y":
idx = np.unravel_index(np.random.choice(y.size), y.shape)
y = np.asarray(y, dtype=np.float32)
y[idx] = add_entry
with expectation:
model.fit(X, y, init_params=true_params)

@pytest.mark.parametrize(
"dim_weights, expectation",
[
Expand Down Expand Up @@ -871,51 +826,6 @@ def test_initialize_solver_param_length(
with expectation:
model.initialize_solver(X, y, init_params=init_params)

@pytest.mark.parametrize(
"add_entry, add_to, expectation",
[
(0, "X", does_not_raise()),
(
np.nan,
"X",
pytest.warns(UserWarning, match="The provided trees contain"),
),
(
np.inf,
"X",
pytest.warns(UserWarning, match="The provided trees contain"),
),
(0, "y", does_not_raise()),
(
np.nan,
"y",
pytest.warns(UserWarning, match="The provided trees contain"),
),
(
np.inf,
"y",
pytest.warns(UserWarning, match="The provided trees contain"),
),
],
)
def test_initialize_solver_param_values(
self, add_entry, add_to, expectation, poissonGLM_model_instantiation
):
"""
Test the `initialize_solver` method with altered X or y values. Ensure the method raises exceptions for NaN or Inf values.
"""
X, y, model, true_params, firing_rate = poissonGLM_model_instantiation
if add_to == "X":
# get an index to be edited
idx = np.unravel_index(np.random.choice(X.size), X.shape)
X[idx] = add_entry
elif add_to == "y":
idx = np.unravel_index(np.random.choice(y.size), y.shape)
y = np.asarray(y, dtype=np.float32)
y[idx] = add_entry
with expectation:
model.initialize_solver(X, y, init_params=true_params)

@pytest.mark.parametrize(
"dim_weights, expectation",
[
Expand Down Expand Up @@ -1404,24 +1314,6 @@ def test_simulate_feedforward_glm(self, poissonGLM_model_instantiation):
# check the time point number is that expected (same as the input)
assert ysim.shape[0] == X.shape[0]

@pytest.mark.parametrize(
"insert, expectation",
[
(0, does_not_raise()),
(np.nan, pytest.warns(UserWarning, match=r"The provided trees contain")),
(np.inf, pytest.warns(UserWarning, match=r"The provided trees contain")),
],
)
def test_simulate_invalid_feedforward(
self, insert, expectation, poissonGLM_model_instantiation
):
X, y, model, params, rate = poissonGLM_model_instantiation
model.coef_ = params[0]
model.intercept_ = params[1]
X[0] = insert
with expectation:
model.simulate(jax.random.key(123), X)

@pytest.mark.parametrize("inv_link", [jnp.exp, lambda x: 1 / x])
def test_simulate_gamma_glm(self, inv_link, gammaGLM_model_instantiation):
X, y, model, true_params, firing_rate = gammaGLM_model_instantiation
Expand Down Expand Up @@ -1589,52 +1481,6 @@ def test_estimate_dof_resid(
num = model.estimate_resid_degrees_of_freedom(X, n_samples=n_samples)
assert int(num) == num


@pytest.mark.parametrize(
"add_entry, add_to, expectation",
[
(0, "X", does_not_raise()),
(
np.nan,
"X",
pytest.warns(UserWarning, match="The provided trees contain"),
),
(
np.inf,
"X",
pytest.warns(UserWarning, match="The provided trees contain"),
),
(0, "y", does_not_raise()),
(
np.nan,
"y",
pytest.warns(UserWarning, match="The provided trees contain"),
),
(
np.inf,
"y",
pytest.warns(UserWarning, match="The provided trees contain"),
),
],
)
def test_fit_param_values(
self, add_entry, add_to, expectation, poisson_population_GLM_model
):
"""
Test the `fit` method with altered X or y values. Ensure the method raises exceptions for NaN or Inf values.
"""
X, y, model, true_params, firing_rate = poisson_population_GLM_model
if add_to == "X":
# get an index to be edited
idx = np.unravel_index(np.random.choice(X.size), X.shape)
X[idx] = add_entry
elif add_to == "y":
idx = np.unravel_index(np.random.choice(y.size), y.shape)
y = np.asarray(y, dtype=np.float32)
y[idx] = add_entry
with expectation:
model.fit(X, y, init_params=true_params)

@pytest.mark.parametrize(
"dim_weights, expectation",
[
Expand Down Expand Up @@ -2019,51 +1865,6 @@ def test_initialize_solver_param_length(
with expectation:
model.initialize_solver(X, y, init_params=init_params)

@pytest.mark.parametrize(
"add_entry, add_to, expectation",
[
(0, "X", does_not_raise()),
(
np.nan,
"X",
pytest.warns(UserWarning, match="The provided trees contain"),
),
(
np.inf,
"X",
pytest.warns(UserWarning, match="The provided trees contain"),
),
(0, "y", does_not_raise()),
(
np.nan,
"y",
pytest.warns(UserWarning, match="The provided trees contain"),
),
(
np.inf,
"y",
pytest.warns(UserWarning, match="The provided trees contain"),
),
],
)
def test_initialize_solver_param_values(
self, add_entry, add_to, expectation, poisson_population_GLM_model
):
"""
Test the `initialize_solver` method with altered X or y values. Ensure the method raises exceptions for NaN or Inf values.
"""
X, y, model, true_params, firing_rate = poisson_population_GLM_model
if add_to == "X":
# get an index to be edited
idx = np.unravel_index(np.random.choice(X.size), X.shape)
X[idx] = add_entry
elif add_to == "y":
idx = np.unravel_index(np.random.choice(y.size), y.shape)
y = np.asarray(y, dtype=np.float32)
y[idx] = add_entry
with expectation:
model.initialize_solver(X, y, init_params=true_params)

@pytest.mark.parametrize(
"dim_weights, expectation",
[
Expand Down Expand Up @@ -2828,25 +2629,6 @@ def test_simulate_feedforward_glm(self, poissonGLM_model_instantiation):
# check the time point number is that expected (same as the input)
assert ysim.shape[0] == X.shape[0]

@pytest.mark.parametrize(
"insert, expectation",
[
(0, does_not_raise()),
(np.nan, pytest.warns(UserWarning, match=r"The provided trees contain")),
(np.inf, pytest.warns(UserWarning, match=r"The provided trees contain")),
],
)
def test_simulate_invalid_feedforward(
self, insert, expectation, poisson_population_GLM_model
):
X, y, model, params, rate = poisson_population_GLM_model
model.coef_ = params[0]
model.intercept_ = params[1]
model._initialize_feature_mask(X, y)
X[0] = insert
with expectation:
model.simulate(jax.random.key(123), X)

@pytest.mark.parametrize("inv_link", [jnp.exp, lambda x: 1 / x])
def test_simulate_gamma_glm(self, inv_link, gamma_population_GLM_model):
X, y, model, true_params, firing_rate = gamma_population_GLM_model
Expand Down
25 changes: 0 additions & 25 deletions tests/test_vallidation.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,31 +31,6 @@ def test_error_invalid_entry(tree, expectation):
validation.error_invalid_entry(valid_data, tree)


@pytest.mark.parametrize(
"tree, expectation",
[
(jnp.array([[1], [2], [3]]), does_not_raise()),
(
jnp.array([[1], [2], [jnp.nan]]),
pytest.warns(UserWarning, match="The provided trees contain Nans"),
),
(
jnp.array([[1], [jnp.inf], [3]]),
pytest.warns(UserWarning, match="The provided trees contain Infs"),
),
(
jnp.array([[1], [jnp.inf], [jnp.nan]]),
pytest.warns(UserWarning, match="The provided trees contain Infs and Nans"),
),
],
)
def test_warn_invalid_entry(tree, expectation):
"""Test validation of trees generates the correct exceptions."""
valid_data = jnp.array([[1], [2], [3]])
with expectation:
validation.warn_invalid_entry(valid_data, tree)


@pytest.mark.parametrize(
"fill_val, expectation",
[
Expand Down

0 comments on commit 1c36083

Please sign in to comment.