1
1
import warnings
2
2
from numbers import Number
3
3
from textwrap import dedent
4
- from typing import Dict , List , Tuple , Union
4
+ from typing import Dict , List , Sequence , Tuple , Union
5
5
6
6
import numpy as np
7
7
17
17
from aesara .tensor import basic as at
18
18
from aesara .tensor import get_vector_length
19
19
from aesara .tensor .exceptions import NotScalarConstantError
20
- from aesara .tensor .type import DenseTensorType , TensorType , int_dtypes , tensor
20
+ from aesara .tensor .type import (
21
+ DenseTensorType ,
22
+ TensorType ,
23
+ int_dtypes ,
24
+ shape_encode ,
25
+ tensor ,
26
+ )
21
27
from aesara .tensor .type_other import NoneConst
22
28
from aesara .tensor .var import TensorConstant , TensorVariable
23
29
24
30
31
+ def filter_shape_vars (
32
+ ref_shape : Tuple [int , ...], shape : Sequence [Variable ], shape_is_encoded : bool = True
33
+ ) -> Tuple [int , ...]:
34
+ r"""Compute the most \"informative\" shape based on a static reference.
35
+
36
+ Parameters
37
+ ----------
38
+ ref_shape
39
+ A static shape reference using static shape constraint encoding.
40
+ shape
41
+ A symbolic shape.
42
+ shape_is_encoded
43
+ If ``True``, `shape` is assumed to be static shape constraint encoded.
44
+
45
+ Returns
46
+ -------
47
+ The most specific, and compatible (with `ref_shape`), static shape
48
+ constraint encoded values.
49
+ """
50
+ shape_bottom = shape_encode (None )
51
+ type_shape = ()
52
+ for i , (xts , s ) in enumerate (zip (ref_shape , shape )):
53
+
54
+ try :
55
+ # TODO FIXME: We shouldn't need to do this; let a rewrite
56
+ # do constant folding and update the `TensorType`s.
57
+ s_val = at .get_scalar_constant_value (s )
58
+
59
+ if isinstance (s_val , np .ndarray ):
60
+ s_val = s_val .item ()
61
+
62
+ if shape_is_encoded or s_val is not None and s_val > 0 :
63
+ type_s = shape_encode (s_val )
64
+ else :
65
+ type_s = shape_bottom
66
+ except NotScalarConstantError :
67
+ type_s = shape_bottom
68
+
69
+ if not (xts <= - 1 or type_s <= - 1 or type_s == xts ):
70
+ raise AssertionError (
71
+ f"SpecifyShape: Got shape { xts } at index { i } , expected { type_s } ."
72
+ )
73
+
74
+ type_shape += (max (type_s , xts ),)
75
+
76
+ return type_shape
77
+
78
+
25
79
def register_shape_c_code (type , code , version = ()):
26
80
"""
27
81
Tell Shape Op how to generate C code for an Aesara Type.
@@ -394,7 +448,6 @@ class SpecifyShape(COp):
394
448
_f16_ok = True
395
449
396
450
def make_node (self , x , * shape ):
397
- from aesara .tensor .basic import get_scalar_constant_value
398
451
399
452
x = at .as_tensor_variable (x )
400
453
@@ -417,18 +470,7 @@ def make_node(self, x, *shape):
417
470
f"Input `x` is { x .type .ndim } -dimensional and will never match a shape of length { len (shape )} ."
418
471
)
419
472
420
- type_shape = [None ] * x .ndim
421
- for i , (xts , s ) in enumerate (zip (x .type .shape , shape )):
422
- if xts is not None :
423
- type_shape [i ] = xts
424
- else :
425
- try :
426
- type_s = get_scalar_constant_value (s )
427
- if type_s is not None :
428
- type_shape [i ] = int (type_s )
429
- except NotScalarConstantError :
430
- pass
431
-
473
+ type_shape = filter_shape_vars (x .type .shape_encoded , shape )
432
474
out_var = x .type .clone (shape = type_shape )()
433
475
434
476
return Apply (self , [x , * shape ], [out_var ])
@@ -441,10 +483,10 @@ def perform(self, node, inp, out_):
441
483
raise AssertionError (
442
484
f"SpecifyShape: Got { x .ndim } dimensions (shape { x .shape } ), expected { ndim } dimensions with shape { tuple (shape )} ."
443
485
)
444
- if not all ( xs == s for xs , s in zip (x .shape , shape ) if s is not None ):
445
- raise AssertionError (
446
- f"SpecifyShape: Got shape { x . shape } , expected { tuple ( int ( s ) if s is not None else None for s in shape ) } ."
447
- )
486
+ for xs , s in zip (x .shape , shape ):
487
+ if ( s == - 1 and xs == 1 ) or ( s is not None and s > - 1 and not xs == s ):
488
+ raise AssertionError ( f"SpecifyShape: Got shape { xs } , expected { s } ." )
489
+
448
490
out [0 ] = x
449
491
450
492
def infer_shape (self , fgraph , node , shapes ):
@@ -454,11 +496,11 @@ def infer_shape(self, fgraph, node, shapes):
454
496
for dim in range (node .inputs [0 ].type .ndim ):
455
497
s = shape [dim ]
456
498
try :
457
- s = at .get_scalar_constant_value (s )
458
- # We assume that `None` shapes are always retrieved by
499
+ s = shape_encode ( at .get_scalar_constant_value (s ) )
500
+ # We assume that negative shapes are always retrieved by
459
501
# `get_scalar_constant_value`, and only in that case do we default to
460
502
# the shape of the input variable
461
- if s is None :
503
+ if s < 0 :
462
504
s = xshape [dim ]
463
505
except NotScalarConstantError :
464
506
pass
@@ -502,6 +544,9 @@ def c_code(self, node, name, i_names, o_names, sub):
502
544
);
503
545
{ fail } ;
504
546
}}
547
+
548
+ npy_intp shp;
549
+ npy_intp actual_shp;
505
550
"""
506
551
)
507
552
@@ -510,9 +555,11 @@ def c_code(self, node, name, i_names, o_names, sub):
510
555
continue
511
556
code += dedent (
512
557
f"""
513
- if (py_{ shp_name } != Py_None){{
514
- dtype_{ shp_name } shp = ((dtype_{ shp_name } *)PyArray_GETPTR1({ shp_name } , 0))[0];
515
- if (PyArray_DIMS({ x_name } )[{ i } ] != shp) {{
558
+ shp = ((dtype_{ shp_name } *)PyArray_GETPTR1({ shp_name } , 0))[0];
559
+
560
+ if (shp > -2) {{
561
+ actual_shp = PyArray_DIMS({ x_name } )[{ i } ];
562
+ if (actual_shp == -1 && shp == 1 || actual_shp != shp) {{
516
563
PyErr_Format(PyExc_AssertionError,
517
564
"SpecifyShape: dim %d of input has shape %d, expected %d.",
518
565
{ i } , PyArray_DIMS({ x_name } )[{ i } ], shp
@@ -533,7 +580,7 @@ def c_code(self, node, name, i_names, o_names, sub):
533
580
return code
534
581
535
582
def c_code_cache_version (self ):
536
- return (2 ,)
583
+ return (3 ,)
537
584
538
585
539
586
_specify_shape = SpecifyShape ()
0 commit comments