Skip to content

Commit 3f1fb6b

Browse files
Add a non-one shape constraint to TensorType
1 parent 3ee2717 commit 3f1fb6b

File tree

8 files changed

+233
-70
lines changed

8 files changed

+233
-70
lines changed

aesara/link/jax/dispatch/shape.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,21 @@ def shape_i(x):
4646
def jax_funcify_SpecifyShape(op, **kwargs):
4747
def specifyshape(x, *shape):
4848
assert x.ndim == len(shape)
49-
assert jnp.all(x.shape == tuple(shape)), (
50-
"got shape",
51-
x.shape,
52-
"expected",
53-
shape,
54-
)
49+
for s_x, s in zip(x.shape, shape):
50+
if s == -1:
51+
assert s_x != 1, (
52+
"got shape",
53+
s_x,
54+
"expected",
55+
s,
56+
)
57+
elif s > -1:
58+
assert s_x == s, (
59+
"got shape",
60+
s_x,
61+
"expected",
62+
s,
63+
)
5564
return x
5665

5766
return specifyshape

aesara/link/numba/dispatch/basic.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import warnings
33
from contextlib import contextmanager
44
from functools import singledispatch
5-
from textwrap import dedent
5+
from textwrap import dedent, indent
66
from typing import Union
77

88
import numba
@@ -673,17 +673,26 @@ def numba_funcify_SpecifyShape(op, node, **kwargs):
673673
shape_input_names = ["shape_" + str(i) for i in range(len(shape_inputs))]
674674

675675
func_conditions = [
676-
f"assert x.shape[{i}] == {shape_input_names}"
677-
for i, (shape_input, shape_input_names) in enumerate(
676+
dedent(
677+
f"""
678+
if {shape_name} == -1:
679+
assert x.shape[{i}] != 1
680+
if {shape_name} > -1:
681+
assert x.shape[{i}] == {shape_name}
682+
"""
683+
)
684+
for i, (shape_input, shape_name) in enumerate(
678685
zip(shape_inputs, shape_input_names)
679686
)
680687
if shape_input is not NoneConst
681688
]
682689

690+
conditions_block = "\n".join(func_conditions)
691+
683692
func = dedent(
684693
f"""
685694
def specify_shape(x, {create_arg_string(shape_input_names)}):
686-
{"; ".join(func_conditions)}
695+
{indent(conditions_block, ' ' * 12)}
687696
return x
688697
"""
689698
)

aesara/tensor/shape.py

Lines changed: 73 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import warnings
22
from numbers import Number
33
from textwrap import dedent
4-
from typing import Dict, List, Tuple, Union
4+
from typing import Dict, List, Sequence, Tuple, Union
55

66
import numpy as np
77

@@ -17,11 +17,65 @@
1717
from aesara.tensor import basic as at
1818
from aesara.tensor import get_vector_length
1919
from aesara.tensor.exceptions import NotScalarConstantError
20-
from aesara.tensor.type import DenseTensorType, TensorType, int_dtypes, tensor
20+
from aesara.tensor.type import (
21+
DenseTensorType,
22+
TensorType,
23+
int_dtypes,
24+
shape_encode,
25+
tensor,
26+
)
2127
from aesara.tensor.type_other import NoneConst
2228
from aesara.tensor.var import TensorConstant, TensorVariable
2329

2430

31+
def filter_shape_vars(
32+
ref_shape: Tuple[int, ...], shape: Sequence[Variable], shape_is_encoded: bool = True
33+
) -> Tuple[int, ...]:
34+
r"""Compute the most \"informative\" shape based on a static reference.
35+
36+
Parameters
37+
----------
38+
ref_shape
39+
A static shape reference using static shape constraint encoding.
40+
shape
41+
A symbolic shape.
42+
shape_is_encoded
43+
If ``True``, `shape` is assumed to be static shape constraint encoded.
44+
45+
Returns
46+
-------
47+
The most specific, and compatible (with `ref_shape`), static shape
48+
constraint encoded values.
49+
"""
50+
shape_bottom = shape_encode(None)
51+
type_shape = ()
52+
for i, (xts, s) in enumerate(zip(ref_shape, shape)):
53+
54+
try:
55+
# TODO FIXME: We shouldn't need to do this; let a rewrite
56+
# do constant folding and update the `TensorType`s.
57+
s_val = at.get_scalar_constant_value(s)
58+
59+
if isinstance(s_val, np.ndarray):
60+
s_val = s_val.item()
61+
62+
if shape_is_encoded or s_val is not None and s_val > 0:
63+
type_s = shape_encode(s_val)
64+
else:
65+
type_s = shape_bottom
66+
except NotScalarConstantError:
67+
type_s = shape_bottom
68+
69+
if not (xts <= -1 or type_s <= -1 or type_s == xts):
70+
raise AssertionError(
71+
f"SpecifyShape: Got shape {xts} at index {i}, expected {type_s}."
72+
)
73+
74+
type_shape += (max(type_s, xts),)
75+
76+
return type_shape
77+
78+
2579
def register_shape_c_code(type, code, version=()):
2680
"""
2781
Tell Shape Op how to generate C code for an Aesara Type.
@@ -394,7 +448,6 @@ class SpecifyShape(COp):
394448
_f16_ok = True
395449

396450
def make_node(self, x, *shape):
397-
from aesara.tensor.basic import get_scalar_constant_value
398451

399452
x = at.as_tensor_variable(x)
400453

@@ -417,18 +470,7 @@ def make_node(self, x, *shape):
417470
f"Input `x` is {x.type.ndim}-dimensional and will never match a shape of length {len(shape)}."
418471
)
419472

420-
type_shape = [None] * x.ndim
421-
for i, (xts, s) in enumerate(zip(x.type.shape, shape)):
422-
if xts is not None:
423-
type_shape[i] = xts
424-
else:
425-
try:
426-
type_s = get_scalar_constant_value(s)
427-
if type_s is not None:
428-
type_shape[i] = int(type_s)
429-
except NotScalarConstantError:
430-
pass
431-
473+
type_shape = filter_shape_vars(x.type.shape_encoded, shape)
432474
out_var = x.type.clone(shape=type_shape)()
433475

434476
return Apply(self, [x, *shape], [out_var])
@@ -441,10 +483,10 @@ def perform(self, node, inp, out_):
441483
raise AssertionError(
442484
f"SpecifyShape: Got {x.ndim} dimensions (shape {x.shape}), expected {ndim} dimensions with shape {tuple(shape)}."
443485
)
444-
if not all(xs == s for xs, s in zip(x.shape, shape) if s is not None):
445-
raise AssertionError(
446-
f"SpecifyShape: Got shape {x.shape}, expected {tuple(int(s) if s is not None else None for s in shape)}."
447-
)
486+
for xs, s in zip(x.shape, shape):
487+
if (s == -1 and xs == 1) or (s is not None and s > -1 and not xs == s):
488+
raise AssertionError(f"SpecifyShape: Got shape {xs}, expected {s}.")
489+
448490
out[0] = x
449491

450492
def infer_shape(self, fgraph, node, shapes):
@@ -454,11 +496,11 @@ def infer_shape(self, fgraph, node, shapes):
454496
for dim in range(node.inputs[0].type.ndim):
455497
s = shape[dim]
456498
try:
457-
s = at.get_scalar_constant_value(s)
458-
# We assume that `None` shapes are always retrieved by
499+
s = shape_encode(at.get_scalar_constant_value(s))
500+
# We assume that negative shapes are always retrieved by
459501
# `get_scalar_constant_value`, and only in that case do we default to
460502
# the shape of the input variable
461-
if s is None:
503+
if s < 0:
462504
s = xshape[dim]
463505
except NotScalarConstantError:
464506
pass
@@ -502,6 +544,9 @@ def c_code(self, node, name, i_names, o_names, sub):
502544
);
503545
{fail};
504546
}}
547+
548+
npy_intp shp;
549+
npy_intp actual_shp;
505550
"""
506551
)
507552

@@ -510,9 +555,11 @@ def c_code(self, node, name, i_names, o_names, sub):
510555
continue
511556
code += dedent(
512557
f"""
513-
if (py_{shp_name} != Py_None){{
514-
dtype_{shp_name} shp = ((dtype_{shp_name}*)PyArray_GETPTR1({shp_name}, 0))[0];
515-
if (PyArray_DIMS({x_name})[{i}] != shp) {{
558+
shp = ((dtype_{shp_name}*)PyArray_GETPTR1({shp_name}, 0))[0];
559+
560+
if (shp > -2) {{
561+
actual_shp = PyArray_DIMS({x_name})[{i}];
562+
if (actual_shp == -1 && shp == 1 || actual_shp != shp) {{
516563
PyErr_Format(PyExc_AssertionError,
517564
"SpecifyShape: dim %d of input has shape %d, expected %d.",
518565
{i}, PyArray_DIMS({x_name})[{i}], shp
@@ -533,7 +580,7 @@ def c_code(self, node, name, i_names, o_names, sub):
533580
return code
534581

535582
def c_code_cache_version(self):
536-
return (2,)
583+
return (3,)
537584

538585

539586
_specify_shape = SpecifyShape()

0 commit comments

Comments
 (0)