Skip to content

Commit 79f815a

Browse files
authored
Merge pull request #211 from SEMCOG/stability
Fixes MNL runtime overflow and warns in the case of non-convergence
2 parents 2497039 + 9794d9e commit 79f815a

File tree

2 files changed

+13
-0
lines changed

2 files changed

+13
-0
lines changed

urbansim/urbanchoice/mnl.py

+7
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ def mnl_probs(data, beta, numalts):
3535
raise Exception("Number of alternatives is zero")
3636
utilities.reshape(numalts, utilities.size() // numalts)
3737

38+
# https://stats.stackexchange.com/questions/304758/softmax-overflow
39+
utilities = utilities.subtract(utilities.max(0))
40+
3841
exponentiated_utility = utilities.exp(inplace=True)
3942
if clamp:
4043
exponentiated_utility.inftoval(1e20)
@@ -245,6 +248,10 @@ def mnl_estimate(data, chosen, numalts, GPU=False, coeffrange=(-3, 3),
245248
approx_grad=False,
246249
bounds=bounds
247250
)
251+
252+
if bfgs_result[2]['warnflag'] > 0:
253+
logger.warn("mnl did not converge correctly: %s", bfgs_result)
254+
248255
beta = bfgs_result[0]
249256
stderr = mnl_loglik(
250257
beta, data, chosen, numalts, weights, stderr=1, lcgrad=lcgrad)

urbansim/urbanchoice/pmat.py

+6
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,12 @@ def cumsum(self, axis):
7979
# elif self.typ == 'cuda':
8080
# return PMAT(misc.cumsum(self.mat,axis=axis))
8181

82+
def max(self, axis):
83+
if self.typ == 'numpy':
84+
return PMAT(np.max(self.mat, axis=axis))
85+
elif self.typ == 'cuda':
86+
return PMAT(self.mat.max(axis=axis))
87+
8288
def argmax(self, axis):
8389
if self.typ == 'numpy':
8490
return PMAT(np.argmax(self.mat, axis=axis))

0 commit comments

Comments
 (0)