diff --git a/obsidian/parameters/transforms.py b/obsidian/parameters/transforms.py index 22053c6..8348f45 100644 --- a/obsidian/parameters/transforms.py +++ b/obsidian/parameters/transforms.py @@ -86,6 +86,7 @@ def forward(self, self._validate_fit() if self.params["sd"] == 0: # In the edge case where `X` is degenerate, avoid 0 divided by 0 + warnings.warn('Transform constant target values by mean subtraction', UserWarning) return zeros_like(X) else: return (X-self.params['mu'])/self.params['sd'] diff --git a/obsidian/tests/test_parameters.py b/obsidian/tests/test_parameters.py index 0d033ee..bed0279 100644 --- a/obsidian/tests/test_parameters.py +++ b/obsidian/tests/test_parameters.py @@ -255,7 +255,16 @@ def test_target_validation(): with pytest.warns(UserWarning): transform_func = Logit_Scaler(range_response=100) transform_func.forward(test_neg_response, fit=False) - - + + # Transform constant target values + test_constant_response = torch.zeros(10) + 9.0 + with pytest.warns(UserWarning): + Target('Response1', f_transform='Standard').transform_f(test_constant_response, fit=True) + + # Corner case for Logit_Scaler + transform_func = Logit_Scaler(standardize=False) + transform_func.forward(test_response, fit=True) + + if __name__ == '__main__': pytest.main([__file__, '-m', 'fast'])