@@ -3120,49 +3120,49 @@ def use_explicit_axes(*axes):
3120
3120
with mesh_lib .use_abstract_mesh (new_mesh ):
3121
3121
yield
3122
3122
3123
- # -------------------- with_dll_constraint --------------------
3123
+ # -------------------- with_layout_constraint --------------------
3124
3124
3125
- def with_dll_constraint (x , layouts ):
3125
+ def with_layout_constraint (x , layouts ):
3126
3126
x_flat , tree = tree_flatten (x )
3127
- layouts_flat = tuple (flatten_axes ("with_dll_constraint layouts" , tree ,
3127
+ layouts_flat = tuple (flatten_axes ("with_layout_constraint layouts" , tree ,
3128
3128
layouts ))
3129
3129
if any (not isinstance (l , DeviceLocalLayout ) for l in layouts_flat ):
3130
3130
raise ValueError (
3131
- 'layouts passed to `with_dll_constraint ` must be of type'
3131
+ 'layouts passed to `with_layout_constraint ` must be of type'
3132
3132
f' `DeviceLocalLayout`. Got { [type (l ) for l in layouts_flat ]} ' )
3133
3133
check_aval_layout_compatibility (
3134
3134
layouts_flat , x_flat , ("" ,) * len (layouts_flat ),
3135
- "with_dll_constraint arguments" )
3136
- outs = [dll_constraint_p .bind (xf , layout = l )
3135
+ "with_layout_constraint arguments" )
3136
+ outs = [layout_constraint_p .bind (xf , layout = l )
3137
3137
for xf , l in zip (x_flat , layouts_flat )]
3138
3138
return tree_unflatten (tree , outs )
3139
3139
3140
- dll_constraint_p = core .Primitive ('dll_constraint ' )
3141
- dll_constraint_p .def_abstract_eval (lambda x , ** _ : x )
3142
- ad .deflinear2 (dll_constraint_p ,
3143
- lambda ct , _ , ** params : (dll_constraint_p .bind (ct , ** params ),))
3140
+ layout_constraint_p = core .Primitive ('layout_constraint ' )
3141
+ layout_constraint_p .def_abstract_eval (lambda x , ** _ : x )
3142
+ ad .deflinear2 (layout_constraint_p ,
3143
+ lambda ct , _ , ** params : (layout_constraint_p .bind (ct , ** params ),))
3144
3144
3145
- def _dll_constraint_impl (x , * , layout ):
3145
+ def _layout_constraint_impl (x , * , layout ):
3146
3146
if not isinstance (x , xc .ArrayImpl ):
3147
3147
raise ValueError (
3148
- 'with_dll_constraint in eager mode can only be applied to'
3148
+ 'with_layout_constraint in eager mode can only be applied to'
3149
3149
f' jax.Arrays. Got { type (x )} ' )
3150
3150
if x .layout .device_local_layout == layout : # type: ignore
3151
3151
return x
3152
3152
return api .jit (_identity_fn , out_shardings = Layout (layout , x .sharding ))(x )
3153
- dll_constraint_p .def_impl (_dll_constraint_impl )
3153
+ layout_constraint_p .def_impl (_layout_constraint_impl )
3154
3154
3155
- def _dll_constraint_hlo_lowering (ctx , x_node , * , layout ):
3155
+ def _layout_constraint_hlo_lowering (ctx , x_node , * , layout ):
3156
3156
aval , = ctx .avals_in
3157
3157
out_aval , = ctx .avals_out
3158
3158
return [mlir .wrap_with_layout_op (ctx , x_node , out_aval , layout , aval )]
3159
- mlir .register_lowering (dll_constraint_p ,
3160
- _dll_constraint_hlo_lowering )
3159
+ mlir .register_lowering (layout_constraint_p ,
3160
+ _layout_constraint_hlo_lowering )
3161
3161
3162
- def _dll_constraint_batcher (axis_data , vals_in , dims_in , layout ):
3162
+ def _layout_constraint_batcher (axis_data , vals_in , dims_in , layout ):
3163
3163
raise NotImplementedError
3164
- batching .fancy_primitive_batchers [dll_constraint_p ] = _dll_constraint_batcher
3165
- batching .skippable_batchers [dll_constraint_p ] = lambda _ : ()
3164
+ batching .fancy_primitive_batchers [layout_constraint_p ] = _layout_constraint_batcher
3165
+ batching .skippable_batchers [layout_constraint_p ] = lambda _ : ()
3166
3166
3167
3167
# -------------------- helpers --------------------
3168
3168
0 commit comments