@@ -750,16 +750,16 @@ def _infer_input_type(fun: Callable, dbg: core.DebugInfo,
750
750
for i , x in enumerate (explicit_args ):
751
751
avals .append (core .shaped_abstractify (x ))
752
752
except OverflowError :
753
- arg_path = f"argument path is { dbg .arg_names [i ]} "
753
+ arg_path = f"argument path is { dbg .arg_names [i ]} " # pytype: disable=name-error
754
754
raise OverflowError (
755
755
"An overflow was encountered while parsing an argument to a jitted "
756
756
f"computation, whose { arg_path } ."
757
757
) from None
758
758
except TypeError :
759
- arg_description = f"path { dbg .arg_names [i ]} "
759
+ arg_description = f"path { dbg .arg_names [i ]} " # pytype: disable=name-error
760
760
raise TypeError (
761
761
f"Error interpreting argument to { fun } as an abstract array."
762
- f" The problematic value is of type { type (x )} and was passed to"
762
+ f" The problematic value is of type { type (x )} and was passed to" # pytype: disable=name-error
763
763
f" the function at { arg_description } .\n "
764
764
"This typically means that a jit-wrapped function was called with a non-array"
765
765
" argument, and this argument was not marked as static using the"
@@ -2035,8 +2035,8 @@ def _pjit_cached_lower_jaxpr_to_fun(ctx: mlir.LoweringRuleContext,
2035
2035
elif isinstance (axis_ctx , sharding_impls .SPMDAxisContext ):
2036
2036
num_devices = axis_ctx .mesh .size
2037
2037
key = (pjit_p , name , jaxpr , effects , num_devices ,
2038
- pxla .SemanticallyEqualShardings (in_shardings , jaxpr .in_avals ),
2039
- pxla .SemanticallyEqualShardings (out_shardings , jaxpr .out_avals ),
2038
+ pxla .SemanticallyEqualShardings (in_shardings , jaxpr .in_avals ), # pytype: disable=wrong-arg-types
2039
+ pxla .SemanticallyEqualShardings (out_shardings , jaxpr .out_avals ), # pytype: disable=wrong-arg-types
2040
2040
in_layouts , out_layouts , api_name )
2041
2041
2042
2042
func = mod_ctx .cached_primitive_lowerings .get (key , None )
0 commit comments