Skip to content

Commit bec6588

Browse files
Fixed freezing drjit optimizers
1 parent ade88b0 commit bec6588

File tree

1 file changed

+14
-2
lines changed

1 file changed

+14
-2
lines changed

drjit/opt.py

+14-2
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,11 @@ class Optimizer(Generic[Extra], MutableMapping[str, dr.ArrayBase]):
126126
# - an arbitrary sequence of additional optimizer-dependent state values
127127
state: Dict[str, Tuple[dr.ArrayBase, Optional[LearningRate], Extra]]
128128

129+
DRJIT_STRUCT = {
130+
"lr": LearningRate,
131+
"state": dict,
132+
}
133+
129134
def __init__(
130135
self,
131136
lr: LearningRate,
@@ -960,10 +965,15 @@ def _step(
960965
# Compute the step size scale, which is a product of
961966
# - EMA debiasing factor
962967
# - Adaptive/parameter-specific scaling
968+
Float32 = dr.float32_array_t(dr.leaf_t(grad))
969+
Float64 = dr.float64_array_t(dr.leaf_t(grad))
970+
ema_factor = Float32(
971+
-dr.sqrt(1 - Float64(self.beta_2) ** t) / (1 - Float64(self.beta_1) ** t)
972+
)
963973
scale = cache.product(
964974
dr.leaf_t(grad), # Desired type
965975
lr,
966-
-dr.sqrt(1 - self.beta_2**t) / (1 - self.beta_1**t),
976+
ema_factor,
967977
)
968978

969979
# Optional: use maximum of second order term
@@ -981,9 +991,11 @@ def _step(
981991
def _reset(self, key: str, value: dr.ArrayBase, /) -> None:
982992
valarr = value.array
983993
tp = type(valarr)
994+
UInt = dr.uint32_array_t(dr.leaf_t(tp))
995+
t = UInt(0)
984996
m_t = dr.opaque(tp, 0, valarr.shape)
985997
v_t = dr.opaque(tp, 0, valarr.shape)
986-
self.state[key] = value, None, (0, m_t, v_t)
998+
self.state[key] = value, None, (t, m_t, v_t)
987999

9881000
# Blend between the old and new versions of the optimizer extra state
9891001
def _select(

0 commit comments

Comments
 (0)