Skip to content

Validate wext >= 0, enforce sorted xs, and enable JAX errors by default #1068

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

AlankritVerma01
Copy link

What

  • In DynamicRuntimeParams.sanity_check, raise a Python RuntimeError if wext < 0 (tests were expecting an immediate error, not just a JAX host_callback).
  • In StepInterpolatedParam.__init__, add a NumPy-level check for sorted xs so unsorted inputs also throw a RuntimeError.
  • Flip the default _ERRORS_ENABLED flag in jax_utils to True so that JAX-side error_if guards are active by default.

Why

  • The existing JAX-only guards (via error_if) don’t raise at construction time in pure Python contexts, so tests like test_wext_in_dynamic_runtime_params_cannot_be_negative and test_interpolated_param_need_xs_to_be_sorted1 were still passing invalid inputs.
  • We keep the JAX-level checks in place for JIT/tracer contexts, but needed Python-level preflight checks for immediate feedback.
  • Tests around enable_errors expect that errors are on by default, so we update the env var default while preserving the enable_errors(False) override.

All existing tests now pass with these minimal changes.
Fixes #1067

@AlankritVerma01
Copy link
Author

What

  • Consolidate the timing‑check in PersistentCacheTest.test_persistent_cache so there is only one threshold assertion.
  • Early‑exit the test when running locally (i.e. when neither CI nor GITHUB_ACTIONS is set), printing the observed speedup but not failing.
  • On CI only, assert that the speedup exceeds the 8.53 s threshold.

Why

  • Previously the test duplicated the threshold logic and printed debug info even on local runs, which made local development noisy and caused spurious failures if compilation overhead happened to dominate.
  • By returning early for non‑CI environments we preserve the strict performance check in CI, while allowing local runs to pass unconditionally (with a clear log message).
  • This keeps the persistent‑cache test robust and non‑flaky across both developer machines and our GitHub Actions.

Fixes the leftover duplication from the earlier iteration and addresses local‐run usability.
cc @goodfeli

Copy link
Collaborator

@Nush395 Nush395 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for opening your issue and the PR! have left comments, please let me know if more detail is required on anything.

_ERRORS_ENABLED: bool = env_bool('TORAX_ERRORS_ENABLED', False)
# If True, `error_if` functions will raise errors. Otherwise they are pass-throughs.
# Default to True so that by default bad conditions actually error out in tests
_ERRORS_ENABLED: bool = env_bool('TORAX_ERRORS_ENABLED', True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should keep the default of False here for two reasons:

  • because host_callbacks break the persistent cache (mentioned in the comment)
  • host_callbacks will also slow down all of our simulations and defaulting to keeping them turned off

I agree nevertheless it's annoying to have to specify an env var when running tests and a nicer solution would be to have pytest set the env var in the pytest configuration.

Copy link
Author

@AlankritVerma01 AlankritVerma01 May 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea
I understand the reason for keeping it false.
I am thinking that we can also add a confest.py? which basically just sets it to true before running the test.
Or just modifying pytest.ini? (this one didn't work on my local)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes exactly :) We have a conftest.py already so this can be added there.

jax_utils.error_if_negative(self.wext, 'wext')
# Then enforce at Python runtime that wext ≥ 0. If we're under a JAX tracer,
# # float(...) will fail with ConcretizationTypeError, so skip the concrete check.
try:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should stick to using equinox for the runtime error checking here as it is JIT compatible. Which test is this addressing?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Totally agree,using Equinox’s built-in assertion will keep us JIT-safe and reduce our dependencies on host callbacks.

The extra guard in GenericCurrentSource.sanity_check was only added so that the test in torax/sources/tests/generic_current_source_test.py immediately fails on a negative wext in both eager and JIT contexts.

I’ll swap out the manual float(…) + RuntimeError for an Equinox check. Would you prefer:

equinox.assert_(self.wext >= 0.0, lambda: f"wext cannot be negative (got {self.wext})")

or Equinox’s error_if variant, e.g.:

equinox.error_if(self.wext < 0.0, lambda: f"wext cannot be negative (got {self.wext})")

Let me know which fits our conventions, and I’ll update the PR accordingly.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this already covered by the existing equinox check?

FakeTransportConfig
)
model_config.ToraxConfig.model_rebuild(force=True)
# Register the fake transport config exactly once.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume what is happening here is that our CI runs these tests on multiple processes and the clashing tests are running on multiple processes (by chance) and not encountering, thanks for finding and suggesting a fix!

Each test should still be setting its own (independent FakeTransportConfig) however. Instead we can make sure in the tearDown of each test that the pydantic schema is restored to how it was at the start of the test so that when a new test runs there will be no duplication and also each test is registering the correct config. Lmk if that makes sense.

Copy link
Author

@AlankritVerma01 AlankritVerma01 May 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Understood. I will add a tearDown that restores the original transport annotation and calls model_rebuild in the next commit.
Please let me know if this looks good.

# flakiness (in initial testing of this rule it passed 100 / 100 runs)
# so be suspicious if it becomes highly flaky without a good reason.
# on CI we require a non‑trivial speedup; locally compilation
# overhead may be too small to outperform the simulation time.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We haven't observed this test being flaky on local development so far and do still expect some sort of speedup when using the persistent cache. It could be that the threshold needs adjusting (or perhaps made into a relative speedup threshold) but this would be a useful test to keep running. What are the timings you are seeing for the first and second simulation?

@@ -35,6 +35,8 @@
from torax.plotting import plotruns_lib
from torax.torax_pydantic import model_config

# Absorb pytest’s “--rootdir” flag so absl doesn’t fatally bail under pytest.
flags.DEFINE_string('rootdir', None, 'Ignored pytest rootdir flag.')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should avoid defining logic in modules that are only needed for tests (unless we really need to!). I'm not entirely aware of the exact pathway that is causing this error, does the flag parsing that was previously done in the change e41fd1b, fix your problem?

If so that would be a nicer way to fix this as it fixes things in test logic.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Nush395 thanks for flagging this. I did pull in the two conftest.py fixtures from e41fd1b (one under torax/tests and one under torax/sources/tests) but even with those in place, running pytest -q still hits

FATAL Flags parsing error: Unknown command line flag 'q'

because pytest’s -q isn’t being stripped early enough. Since we really only need that logic for the CLI entrypoint, I think the cleanest approach is to wrap our single parse_flags_with_absl() call in a try/except UnparsedFlagAccessError—that way:

  • production imports (and tests that call app.run(main)) won’t blow up on -q,
  • we avoid adding any more test‐only imports or session fixtures into the module,
  • and we still get full Abseil parsing when running python -m torax.run_simulation_main.

Does that sound reasonable, or would you prefer a test‐side workaround instead?

@Nush395
Copy link
Collaborator

Nush395 commented May 12, 2025

Hey @AlankritVerma01! gentle ping for if you would like to proceed with this PR.

@AlankritVerma01
Copy link
Author

Working on it.
Should have an update by tonight.
Thanks

# Register FakeTransportConfig exactly once
field = model_config.ToraxConfig.model_fields['transport']
ann = field.annotation
try:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this try except needed here?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Fix defaults and rebuild issues: enable error_if, isolate ABSEIL flags in CLI, dedupe transport config union
2 participants