Skip to content

Commit 9107c63

Browse files
yashk2810Google-ML-Automation
authored andcommitted
Rename with_dll_constraint to with_layout_constraint for the upcoming Layout API rename!
PiperOrigin-RevId: 753636449
1 parent e189cd4 commit 9107c63

File tree

4 files changed

+24
-24
lines changed

4 files changed

+24
-24
lines changed

jax/_src/pjit.py

+19-19
Original file line numberDiff line numberDiff line change
@@ -3120,49 +3120,49 @@ def use_explicit_axes(*axes):
31203120
with mesh_lib.use_abstract_mesh(new_mesh):
31213121
yield
31223122

3123-
# -------------------- with_dll_constraint --------------------
3123+
# -------------------- with_layout_constraint --------------------
31243124

3125-
def with_dll_constraint(x, layouts):
3125+
def with_layout_constraint(x, layouts):
31263126
x_flat, tree = tree_flatten(x)
3127-
layouts_flat = tuple(flatten_axes("with_dll_constraint layouts", tree,
3127+
layouts_flat = tuple(flatten_axes("with_layout_constraint layouts", tree,
31283128
layouts))
31293129
if any(not isinstance(l, DeviceLocalLayout) for l in layouts_flat):
31303130
raise ValueError(
3131-
'layouts passed to `with_dll_constraint` must be of type'
3131+
'layouts passed to `with_layout_constraint` must be of type'
31323132
f' `DeviceLocalLayout`. Got {[type(l) for l in layouts_flat]}')
31333133
check_aval_layout_compatibility(
31343134
layouts_flat, x_flat, ("",) * len(layouts_flat),
3135-
"with_dll_constraint arguments")
3136-
outs = [dll_constraint_p.bind(xf, layout=l)
3135+
"with_layout_constraint arguments")
3136+
outs = [layout_constraint_p.bind(xf, layout=l)
31373137
for xf, l in zip(x_flat, layouts_flat)]
31383138
return tree_unflatten(tree, outs)
31393139

3140-
dll_constraint_p = core.Primitive('dll_constraint')
3141-
dll_constraint_p.def_abstract_eval(lambda x, **_: x)
3142-
ad.deflinear2(dll_constraint_p,
3143-
lambda ct, _, **params: (dll_constraint_p.bind(ct, **params),))
3140+
layout_constraint_p = core.Primitive('layout_constraint')
3141+
layout_constraint_p.def_abstract_eval(lambda x, **_: x)
3142+
ad.deflinear2(layout_constraint_p,
3143+
lambda ct, _, **params: (layout_constraint_p.bind(ct, **params),))
31443144

3145-
def _dll_constraint_impl(x, *, layout):
3145+
def _layout_constraint_impl(x, *, layout):
31463146
if not isinstance(x, xc.ArrayImpl):
31473147
raise ValueError(
3148-
'with_dll_constraint in eager mode can only be applied to'
3148+
'with_layout_constraint in eager mode can only be applied to'
31493149
f' jax.Arrays. Got {type(x)}')
31503150
if x.layout.device_local_layout == layout: # type: ignore
31513151
return x
31523152
return api.jit(_identity_fn, out_shardings=Layout(layout, x.sharding))(x)
3153-
dll_constraint_p.def_impl(_dll_constraint_impl)
3153+
layout_constraint_p.def_impl(_layout_constraint_impl)
31543154

3155-
def _dll_constraint_hlo_lowering(ctx, x_node, *, layout):
3155+
def _layout_constraint_hlo_lowering(ctx, x_node, *, layout):
31563156
aval, = ctx.avals_in
31573157
out_aval, = ctx.avals_out
31583158
return [mlir.wrap_with_layout_op(ctx, x_node, out_aval, layout, aval)]
3159-
mlir.register_lowering(dll_constraint_p,
3160-
_dll_constraint_hlo_lowering)
3159+
mlir.register_lowering(layout_constraint_p,
3160+
_layout_constraint_hlo_lowering)
31613161

3162-
def _dll_constraint_batcher(axis_data, vals_in, dims_in, layout):
3162+
def _layout_constraint_batcher(axis_data, vals_in, dims_in, layout):
31633163
raise NotImplementedError
3164-
batching.fancy_primitive_batchers[dll_constraint_p] = _dll_constraint_batcher
3165-
batching.skippable_batchers[dll_constraint_p] = lambda _: ()
3164+
batching.fancy_primitive_batchers[layout_constraint_p] = _layout_constraint_batcher
3165+
batching.skippable_batchers[layout_constraint_p] = lambda _: ()
31663166

31673167
# -------------------- helpers --------------------
31683168

jax/experimental/jax2tf/tests/primitives_test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ def test_primitive_coverage(self):
178178
continue
179179
if p.name == "sharding_constraint":
180180
continue
181-
if p.name == "dll_constraint":
181+
if p.name == "layout_constraint":
182182
continue
183183
if p.name == "mesh_cast":
184184
continue

jax/experimental/layout.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,5 @@
1717
Layout as Layout,
1818
)
1919
from jax._src.pjit import (
20-
with_dll_constraint as with_dll_constraint,
20+
with_layout_constraint as with_layout_constraint,
2121
)

tests/layout_test.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from jax._src import config
2424
from jax._src import test_util as jtu
2525
from jax._src.util import safe_zip
26-
from jax.experimental.layout import (with_dll_constraint, Layout,
26+
from jax.experimental.layout import (with_layout_constraint, Layout,
2727
DeviceLocalLayout as DLL)
2828
from jax.experimental.compute_on import compute_on
2929

@@ -745,7 +745,7 @@ def f(x):
745745
self.assertArraysEqual(out, np_inp * 2)
746746
self.assertEqual(out.layout, out_layout)
747747

748-
def test_with_dll_constraint(self):
748+
def test_with_layout_constraint(self):
749749
if not jtu.test_device_matches(['tpu']):
750750
self.skipTest('Only works for TPU')
751751
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
@@ -761,7 +761,7 @@ def f(x):
761761
y = x.T
762762
# Constrain `y` to the original layout of `arr` because without it,
763763
# the layout of `y` would be the transpose of `arr`.
764-
y = with_dll_constraint(y, custom_dll)
764+
y = with_layout_constraint(y, custom_dll)
765765
return y * 2
766766

767767
f(arr) # doesn't crash

0 commit comments

Comments
 (0)