Skip to content

Commit 72a268c

Browse files
Make Numba the default backend for testing
Do not merge this commit; it is intended for temporary use in a draft PR.
1 parent 9436553 commit 72a268c

File tree

4 files changed

+35
-14
lines changed

4 files changed

+35
-14
lines changed

.github/workflows/test.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ jobs:
6666
runs-on: ubuntu-latest
6767
if: ${{ needs.changes.outputs.changes == 'true' && needs.style.result == 'success' }}
6868
strategy:
69-
fail-fast: true
69+
fail-fast: false
7070
matrix:
7171
python-version: ["3.7", "3.9"]
7272
fast-compile: [0]
@@ -132,7 +132,7 @@ jobs:
132132
if [[ $FAST_COMPILE == "1" ]]; then export AESARA_FLAGS=$AESARA_FLAGS,mode=FAST_COMPILE; fi
133133
if [[ $FLOAT32 == "1" ]]; then export AESARA_FLAGS=$AESARA_FLAGS,floatX=float32; fi
134134
export AESARA_FLAGS=$AESARA_FLAGS,warn__ignore_bug_before=all,on_opt_error=raise,on_shape_error=raise,gcc__cxxflags=-pipe
135-
python -m pytest -x -r A --verbose --runslow --cov=aesara/ --cov-report=xml:coverage/coverage-${MATRIX_ID}.xml --no-cov-on-fail $PART
135+
python -m pytest --verbose --runslow --cov=aesara/ --cov-report=xml:coverage/coverage-${MATRIX_ID}.xml --no-cov-on-fail $PART
136136
env:
137137
MATRIX_ID: ${{ steps.matrix-id.outputs.id }}
138138
MKL_THREADING_LAYER: GNU

aesara/compile/mode.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,9 @@
4242
"c": CLinker(), # Don't support gc. so don't check allow_gc
4343
"c|py": OpWiseCLinker(), # Use allow_gc Aesara flag
4444
"c|py_nogc": OpWiseCLinker(allow_gc=False),
45-
"vm": VMLinker(use_cloop=False), # Use allow_gc Aesara flag
45+
"vm": NumbaLinker(), # VMLinker(use_cloop=False), # Use allow_gc Aesara flag
4646
"cvm": VMLinker(use_cloop=True), # Use allow_gc Aesara flag
47-
"vm_nogc": VMLinker(allow_gc=False, use_cloop=False),
47+
"vm_nogc": NumbaLinker(), # VMLinker(allow_gc=False, use_cloop=False),
4848
"cvm_nogc": VMLinker(allow_gc=False, use_cloop=True),
4949
"jax": JAXLinker(),
5050
"numba": NumbaLinker(),
@@ -441,9 +441,9 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
441441
# Use VM_linker to allow lazy evaluation by default.
442442
FAST_COMPILE = Mode(VMLinker(use_cloop=False, c_thunks=False), "fast_compile")
443443
if config.cxx:
444-
FAST_RUN = Mode("cvm", "fast_run")
444+
FAST_RUN = Mode("numba", "fast_run")
445445
else:
446-
FAST_RUN = Mode("vm", "fast_run")
446+
FAST_RUN = Mode("numba", "fast_run")
447447

448448
JAX = Mode(
449449
JAXLinker(),

aesara/configdefaults.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -392,7 +392,7 @@ def add_compile_configvars():
392392
config.add(
393393
"mode",
394394
"Default compilation mode",
395-
ConfigParam("Mode", apply=_filter_mode),
395+
ConfigParam("NUMBA", apply=_filter_mode),
396396
in_c_key=False,
397397
)
398398

@@ -463,7 +463,18 @@ def add_compile_configvars():
463463
"linker",
464464
"Default linker used if the aesara flags mode is Mode",
465465
EnumStr(
466-
"cvm", ["c|py", "py", "c", "c|py_nogc", "vm", "vm_nogc", "cvm_nogc"]
466+
"numba",
467+
[
468+
"c|py",
469+
"py",
470+
"c",
471+
"c|py_nogc",
472+
"vm",
473+
"vm_nogc",
474+
"cvm_nogc",
475+
"numba",
476+
"jax",
477+
],
467478
),
468479
in_c_key=False,
469480
)
@@ -473,7 +484,7 @@ def add_compile_configvars():
473484
config.add(
474485
"linker",
475486
"Default linker used if the aesara flags mode is Mode",
476-
EnumStr("vm", ["py", "vm_nogc"]),
487+
EnumStr("numba", ["py", "vm_nogc", "vm", "numba", "jax"]),
477488
in_c_key=False,
478489
)
479490
if type(config).cxx.is_default:

tests/scan/test_basic.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ class TestScan:
248248
"rng_type",
249249
[
250250
np.random.default_rng,
251-
np.random.RandomState,
251+
# np.random.RandomState,
252252
],
253253
)
254254
def test_inner_graph_cloning(self, rng_type):
@@ -396,7 +396,7 @@ def f_pow2(x_tm1):
396396
assert all(i.value is None for i in scan_node.op.fn.input_storage)
397397
assert all(o.value is None for o in scan_node.op.fn.output_storage)
398398

399-
@pytest.mark.parametrize("mode", [Mode(linker="py"), Mode(linker="cvm")])
399+
@pytest.mark.parametrize("mode", ["NUMBA", Mode(linker="py"), Mode(linker="cvm")])
400400
@pytest.mark.parametrize(
401401
"x_init",
402402
[
@@ -421,7 +421,12 @@ def f_pow(x_tm1):
421421
assert res.dtype == exp_res.dtype
422422

423423
@pytest.mark.parametrize(
424-
"mode", [Mode(linker="py", optimizer=None), Mode(linker="cvm", optimizer=None)]
424+
"mode",
425+
[
426+
"NUMBA",
427+
Mode(linker="py", optimizer=None),
428+
Mode(linker="cvm", optimizer=None),
429+
],
425430
)
426431
@pytest.mark.parametrize(
427432
"x",
@@ -459,7 +464,12 @@ def inner_fn(x_seq, x_i):
459464
assert res.dtype == exp_res.dtype
460465

461466
@pytest.mark.parametrize(
462-
"mode", [Mode(linker="py", optimizer=None), Mode(linker="cvm", optimizer=None)]
467+
"mode",
468+
[
469+
"NUMBA",
470+
Mode(linker="py", optimizer=None),
471+
Mode(linker="cvm", optimizer=None),
472+
],
463473
)
464474
@pytest.mark.parametrize(
465475
"x",
@@ -1126,7 +1136,7 @@ def test_inner_grad(self):
11261136
utt.assert_allclose(out, vR)
11271137

11281138
@pytest.mark.parametrize(
1129-
"mode", [Mode(linker="cvm", optimizer=None), Mode(linker="cvm")]
1139+
"mode", ["NUMBA", Mode(linker="cvm", optimizer=None), Mode(linker="cvm")]
11301140
)
11311141
def test_sequence_is_scan(self, mode):
11321142
"""Make sure that a `Scan` can be used as a sequence input to another `Scan`."""

0 commit comments

Comments
 (0)