1
1
import numpy as np
2
- from scipy .sparse import diags , block_diag , coo_matrix , spdiags
3
2
from scipy .optimize import minimize
3
+ from scipy .sparse import block_diag , coo_matrix , diags , spdiags
4
+
4
5
5
6
class SNMFOptimizer :
6
7
def __init__ (self , MM , Y0 , X0 = None , A = None , rho = 1e18 , eta = 1 , maxiter = 300 ):
@@ -97,7 +98,7 @@ def apply_interpolation(self, a, x):
97
98
98
99
# Ensure `a` is an array and reshape for broadcasting
99
100
a = np .atleast_1d (np .asarray (a )) # Ensures a is at least 1D
100
- #a = np.asarray(a)[None, :] # Shape (1, M) to allow broadcasting
101
+ # a = np.asarray(a)[None, :] # Shape (1, M) to allow broadcasting
101
102
102
103
# Compute fractional indices, broadcasting over `a`
103
104
ii = np .arange (N )[:, None ] / a # Shape (N, M)
@@ -125,7 +126,7 @@ def apply_interpolation(self, a, x):
125
126
Tx = np .vstack ([Tx , np .zeros ((N - len (I ), Tx .shape [1 ]))])
126
127
127
128
# Compute second derivative (Hx)
128
- ddi = - di / a + i * a ** - 2
129
+ ddi = - di / a + i * a ** - 2
129
130
Hx = x [I ] * (- ddi ) + x [np .minimum (I + 1 , N - 1 )] * ddi
130
131
Hx = np .vstack ([Hx , np .zeros ((N - len (I ), Hx .shape [1 ]))])
131
132
@@ -152,10 +153,10 @@ def get_objective_function(self, R=None, A=None):
152
153
R = self .R
153
154
if A is None :
154
155
A = self .A
155
- residual_term = 0.5 * np .linalg .norm (R , ' fro' ) ** 2
156
+ residual_term = 0.5 * np .linalg .norm (R , " fro" ) ** 2
156
157
# original code selected indices, but for now we'll compute the norm over the whole matrix
157
158
# residual_term = 0.5 * np.linalg.norm(self.R[index, :], 'fro') ** 2
158
- regularization_term = 0.5 * self .rho * np .linalg .norm (self .P @ A .T , ' fro' ) ** 2
159
+ regularization_term = 0.5 * self .rho * np .linalg .norm (self .P @ A .T , " fro" ) ** 2
159
160
sparsity_term = self .eta * np .sum (np .sqrt (self .X )) # Square root penalty
160
161
# Final objective function value
161
162
function = residual_term + regularization_term + sparsity_term
@@ -250,7 +251,7 @@ def apply_transformation_matrix(self, R=None):
250
251
iI = ii - II
251
252
252
253
# Expand row indices (MATLAB: repm = repmat(1:K, Nindex, M))
253
- repm = np .tile (np .arange (self .K ), (self .N , self .M )) # indexed to zero here
254
+ repm = np .tile (np .arange (self .K ), (self .N , self .M )) # indexed to zero here
254
255
255
256
# Compute transformations (MATLAB: kro = kron(R(index,:), ones(1, K)))
256
257
kro = np .kron (R , np .ones ((1 , self .K ))) # Use full `R`
@@ -313,7 +314,7 @@ def objective(y):
313
314
y0 = np .ones (K )
314
315
315
316
# Solve QP
316
- result = minimize (objective , y0 , bounds = bounds , method = ' SLSQP' )
317
+ result = minimize (objective , y0 , bounds = bounds , method = " SLSQP" )
317
318
318
319
return result .x # Optimal solution
319
320
@@ -339,18 +340,21 @@ def updateX(self):
339
340
denom = np .linalg .norm (self .curX - self .preX , "fro" ) ** 2 # Frobenius norm squared
340
341
L = num / denom if denom > 0 else L0 # Ensure L0 fallback
341
342
342
- L = np .maximum (L , L0 ) # ensure L is positive
343
- while True : # iterate updating X
343
+ L = np .maximum (L , L0 ) # ensure L is positive
344
+ while True : # iterate updating X
344
345
X_ = self .curX - self .GraX / L
345
346
# Solve x^3 + p*x + q = 0 for the largest real root
346
347
# off the shelf solver did not work element-wise for matrices
347
348
X = np .square (rooth (- X_ , self .eta / (2 * L )))
348
349
# Mask values that should be set to zero
349
- mask = (X ** 2 * L / 2 - L * X * X_ + self .eta * np .sqrt (X )) < 0
350
+ mask = (X ** 2 * L / 2 - L * X * X_ + self .eta * np .sqrt (X )) < 0
350
351
X [mask ] = 0
351
352
# Check if objective function improves
352
- if self .objective_history [- 1 ] - self .get_objective_function (
353
- self .get_residual_matrix (np .maximum (0 , X ), self .Y , self .A ), self .A ) > 0 :
353
+ if (
354
+ self .objective_history [- 1 ]
355
+ - self .get_objective_function (self .get_residual_matrix (np .maximum (0 , X ), self .Y , self .A ), self .A )
356
+ > 0
357
+ ):
354
358
break
355
359
# Increase L
356
360
L *= 2
@@ -405,29 +409,29 @@ def regularize_function(self, A=None):
405
409
fun = self .get_objective_function (RA , A )
406
410
407
411
# Compute gradient (removed index filtering)
408
- gra = np . reshape (
409
- np .sum (TX * np .tile (RA , (1 , K )), axis = 0 ), (M , K )
410
- ). T + self . rho * A @ self . P . T @ self . P # Gradient matrix
412
+ gra = (
413
+ np .reshape ( np . sum (TX * np .tile (RA , (1 , K )), axis = 0 ), (M , K )). T + self . rho * A @ self . P . T @ self . P
414
+ ) # Gradient matrix
411
415
412
416
# Compute Hessian (removed index filtering)
413
417
hess = np .zeros ((M * K , M * K ))
414
418
415
419
for m in range (M ):
416
420
Tx = TX [:, m + M * np .arange (K )] # Now using all rows
417
- hess [m * K : (m + 1 ) * K , m * K : (m + 1 ) * K ] = Tx .T @ Tx
421
+ hess [m * K : (m + 1 ) * K , m * K : (m + 1 ) * K ] = Tx .T @ Tx
418
422
419
423
hess = (
420
- hess
421
- + spdiags (
422
- np .reshape (
423
- np .reshape (np .sum (HX * np .tile (RA , (1 , K )), axis = 0 ), (M , K )).T ,
424
- (M * K ,), # ✅ Ensure 1D instead of (M*K,1)
425
- ),
426
- 0 , # Diagonal index
427
- M * K , # Number of rows
428
- M * K , # Number of columns
429
- ).toarray ()
430
- + self .rho * self .PPPP
424
+ hess
425
+ + spdiags (
426
+ np .reshape (
427
+ np .reshape (np .sum (HX * np .tile (RA , (1 , K )), axis = 0 ), (M , K )).T ,
428
+ (M * K ,), # ✅ Ensure 1D instead of (M*K,1)
429
+ ),
430
+ 0 , # Diagonal index
431
+ M * K , # Number of rows
432
+ M * K , # Number of columns
433
+ ).toarray ()
434
+ + self .rho * self .PPPP
431
435
)
432
436
433
437
return fun , gra , hess
@@ -467,7 +471,7 @@ def rooth(p, q):
467
471
Solves x^3 + p*x + q = 0 element-wise for matrices, returning the largest real root.
468
472
"""
469
473
# Handle special case where q == 0
470
- y = np .where (q == 0 , np .maximum (0 , - p )** 0.5 , np .zeros_like (p )) # q=0 case
474
+ y = np .where (q == 0 , np .maximum (0 , - p ) ** 0.5 , np .zeros_like (p )) # q=0 case
471
475
472
476
# Compute discriminant
473
477
delta = (q / 2 ) ** 2 + (p / 3 ) ** 3
@@ -490,4 +494,3 @@ def rooth(p, q):
490
494
y = np .max (real_roots , axis = 0 ) * (delta < 0 ) # Keep only real roots when delta < 0
491
495
492
496
return y
493
-
0 commit comments