22
22
import pytensor .sparse
23
23
import pytensor .tensor as pt
24
24
import pytensor .tensor .slinalg
25
- import scipy as sp
26
- import scipy .sparse
27
25
28
26
from pytensor .graph .basic import Apply
29
27
from pytensor .graph .op import Op
93
91
from pytensor .tensor .linalg import solve_triangular
94
92
from pytensor .tensor .nlinalg import matrix_inverse
95
93
from pytensor .tensor .special import log_softmax , softmax
96
- from scipy .linalg import block_diag as scipy_block_diag
97
94
98
- from pymc .pytensorf import floatX , ix_ , largest_common_dtype
95
+ from pymc .pytensorf import floatX
99
96
100
97
__all__ = [
101
98
"abs" ,
@@ -513,55 +510,9 @@ def batched_diag(C):
513
510
raise ValueError ("Input should be 2 or 3 dimensional" )
514
511
515
512
516
- class BlockDiagonalMatrix (Op ):
517
- __props__ = ("sparse" , "format" )
518
-
519
- def __init__ (self , sparse = False , format = "csr" ):
520
- if format not in ("csr" , "csc" ):
521
- raise ValueError (f"format must be one of: 'csr', 'csc', got { format } " )
522
- self .sparse = sparse
523
- self .format = format
524
-
525
- def make_node (self , * matrices ):
526
- if not matrices :
527
- raise ValueError ("no matrices to allocate" )
528
- matrices = list (map (pt .as_tensor , matrices ))
529
- if any (mat .type .ndim != 2 for mat in matrices ):
530
- raise TypeError ("all data arguments must be matrices" )
531
- if self .sparse :
532
- out_type = pytensor .sparse .matrix (self .format , dtype = largest_common_dtype (matrices ))
533
- else :
534
- out_type = pytensor .tensor .matrix (dtype = largest_common_dtype (matrices ))
535
- return Apply (self , matrices , [out_type ])
536
-
537
- def perform (self , node , inputs , output_storage , params = None ):
538
- dtype = largest_common_dtype (inputs )
539
- if self .sparse :
540
- output_storage [0 ][0 ] = sp .sparse .block_diag (inputs , self .format , dtype )
541
- else :
542
- output_storage [0 ][0 ] = scipy_block_diag (* inputs ).astype (dtype )
543
-
544
- def grad (self , inputs , gout ):
545
- shapes = pt .stack ([i .shape for i in inputs ])
546
- index_end = shapes .cumsum (0 )
547
- index_begin = index_end - shapes
548
- slices = [
549
- ix_ (
550
- pt .arange (index_begin [i , 0 ], index_end [i , 0 ]),
551
- pt .arange (index_begin [i , 1 ], index_end [i , 1 ]),
552
- )
553
- for i in range (len (inputs ))
554
- ]
555
- return [gout [0 ][slc ] for slc in slices ]
556
-
557
- def infer_shape (self , fgraph , nodes , shapes ):
558
- first , second = zip (* shapes )
559
- return [(pt .add (* first ), pt .add (* second ))]
560
-
561
-
562
513
def block_diagonal (matrices , sparse = False , format = "csr" ):
563
- r"""See scipy.sparse .block_diag or
564
- scipy.linalg .block_diag for reference
514
+ r"""See pt.slinalg .block_diag or
515
+ pytensor.sparse.basic .block_diag for reference
565
516
566
517
Parameters
567
518
----------
@@ -575,6 +526,12 @@ def block_diagonal(matrices, sparse=False, format="csr"):
575
526
-------
576
527
matrix
577
528
"""
529
+ warnings .warn (
530
+ "pymc.math.block_diagonal is deprecated in favor of `pytensor.tensor.linalg.block_diag` and `pytensor.sparse.block_diag` functions. This function will be removed in a future release" ,
531
+ )
578
532
if len (matrices ) == 1 : # graph optimization
579
533
return matrices [0 ]
580
- return BlockDiagonalMatrix (sparse = sparse , format = format )(* matrices )
534
+ if sparse :
535
+ return pytensor .sparse .basic .block_diag (* matrices , format = format )
536
+ else :
537
+ return pt .slinalg .block_diag (* matrices )
0 commit comments