Skip to content

Commit 627a8dd

Browse files
authored
Deprecate block_diag from math module in favor of PyTensor (pymc-devs#7132)
1 parent 8745974 commit 627a8dd

File tree

1 file changed

+10
-53
lines changed

1 file changed

+10
-53
lines changed

pymc/math.py

+10-53
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@
2222
import pytensor.sparse
2323
import pytensor.tensor as pt
2424
import pytensor.tensor.slinalg
25-
import scipy as sp
26-
import scipy.sparse
2725

2826
from pytensor.graph.basic import Apply
2927
from pytensor.graph.op import Op
@@ -93,9 +91,8 @@
9391
from pytensor.tensor.linalg import solve_triangular
9492
from pytensor.tensor.nlinalg import matrix_inverse
9593
from pytensor.tensor.special import log_softmax, softmax
96-
from scipy.linalg import block_diag as scipy_block_diag
9794

98-
from pymc.pytensorf import floatX, ix_, largest_common_dtype
95+
from pymc.pytensorf import floatX
9996

10097
__all__ = [
10198
"abs",
@@ -513,55 +510,9 @@ def batched_diag(C):
513510
raise ValueError("Input should be 2 or 3 dimensional")
514511

515512

516-
class BlockDiagonalMatrix(Op):
517-
__props__ = ("sparse", "format")
518-
519-
def __init__(self, sparse=False, format="csr"):
520-
if format not in ("csr", "csc"):
521-
raise ValueError(f"format must be one of: 'csr', 'csc', got {format}")
522-
self.sparse = sparse
523-
self.format = format
524-
525-
def make_node(self, *matrices):
526-
if not matrices:
527-
raise ValueError("no matrices to allocate")
528-
matrices = list(map(pt.as_tensor, matrices))
529-
if any(mat.type.ndim != 2 for mat in matrices):
530-
raise TypeError("all data arguments must be matrices")
531-
if self.sparse:
532-
out_type = pytensor.sparse.matrix(self.format, dtype=largest_common_dtype(matrices))
533-
else:
534-
out_type = pytensor.tensor.matrix(dtype=largest_common_dtype(matrices))
535-
return Apply(self, matrices, [out_type])
536-
537-
def perform(self, node, inputs, output_storage, params=None):
538-
dtype = largest_common_dtype(inputs)
539-
if self.sparse:
540-
output_storage[0][0] = sp.sparse.block_diag(inputs, self.format, dtype)
541-
else:
542-
output_storage[0][0] = scipy_block_diag(*inputs).astype(dtype)
543-
544-
def grad(self, inputs, gout):
545-
shapes = pt.stack([i.shape for i in inputs])
546-
index_end = shapes.cumsum(0)
547-
index_begin = index_end - shapes
548-
slices = [
549-
ix_(
550-
pt.arange(index_begin[i, 0], index_end[i, 0]),
551-
pt.arange(index_begin[i, 1], index_end[i, 1]),
552-
)
553-
for i in range(len(inputs))
554-
]
555-
return [gout[0][slc] for slc in slices]
556-
557-
def infer_shape(self, fgraph, nodes, shapes):
558-
first, second = zip(*shapes)
559-
return [(pt.add(*first), pt.add(*second))]
560-
561-
562513
def block_diagonal(matrices, sparse=False, format="csr"):
563-
r"""See scipy.sparse.block_diag or
564-
scipy.linalg.block_diag for reference
514+
r"""See pt.slinalg.block_diag or
515+
pytensor.sparse.basic.block_diag for reference
565516
566517
Parameters
567518
----------
@@ -575,6 +526,12 @@ def block_diagonal(matrices, sparse=False, format="csr"):
575526
-------
576527
matrix
577528
"""
529+
warnings.warn(
530+
"pymc.math.block_diagonal is deprecated in favor of `pytensor.tensor.linalg.block_diag` and `pytensor.sparse.block_diag` functions. This function will be removed in a future release",
531+
)
578532
if len(matrices) == 1: # graph optimization
579533
return matrices[0]
580-
return BlockDiagonalMatrix(sparse=sparse, format=format)(*matrices)
534+
if sparse:
535+
return pytensor.sparse.basic.block_diag(*matrices, format=format)
536+
else:
537+
return pt.slinalg.block_diag(*matrices)

0 commit comments

Comments
 (0)