Skip to content

Commit 603f730

Browse files
h-jooGoogle-ML-Automation
authored andcommitted
Automated Code Change
PiperOrigin-RevId: 753039040
1 parent 806190d commit 603f730

File tree

6 files changed

+13
-13
lines changed

6 files changed

+13
-13
lines changed

jax/_src/core.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2047,7 +2047,7 @@ def standard_insert_pvary(*args):
20472047
if not args:
20482048
return args
20492049
in_vma = [frozenset() if (aval := get_aval(a)) is abstract_token
2050-
else aval.vma for a in args]
2050+
else aval.vma for a in args] # pytype: disable=attribute-error
20512051
out_vma = frozenset.union(*in_vma)
20522052
return [pvary(arg, tuple(n for n in out_vma if n not in src))
20532053
if out_vma - src else arg for arg, src in zip(args, in_vma)]

jax/_src/interpreters/mlir.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1097,7 +1097,7 @@ class UnconstrainedVariants(NamedTuple):
10971097

10981098
def _get_unconstrained_variants(s, aval) -> UnconstrainedVariants:
10991099
us = contains_unconstrained(s)
1100-
unconstrained_dims = ({i for i, p in enumerate(s.spec)
1100+
unconstrained_dims = ({i for i, p in enumerate(s.spec) # pytype: disable=attribute-error
11011101
if p is PartitionSpec.UNCONSTRAINED} if us else None)
11021102
return UnconstrainedVariants(
11031103
contains_unconstrained=us, all_unconstrained=all_unconstrained(s, aval),

jax/_src/lax/lax.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -5172,7 +5172,7 @@ def _dot_general_shape_computation(lhs_shape, rhs_shape, dimension_numbers):
51725172
lhs_tensored_shape = tuple_delete(lhs_shape, lhs_contract_or_batch)
51735173
rhs_group = ()
51745174
if isinstance(dimension_numbers, RaggedDotDimensionNumbers):
5175-
rhs_group = tuple(dimension_numbers.rhs_group_dimensions)
5175+
rhs_group = tuple(dimension_numbers.rhs_group_dimensions) # pytype: disable=attribute-error
51765176
rhs_contract_or_batch_or_group = tuple(
51775177
sorted(tuple(rhs_contracting) + tuple(rhs_batch) + rhs_group)
51785178
)
@@ -6017,7 +6017,7 @@ def grad_x_dims():
60176017
unsorted_axes = list(x_batch) + x_kept + x_contract_sorted_by_y
60186018
case RaggedDotMode.RAGGED_CONTRACTING | RaggedDotMode.RAGGED_BATCH:
60196019
raise unimplemented('grad_x_dims', mode)
6020-
return dims, unsorted_axes
6020+
return dims, unsorted_axes # pytype: disable=name-error
60216021

60226022
def grad_y_dims():
60236023
match mode:
@@ -6036,7 +6036,7 @@ def grad_y_dims():
60366036
)
60376037
case RaggedDotMode.RAGGED_CONTRACTING | RaggedDotMode.RAGGED_BATCH:
60386038
raise unimplemented('grad_y_dims', mode)
6039-
return dims, unsorted_axes
6039+
return dims, unsorted_axes # pytype: disable=name-error
60406040

60416041
def _ragged_dot_grad(lhs, rhs, dims_fn, aval):
60426042
dims, unsorted_axes = dims_fn()
@@ -6238,7 +6238,7 @@ def expand(x, dim, gs, *axes):
62386238
lhs,
62396239
rhs,
62406240
dimension_numbers=ragged_dot_dimension_numbers.dot_dimension_numbers,
6241-
)
6241+
) # pytype: disable=bad-return-type
62426242

62436243

62446244
def _ragged_dot_general_lower(

jax/_src/pallas/triton/pallas_call_registration.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def normalize_grid(grid: pallas_core.StaticGrid) -> tuple[int, int, int]:
4040

4141

4242
def avals_to_layouts(avals):
43-
return [list(reversed(range(aval.ndim))) for aval in avals]
43+
return [list(reversed(range(aval.ndim))) for aval in avals] # pytype: disable=attribute-error
4444

4545

4646
def pallas_call_lowering(

jax/_src/pjit.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -750,16 +750,16 @@ def _infer_input_type(fun: Callable, dbg: core.DebugInfo,
750750
for i, x in enumerate(explicit_args):
751751
avals.append(core.shaped_abstractify(x))
752752
except OverflowError:
753-
arg_path = f"argument path is {dbg.arg_names[i]}"
753+
arg_path = f"argument path is {dbg.arg_names[i]}" # pytype: disable=name-error
754754
raise OverflowError(
755755
"An overflow was encountered while parsing an argument to a jitted "
756756
f"computation, whose {arg_path}."
757757
) from None
758758
except TypeError:
759-
arg_description = f"path {dbg.arg_names[i]}"
759+
arg_description = f"path {dbg.arg_names[i]}" # pytype: disable=name-error
760760
raise TypeError(
761761
f"Error interpreting argument to {fun} as an abstract array."
762-
f" The problematic value is of type {type(x)} and was passed to"
762+
f" The problematic value is of type {type(x)} and was passed to" # pytype: disable=name-error
763763
f" the function at {arg_description}.\n"
764764
"This typically means that a jit-wrapped function was called with a non-array"
765765
" argument, and this argument was not marked as static using the"
@@ -2035,8 +2035,8 @@ def _pjit_cached_lower_jaxpr_to_fun(ctx: mlir.LoweringRuleContext,
20352035
elif isinstance(axis_ctx, sharding_impls.SPMDAxisContext):
20362036
num_devices = axis_ctx.mesh.size
20372037
key = (pjit_p, name, jaxpr, effects, num_devices,
2038-
pxla.SemanticallyEqualShardings(in_shardings, jaxpr.in_avals),
2039-
pxla.SemanticallyEqualShardings(out_shardings, jaxpr.out_avals),
2038+
pxla.SemanticallyEqualShardings(in_shardings, jaxpr.in_avals), # pytype: disable=wrong-arg-types
2039+
pxla.SemanticallyEqualShardings(out_shardings, jaxpr.out_avals), # pytype: disable=wrong-arg-types
20402040
in_layouts, out_layouts, api_name)
20412041

20422042
func = mod_ctx.cached_primitive_lowerings.get(key, None)

jax/_src/tpu_custom_call.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ def _tpu_custom_call_abstract_eval(*_, out_avals, **__):
222222

223223

224224
def _avals_to_layouts(avals) -> Sequence[Sequence[int]]:
225-
return [tuple(range(a.ndim - 1, -1, -1)) for a in avals]
225+
return [tuple(range(a.ndim - 1, -1, -1)) for a in avals] # pytype: disable=attribute-error
226226

227227

228228
def _tpu_custom_call_lowering(

0 commit comments

Comments
 (0)