@@ -126,6 +126,11 @@ class Optimizer(Generic[Extra], MutableMapping[str, dr.ArrayBase]):
126
126
# - an arbitrary sequence of additional optimizer-dependent state values
127
127
state : Dict [str , Tuple [dr .ArrayBase , Optional [LearningRate ], Extra ]]
128
128
129
+ DRJIT_STRUCT = {
130
+ "lr" : LearningRate ,
131
+ "state" : dict ,
132
+ }
133
+
129
134
def __init__ (
130
135
self ,
131
136
lr : LearningRate ,
@@ -960,10 +965,15 @@ def _step(
960
965
# Compute the step size scale, which is a product of
961
966
# - EMA debiasing factor
962
967
# - Adaptive/parameter-specific scaling
968
+ Float32 = dr .float32_array_t (dr .leaf_t (grad ))
969
+ Float64 = dr .float64_array_t (dr .leaf_t (grad ))
970
+ ema_factor = Float32 (
971
+ - dr .sqrt (1 - Float64 (self .beta_2 ) ** t ) / (1 - Float64 (self .beta_1 ) ** t )
972
+ )
963
973
scale = cache .product (
964
974
dr .leaf_t (grad ), # Desired type
965
975
lr ,
966
- - dr . sqrt ( 1 - self . beta_2 ** t ) / ( 1 - self . beta_1 ** t ) ,
976
+ ema_factor ,
967
977
)
968
978
969
979
# Optional: use maximum of second order term
@@ -981,9 +991,11 @@ def _step(
981
991
def _reset (self , key : str , value : dr .ArrayBase , / ) -> None :
982
992
valarr = value .array
983
993
tp = type (valarr )
994
+ UInt = dr .uint32_array_t (dr .leaf_t (tp ))
995
+ t = UInt (0 )
984
996
m_t = dr .opaque (tp , 0 , valarr .shape )
985
997
v_t = dr .opaque (tp , 0 , valarr .shape )
986
- self .state [key ] = value , None , (0 , m_t , v_t )
998
+ self .state [key ] = value , None , (t , m_t , v_t )
987
999
988
1000
# Blend between the old and new versions of the optimizer extra state
989
1001
def _select (
0 commit comments