@@ -2822,3 +2822,44 @@ def func(x: t, lit: MyHolder):
2822
2822
assert dr .allclose (ref , res )
2823
2823
2824
2824
assert frozen .n_recordings == 1
2825
+
2826
+ @pytest .test_arrays ("float32, jit, diff, shape=(*)" )
2827
+ @pytest .mark .parametrize ("optimizer" , ["sdg" , "rmsprop" , "adam" ])
2828
+ def test77_optimizers (t , optimizer ):
2829
+ n = 10
2830
+
2831
+ def func (y , opt ):
2832
+ loss = dr .mean (dr .square (opt ["x" ] - y ))
2833
+
2834
+ dr .backward (loss )
2835
+
2836
+ opt .step ()
2837
+
2838
+ return opt ["x" ], loss
2839
+
2840
+ def init_optimizer ():
2841
+ if optimizer == "sdg" :
2842
+ opt = dr .opt .SGD (lr = 0.001 , momentum = 0.9 )
2843
+ elif optimizer == "rmsprop" :
2844
+ opt = dr .opt .RMSProp (lr = 0.001 )
2845
+ elif optimizer == "adam" :
2846
+ opt = dr .opt .Adam (lr = 0.001 )
2847
+ return opt
2848
+
2849
+ frozen = dr .freeze (func )
2850
+
2851
+ opt_func = init_optimizer ()
2852
+ opt_frozen = init_optimizer ()
2853
+
2854
+ for i in range (n ):
2855
+ x = dr .full (t , 1 , 10 )
2856
+ y = dr .full (t , 0 , 10 )
2857
+
2858
+ opt_func ["x" ] = x
2859
+ opt_frozen ["x" ] = x
2860
+
2861
+ res_x , res_loss = frozen (y , opt_frozen )
2862
+ ref_x , ref_loss = func (y , opt_func )
2863
+
2864
+ assert dr .allclose (res_x , ref_x )
2865
+ assert dr .allclose (res_loss , ref_loss )
0 commit comments