@@ -126,6 +126,11 @@ class Optimizer(Generic[Extra], MutableMapping[str, dr.ArrayBase]):
126126 # - an arbitrary sequence of additional optimizer-dependent state values
127127 state : Dict [str , Tuple [dr .ArrayBase , Optional [LearningRate ], Extra ]]
128128
129+ DRJIT_STRUCT = {
130+ "lr" : LearningRate ,
131+ "state" : dict ,
132+ }
133+
129134 def __init__ (
130135 self ,
131136 lr : LearningRate ,
@@ -960,10 +965,15 @@ def _step(
960965 # Compute the step size scale, which is a product of
961966 # - EMA debiasing factor
962967 # - Adaptive/parameter-specific scaling
968+ Float32 = dr .float32_array_t (dr .leaf_t (grad ))
969+ Float64 = dr .float64_array_t (dr .leaf_t (grad ))
970+ ema_factor = Float32 (
971+ - dr .sqrt (1 - Float64 (self .beta_2 ) ** t ) / (1 - Float64 (self .beta_1 ) ** t )
972+ )
963973 scale = cache .product (
964974 dr .leaf_t (grad ), # Desired type
965975 lr ,
966- - dr . sqrt ( 1 - self . beta_2 ** t ) / ( 1 - self . beta_1 ** t ) ,
976+ ema_factor ,
967977 )
968978
969979 # Optional: use maximum of second order term
@@ -981,9 +991,11 @@ def _step(
981991 def _reset (self , key : str , value : dr .ArrayBase , / ) -> None :
982992 valarr = value .array
983993 tp = type (valarr )
994+ UInt = dr .uint32_array_t (dr .leaf_t (tp ))
995+ t = UInt (0 )
984996 m_t = dr .opaque (tp , 0 , valarr .shape )
985997 v_t = dr .opaque (tp , 0 , valarr .shape )
986- self .state [key ] = value , None , (0 , m_t , v_t )
998+ self .state [key ] = value , None , (t , m_t , v_t )
987999
9881000 # Blend between the old and new versions of the optimizer extra state
9891001 def _select (
0 commit comments