Skip to content

Commit 48001a2

Browse files
yashk2810Google-ML-Automation
authored andcommitted
Make literal's dtype print with an empty shape so that it's consistent. So 1.0:f32 -> 1.0:f32[]
PiperOrigin-RevId: 753640777
1 parent dcdc25b commit 48001a2

File tree

6 files changed

+11
-11
lines changed

6 files changed

+11
-11
lines changed

docs/aot.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ some other features along the way. An example:
5050
>>> # Print the specialized, staged-out representation (as Jaxpr IR)
5151
>>> print(traced.jaxpr)
5252
{ lambda ; a:i32[] b:i32[]. let
53-
c:i32[] = mul 2:i32 a
53+
c:i32[] = mul 2:i32[] a
5454
d:i32[] = add c b
5555
in (d,) }
5656

jax/_src/api.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2396,7 +2396,7 @@ def make_jaxpr(
23962396
c:f32[] = sin a
23972397
_:f32[] = sin b
23982398
d:f32[] = cos b
2399-
e:f32[] = mul 1.0:f32 d
2399+
e:f32[] = mul 1.0:f32[] d
24002400
f:f32[] = neg e
24012401
g:f32[] = mul f c
24022402
in (g,) }

jax/_src/core.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -475,7 +475,7 @@ def pretty_print(self, context: JaxprPpContext, *, print_dtype: bool = True):
475475
del context # unused
476476
dtype = getattr(self.aval, 'dtype', None)
477477
if print_dtype and dtype:
478-
return f'{self.val}:{dtypes.short_dtype_name(dtype)}'
478+
return f'{self.val}:{self.aval.str_short(short_dtypes=True)}'
479479
else:
480480
return f'{self.val}'
481481

tests/api_test.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -6679,7 +6679,7 @@ def fun(x):
66796679
return (x, 1., np.zeros(1, dtype=jnp.float32))
66806680

66816681
dtype = "f64" if config.enable_x64.value else "f32"
6682-
expected = f"{{ lambda a:f32[1]; b:f32[]. let in (b, 1.0:{dtype}, a) }}"
6682+
expected = f"{{ lambda a:f32[1]; b:f32[]. let in (b, 1.0:{dtype}[], a) }}"
66836683
jaxpr = api.make_jaxpr(fun)(jnp.float32(0.))
66846684
self.assertMultiLineStrippedEqual(expected, str(jaxpr))
66856685

@@ -6691,9 +6691,9 @@ def f(x):
66916691
x + 2.,
66926692
lambda xf: xf - x)
66936693
expected = """{ lambda ; a:f32[]. let
6694-
b:bool[] = ge a 0.0:f32
6695-
c:f32[] = add a 1.0:f32
6696-
d:f32[] = add a 2.0:f32
6694+
b:bool[] = ge a 0.0:f32[]
6695+
c:f32[] = add a 1.0:f32[]
6696+
d:f32[] = add a 2.0:f32[]
66976697
e:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b
66986698
f:f32[] = cond[
66996699
branches=(

tests/pjit_test.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1274,7 +1274,7 @@ def test_pretty_print_with_constant_pjit_arg(self):
12741274
b:f32[1] = pjit[
12751275
name=<lambda>
12761276
jaxpr={ lambda ; a:f32[1] c:f32[]. let b:f32[1] = mul a c in (b,) }
1277-
] a 1.0:f32
1277+
] a 1.0:f32[]
12781278
in (b,) }
12791279
""").strip(),
12801280
)
@@ -1308,7 +1308,7 @@ def test_pretty_print_with_literal_outvar(self):
13081308
{ lambda ; a:f32[1]. let
13091309
b:i32[] c:f32[1] = pjit[
13101310
name=<lambda>
1311-
jaxpr={ lambda ; a:f32[1]. let in (2:i32, a) }
1311+
jaxpr={ lambda ; a:f32[1]. let in (2:i32[], a) }
13121312
] a
13131313
in (b, c) }
13141314
""").strip(),

tests/state_test.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,7 @@ def body(x_ref):
361361
return []
362362
jaxpr, _ , _, () = pe.trace_to_jaxpr_dynamic(
363363
wrap_init(body, 1), [shaped_array_ref((), jnp.int32)])
364-
self.assertIn("a[] <- 2:i32", jaxpr.pretty_print(use_color=False))
364+
self.assertIn("a[] <- 2:i32[]", jaxpr.pretty_print(use_color=False))
365365

366366
def body(x_ref, val):
367367
x_ref[:, 0] = val
@@ -377,7 +377,7 @@ def body(x_ref):
377377
return [x]
378378
jaxpr, _ , _, () = pe.trace_to_jaxpr_dynamic(
379379
wrap_init(body, 1), [shaped_array_ref((), jnp.int32)])
380-
self.assertIn("b:i32[], a[] <- a[], 2:i32", jaxpr.pretty_print(use_color=False))
380+
self.assertIn("b:i32[], a[] <- a[], 2:i32[]", jaxpr.pretty_print(use_color=False))
381381

382382
def body(x_ref, val):
383383
x = ref_swap(x_ref, (slice(None), 0), val)

0 commit comments

Comments
 (0)