Skip to content

Commit 8ce263f

Browse files
[pre-commit.ci] auto fixes from pre-commit hooks
1 parent 26e4f60 commit 8ce263f

File tree

2 files changed

+56
-49
lines changed

2 files changed

+56
-49
lines changed

src/diffpy/snmf/main.py

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,24 @@
1-
import snmf
2-
31
import numpy as np
42
import pandas as pd
3+
import snmf
54

65
# Define fixed feature matrix (X) with distinct, structured features
7-
X = np.array([
8-
[10, 0, 0], # First component dominates first feature
9-
[0, 8, 0], # Second component dominates second feature
10-
[0, 0, 6], # Third component dominates third feature
11-
[4, 4, 0], # Mixed contribution to the fourth feature
12-
[3, 2, 5] # Mixed contribution to the fifth feature
13-
], dtype=float)
6+
X = np.array(
7+
[
8+
[10, 0, 0], # First component dominates first feature
9+
[0, 8, 0], # Second component dominates second feature
10+
[0, 0, 6], # Third component dominates third feature
11+
[4, 4, 0], # Mixed contribution to the fourth feature
12+
[3, 2, 5], # Mixed contribution to the fifth feature
13+
],
14+
dtype=float,
15+
)
1416

1517
# Define fixed coefficient matrix (Y) representing weights
16-
Y = np.array([
17-
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
18-
[2, 4, 6, 8, 10, 12, 14, 16, 18, 20],
19-
[3, 6, 9, 12, 15, 18, 21, 24, 27, 30]
20-
], dtype=float)
18+
Y = np.array(
19+
[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], [2, 4, 6, 8, 10, 12, 14, 16, 18, 20], [3, 6, 9, 12, 15, 18, 21, 24, 27, 30]],
20+
dtype=float,
21+
)
2122

2223
# Compute the resulting data matrix M
2324
MM = np.dot(X, Y)
@@ -28,11 +29,14 @@
2829
MM_norm = (MM - MM.min()) / (MM.max() - MM.min())
2930

3031
# Generate an initial guess Y0 with slightly perturbed values
31-
Y0 = np.array([
32-
[1.5, 1.8, 2.9, 3.6, 4.8, 5.7, 7.1, 8.2, 9.4, 10.3],
33-
[2.2, 4.1, 5.9, 8.1, 9.8, 11.9, 14.2, 16.5, 18.1, 19.7],
34-
[2.7, 5.5, 8.8, 11.5, 14.6, 17.8, 20.5, 23.9, 26.3, 29.2]
35-
], dtype=float)
32+
Y0 = np.array(
33+
[
34+
[1.5, 1.8, 2.9, 3.6, 4.8, 5.7, 7.1, 8.2, 9.4, 10.3],
35+
[2.2, 4.1, 5.9, 8.1, 9.8, 11.9, 14.2, 16.5, 18.1, 19.7],
36+
[2.7, 5.5, 8.8, 11.5, 14.6, 17.8, 20.5, 23.9, 26.3, 29.2],
37+
],
38+
dtype=float,
39+
)
3640

3741
# Normalize Y0 as well
3842
Y0_norm = (Y0 - Y0.min()) / (Y0.max() - Y0.min())
@@ -55,4 +59,4 @@
5559
print(f"My final guess for X: {my_model.X}")
5660
print(f"My final guess for Y: {my_model.Y}")
5761
print(f"Compare to true X: {X_norm}")
58-
print(f"Compare to true Y: {Y_norm}")
62+
print(f"Compare to true Y: {Y_norm}")

src/diffpy/snmf/snmf.py

Lines changed: 32 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
2-
from scipy.sparse import diags, block_diag, coo_matrix, spdiags
32
from scipy.optimize import minimize
3+
from scipy.sparse import block_diag, coo_matrix, diags, spdiags
4+
45

56
class SNMFOptimizer:
67
def __init__(self, MM, Y0, X0=None, A=None, rho=1e18, eta=1, maxiter=300):
@@ -97,7 +98,7 @@ def apply_interpolation(self, a, x):
9798

9899
# Ensure `a` is an array and reshape for broadcasting
99100
a = np.atleast_1d(np.asarray(a)) # Ensures a is at least 1D
100-
#a = np.asarray(a)[None, :] # Shape (1, M) to allow broadcasting
101+
# a = np.asarray(a)[None, :] # Shape (1, M) to allow broadcasting
101102

102103
# Compute fractional indices, broadcasting over `a`
103104
ii = np.arange(N)[:, None] / a # Shape (N, M)
@@ -125,7 +126,7 @@ def apply_interpolation(self, a, x):
125126
Tx = np.vstack([Tx, np.zeros((N - len(I), Tx.shape[1]))])
126127

127128
# Compute second derivative (Hx)
128-
ddi = -di / a + i * a ** -2
129+
ddi = -di / a + i * a**-2
129130
Hx = x[I] * (-ddi) + x[np.minimum(I + 1, N - 1)] * ddi
130131
Hx = np.vstack([Hx, np.zeros((N - len(I), Hx.shape[1]))])
131132

@@ -152,10 +153,10 @@ def get_objective_function(self, R=None, A=None):
152153
R = self.R
153154
if A is None:
154155
A = self.A
155-
residual_term = 0.5 * np.linalg.norm(R, 'fro') ** 2
156+
residual_term = 0.5 * np.linalg.norm(R, "fro") ** 2
156157
# original code selected indices, but for now we'll compute the norm over the whole matrix
157158
# residual_term = 0.5 * np.linalg.norm(self.R[index, :], 'fro') ** 2
158-
regularization_term = 0.5 * self.rho * np.linalg.norm(self.P @ A.T, 'fro') ** 2
159+
regularization_term = 0.5 * self.rho * np.linalg.norm(self.P @ A.T, "fro") ** 2
159160
sparsity_term = self.eta * np.sum(np.sqrt(self.X)) # Square root penalty
160161
# Final objective function value
161162
function = residual_term + regularization_term + sparsity_term
@@ -250,7 +251,7 @@ def apply_transformation_matrix(self, R=None):
250251
iI = ii - II
251252

252253
# Expand row indices (MATLAB: repm = repmat(1:K, Nindex, M))
253-
repm = np.tile(np.arange(self.K), (self.N, self.M)) # indexed to zero here
254+
repm = np.tile(np.arange(self.K), (self.N, self.M)) # indexed to zero here
254255

255256
# Compute transformations (MATLAB: kro = kron(R(index,:), ones(1, K)))
256257
kro = np.kron(R, np.ones((1, self.K))) # Use full `R`
@@ -313,7 +314,7 @@ def objective(y):
313314
y0 = np.ones(K)
314315

315316
# Solve QP
316-
result = minimize(objective, y0, bounds=bounds, method='SLSQP')
317+
result = minimize(objective, y0, bounds=bounds, method="SLSQP")
317318

318319
return result.x # Optimal solution
319320

@@ -339,18 +340,21 @@ def updateX(self):
339340
denom = np.linalg.norm(self.curX - self.preX, "fro") ** 2 # Frobenius norm squared
340341
L = num / denom if denom > 0 else L0 # Ensure L0 fallback
341342

342-
L = np.maximum(L, L0) # ensure L is positive
343-
while True: # iterate updating X
343+
L = np.maximum(L, L0) # ensure L is positive
344+
while True: # iterate updating X
344345
X_ = self.curX - self.GraX / L
345346
# Solve x^3 + p*x + q = 0 for the largest real root
346347
# off the shelf solver did not work element-wise for matrices
347348
X = np.square(rooth(-X_, self.eta / (2 * L)))
348349
# Mask values that should be set to zero
349-
mask = (X ** 2 * L / 2 - L * X * X_ + self.eta * np.sqrt(X)) < 0
350+
mask = (X**2 * L / 2 - L * X * X_ + self.eta * np.sqrt(X)) < 0
350351
X[mask] = 0
351352
# Check if objective function improves
352-
if self.objective_history[-1] - self.get_objective_function(
353-
self.get_residual_matrix(np.maximum(0, X), self.Y, self.A), self.A) > 0:
353+
if (
354+
self.objective_history[-1]
355+
- self.get_objective_function(self.get_residual_matrix(np.maximum(0, X), self.Y, self.A), self.A)
356+
> 0
357+
):
354358
break
355359
# Increase L
356360
L *= 2
@@ -405,29 +409,29 @@ def regularize_function(self, A=None):
405409
fun = self.get_objective_function(RA, A)
406410

407411
# Compute gradient (removed index filtering)
408-
gra = np.reshape(
409-
np.sum(TX * np.tile(RA, (1, K)), axis=0), (M, K)
410-
).T + self.rho * A @ self.P.T @ self.P # Gradient matrix
412+
gra = (
413+
np.reshape(np.sum(TX * np.tile(RA, (1, K)), axis=0), (M, K)).T + self.rho * A @ self.P.T @ self.P
414+
) # Gradient matrix
411415

412416
# Compute Hessian (removed index filtering)
413417
hess = np.zeros((M * K, M * K))
414418

415419
for m in range(M):
416420
Tx = TX[:, m + M * np.arange(K)] # Now using all rows
417-
hess[m * K: (m + 1) * K, m * K: (m + 1) * K] = Tx.T @ Tx
421+
hess[m * K : (m + 1) * K, m * K : (m + 1) * K] = Tx.T @ Tx
418422

419423
hess = (
420-
hess
421-
+ spdiags(
422-
np.reshape(
423-
np.reshape(np.sum(HX * np.tile(RA, (1, K)), axis=0), (M, K)).T,
424-
(M * K,), # ✅ Ensure 1D instead of (M*K,1)
425-
),
426-
0, # Diagonal index
427-
M * K, # Number of rows
428-
M * K, # Number of columns
429-
).toarray()
430-
+ self.rho * self.PPPP
424+
hess
425+
+ spdiags(
426+
np.reshape(
427+
np.reshape(np.sum(HX * np.tile(RA, (1, K)), axis=0), (M, K)).T,
428+
(M * K,), # ✅ Ensure 1D instead of (M*K,1)
429+
),
430+
0, # Diagonal index
431+
M * K, # Number of rows
432+
M * K, # Number of columns
433+
).toarray()
434+
+ self.rho * self.PPPP
431435
)
432436

433437
return fun, gra, hess
@@ -467,7 +471,7 @@ def rooth(p, q):
467471
Solves x^3 + p*x + q = 0 element-wise for matrices, returning the largest real root.
468472
"""
469473
# Handle special case where q == 0
470-
y = np.where(q == 0, np.maximum(0, -p)**0.5, np.zeros_like(p)) # q=0 case
474+
y = np.where(q == 0, np.maximum(0, -p) ** 0.5, np.zeros_like(p)) # q=0 case
471475

472476
# Compute discriminant
473477
delta = (q / 2) ** 2 + (p / 3) ** 3
@@ -490,4 +494,3 @@ def rooth(p, q):
490494
y = np.max(real_roots, axis=0) * (delta < 0) # Keep only real roots when delta < 0
491495

492496
return y
493-

0 commit comments

Comments
 (0)