Skip to content

Commit 7de6d1c

Browse files
Fixed freezing drjit optimizers
1 parent ade88b0 commit 7de6d1c

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

drjit/opt.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,11 @@ class Optimizer(Generic[Extra], MutableMapping[str, dr.ArrayBase]):
126126
# - an arbitrary sequence of additional optimizer-dependent state values
127127
state: Dict[str, Tuple[dr.ArrayBase, Optional[LearningRate], Extra]]
128128

129+
DRJIT_STRUCT = {
130+
"lr": LearningRate,
131+
"state": dict,
132+
}
133+
129134
def __init__(
130135
self,
131136
lr: LearningRate,
@@ -981,9 +986,11 @@ def _step(
981986
def _reset(self, key: str, value: dr.ArrayBase, /) -> None:
982987
valarr = value.array
983988
tp = type(valarr)
989+
UInt = dr.uint_array_t(dr.leaf_t(tp))
990+
t = UInt(0)
984991
m_t = dr.opaque(tp, 0, valarr.shape)
985992
v_t = dr.opaque(tp, 0, valarr.shape)
986-
self.state[key] = value, None, (0, m_t, v_t)
993+
self.state[key] = value, None, (t, m_t, v_t)
987994

988995
# Blend between the old and new versions of the optimizer extra state
989996
def _select(

0 commit comments

Comments
 (0)