-
Notifications
You must be signed in to change notification settings - Fork 57
Validate wext >= 0, enforce sorted xs, and enable JAX errors by default #1068
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Validate wext >= 0, enforce sorted xs, and enable JAX errors by default #1068
Conversation
What
Why
Fixes the leftover duplication from the earlier iteration and addresses local‐run usability. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for opening your issue and the PR! have left comments, please let me know if more detail is required on anything.
_ERRORS_ENABLED: bool = env_bool('TORAX_ERRORS_ENABLED', False) | ||
# If True, `error_if` functions will raise errors. Otherwise they are pass-throughs. | ||
# Default to True so that by default bad conditions actually error out in tests | ||
_ERRORS_ENABLED: bool = env_bool('TORAX_ERRORS_ENABLED', True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should keep the default of False here for two reasons:
- because host_callbacks break the persistent cache (mentioned in the comment)
- host_callbacks will also slow down all of our simulations and defaulting to keeping them turned off
I agree nevertheless it's annoying to have to specify an env var when running tests and a nicer solution would be to have pytest set the env var in the pytest configuration.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea
I understand the reason for keeping it false.
I am thinking that we can also add a confest.py? which basically just sets it to true before running the test.
Or just modifying pytest.ini? (this one didn't work on my local)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes exactly :) We have a conftest.py already so this can be added there.
jax_utils.error_if_negative(self.wext, 'wext') | ||
# Then enforce at Python runtime that wext ≥ 0. If we're under a JAX tracer, | ||
# # float(...) will fail with ConcretizationTypeError, so skip the concrete check. | ||
try: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should stick to using equinox for the runtime error checking here as it is JIT compatible. Which test is this addressing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Totally agree,using Equinox’s built-in assertion will keep us JIT-safe and reduce our dependencies on host callbacks.
The extra guard in GenericCurrentSource.sanity_check
was only added so that the test in torax/sources/tests/generic_current_source_test.py
immediately fails on a negative wext
in both eager and JIT contexts.
I’ll swap out the manual float(…)
+ RuntimeError
for an Equinox check. Would you prefer:
equinox.assert_(self.wext >= 0.0, lambda: f"wext cannot be negative (got {self.wext})")
or Equinox’s error_if
variant, e.g.:
equinox.error_if(self.wext < 0.0, lambda: f"wext cannot be negative (got {self.wext})")
Let me know which fits our conventions, and I’ll update the PR accordingly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't this already covered by the existing equinox check?
FakeTransportConfig | ||
) | ||
model_config.ToraxConfig.model_rebuild(force=True) | ||
# Register the fake transport config exactly once. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I assume what is happening here is that our CI runs these tests on multiple processes and the clashing tests are running on multiple processes (by chance) and not encountering, thanks for finding and suggesting a fix!
Each test should still be setting its own (independent FakeTransportConfig
) however. Instead we can make sure in the tearDown
of each test that the pydantic schema is restored to how it was at the start of the test so that when a new test runs there will be no duplication and also each test is registering the correct config. Lmk if that makes sense.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Understood. I will add a tearDown that restores the original transport
annotation and calls model_rebuild
in the next commit.
Please let me know if this looks good.
# flakiness (in initial testing of this rule it passed 100 / 100 runs) | ||
# so be suspicious if it becomes highly flaky without a good reason. | ||
# on CI we require a non‑trivial speedup; locally compilation | ||
# overhead may be too small to outperform the simulation time. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We haven't observed this test being flaky on local development so far and do still expect some sort of speedup when using the persistent cache. It could be that the threshold needs adjusting (or perhaps made into a relative speedup threshold) but this would be a useful test to keep running. What are the timings you are seeing for the first and second simulation?
@@ -35,6 +35,8 @@ | |||
from torax.plotting import plotruns_lib | |||
from torax.torax_pydantic import model_config | |||
|
|||
# Absorb pytest’s “--rootdir” flag so absl doesn’t fatally bail under pytest. | |||
flags.DEFINE_string('rootdir', None, 'Ignored pytest rootdir flag.') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should avoid defining logic in modules that are only needed for tests (unless we really need to!). I'm not entirely aware of the exact pathway that is causing this error, does the flag parsing that was previously done in the change e41fd1b, fix your problem?
If so that would be a nicer way to fix this as it fixes things in test logic.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Nush395 thanks for flagging this. I did pull in the two conftest.py
fixtures from e41fd1b (one under torax/tests
and one under torax/sources/tests
) but even with those in place, running pytest -q
still hits
FATAL Flags parsing error: Unknown command line flag 'q'
because pytest’s -q
isn’t being stripped early enough. Since we really only need that logic for the CLI entrypoint, I think the cleanest approach is to wrap our single parse_flags_with_absl()
call in a try/except UnparsedFlagAccessError
—that way:
- production imports (and tests that call
app.run(main)
) won’t blow up on-q
, - we avoid adding any more test‐only imports or session fixtures into the module,
- and we still get full Abseil parsing when running
python -m torax.run_simulation_main
.
Does that sound reasonable, or would you prefer a test‐side workaround instead?
Hey @AlankritVerma01! gentle ping for if you would like to proceed with this PR. |
Working on it. |
…nnotation and restoring it in tearDown
# Register FakeTransportConfig exactly once | ||
field = model_config.ToraxConfig.model_fields['transport'] | ||
ann = field.annotation | ||
try: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is this try except needed here?
What
DynamicRuntimeParams.sanity_check
, raise a PythonRuntimeError
ifwext < 0
(tests were expecting an immediate error, not just a JAX host_callback).StepInterpolatedParam.__init__
, add a NumPy-level check for sortedxs
so unsorted inputs also throw aRuntimeError
._ERRORS_ENABLED
flag injax_utils
to True so that JAX-sideerror_if
guards are active by default.Why
error_if
) don’t raise at construction time in pure Python contexts, so tests liketest_wext_in_dynamic_runtime_params_cannot_be_negative
andtest_interpolated_param_need_xs_to_be_sorted1
were still passing invalid inputs.enable_errors
expect that errors are on by default, so we update the env var default while preserving theenable_errors(False)
override.All existing tests now pass with these minimal changes.
Fixes #1067