Skip to content

Commit 48daace

Browse files
Added optimizer freezing tests
1 parent bec6588 commit 48daace

File tree

1 file changed

+41
-0
lines changed

1 file changed

+41
-0
lines changed

tests/test_freeze.py

+41
Original file line numberDiff line numberDiff line change
@@ -2822,3 +2822,44 @@ def func(x: t, lit: MyHolder):
28222822
assert dr.allclose(ref, res)
28232823

28242824
assert frozen.n_recordings == 1
2825+
2826+
@pytest.test_arrays("float32, jit, diff, shape=(*)")
2827+
@pytest.mark.parametrize("optimizer", ["sdg", "rmsprop", "adam"])
2828+
def test77_optimizers(t, optimizer):
2829+
n = 10
2830+
2831+
def func(y, opt):
2832+
loss = dr.mean(dr.square(opt["x"] - y))
2833+
2834+
dr.backward(loss)
2835+
2836+
opt.step()
2837+
2838+
return opt["x"], loss
2839+
2840+
def init_optimizer():
2841+
if optimizer == "sdg":
2842+
opt = dr.opt.SGD(lr = 0.001, momentum = 0.9)
2843+
elif optimizer == "rmsprop":
2844+
opt = dr.opt.RMSProp(lr = 0.001)
2845+
elif optimizer == "adam":
2846+
opt = dr.opt.Adam(lr = 0.001)
2847+
return opt
2848+
2849+
frozen = dr.freeze(func)
2850+
2851+
opt_func = init_optimizer()
2852+
opt_frozen = init_optimizer()
2853+
2854+
for i in range(n):
2855+
x = dr.full(t, 1, 10)
2856+
y = dr.full(t, 0, 10)
2857+
2858+
opt_func["x"] = x
2859+
opt_frozen["x"] = x
2860+
2861+
res_x, res_loss = frozen(y, opt_frozen)
2862+
ref_x, ref_loss = func(y, opt_func)
2863+
2864+
assert dr.allclose(res_x, ref_x)
2865+
assert dr.allclose(res_loss, ref_loss)

0 commit comments

Comments
 (0)