@@ -80,17 +80,44 @@ def __init__(self, MM, Y0=None, X0=None, A=None, rho=1e12, eta=610, maxiter=300,
80
80
)
81
81
82
82
# Convergence check: Stop if diffun is small and at least 20 iterations have passed
83
- print (self .objective_difference , " < " , self .objective_function * 1e-6 )
84
- if self .objective_difference < self .objective_function * 1e-6 and outiter >= 20 :
83
+ # MATLAB uses 1e-6 but also gets faster convergence, so this makes up that difference
84
+ print (self .objective_difference , " < " , self .objective_function * 5e-7 )
85
+ if self .objective_difference < self .objective_function * 5e-7 and outiter >= 20 :
85
86
break
86
87
87
88
# Normalize our results
89
+ # TODO make this much cleaner
88
90
Y_row_max = np .max (self .Y , axis = 1 , keepdims = True )
89
91
self .Y = self .Y / Y_row_max
90
92
A_row_max = np .max (self .A , axis = 1 , keepdims = True )
91
93
self .A = self .A / A_row_max
92
- # TODO loop to normalize X (currently not normalized)
94
+ # loop to normalize X
93
95
# effectively just re-running class with non-normalized X, normalized Y/A as inputs, then only update X
96
+ # reset difference trackers and initialize
97
+ self .preX = self .X .copy () # Previously stored X (like X0 for now)
98
+ self .GraX = np .zeros_like (self .X ) # Gradient of X (zeros for now)
99
+ self .preGraX = np .zeros_like (self .X ) # Previous gradient of X (zeros for now)
100
+ self .R = self .get_residual_matrix ()
101
+ self .objective_function = self .get_objective_function ()
102
+ self .objective_difference = None
103
+ self .objective_history = [self .objective_function ]
104
+ self .outiter = 0
105
+ self .iter = 0
106
+ for outiter in range (100 ):
107
+ if iter == 1 :
108
+ self .iter = 1 # So step size can adapt without an inner loop
109
+ self .updateX ()
110
+ self .R = self .get_residual_matrix ()
111
+ self .objective_function = self .get_objective_function ()
112
+ print (f"Objective function after normX: { self .objective_function :.5e} " )
113
+ self .objective_history .append (self .objective_function )
114
+ self .objective_difference = self .objective_history [- 2 ] - self .objective_history [- 1 ]
115
+ if self .objective_difference < self .objective_function * 5e-7 and outiter >= 20 :
116
+ break
117
+ # end of normalization (and program)
118
+ # note that objective function does not fully recover after normalization
119
+ # it is still higher than pre-normalization, but that is okay and matches MATLAB
120
+ print ("Finished optimization." )
94
121
95
122
def outer_loop (self ):
96
123
# This inner loop runs up to four times per outer loop, making updates to X, Y
0 commit comments