Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions drjit/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,53 @@ class Tanh(Module):
DRJIT_STRUCT = { }
def __call__(self, arg: CoopVec, /) -> CoopVec:
return drjit.tanh(arg)


class Sigmoid(Module):
r"""
Sigmoid activation function.

.. math::
\mathrm{Sigmoid}(x) = \frac{1}{1 + e^{-x}} = 0.5 + 0.5 \cdot \tanh(x/2)
"""
DRJIT_STRUCT = {}

def __call__(self, arg: CoopVec, /) -> CoopVec:
# Use the identity: sigmoid(x) = 0.5 + 0.5 * tanh(x/2)
half_x = arg * 0.5
tanh_half_x = drjit.tanh(half_x)
return drjit.fma(0.5, tanh_half_x, 0.5) # 0.5 * tanh + 0.5

class SiLU(Module):
r"""
SiLU activation function. Also known as the "swish" function.
Copy link
Member

@wjakob wjakob Sep 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you document the expression here as well?

.. math::
\mathrm{SiLU}(x) = x \cdot \mathrm{Sigmoid}(x)
= \frac{x}{1 + e^{-x}}
"""
DRJIT_STRUCT = {}

def __call__(self, arg: CoopVec, /) -> CoopVec:
# Use the identity: sigmoid(x) = 0.5 + 0.5 * tanh(x/2)
half_x = arg * 0.5
tanh_half_x = drjit.tanh(half_x)
sigmoid = drjit.fma(0.5, tanh_half_x, 0.5) # 0.5 * tanh + 0.5
return arg * sigmoid

class Softplus(Module):
r"""
Softplus activation function.

.. math::
\mathrm{Softplus}(x) = \log(1 + e^x)
"""
DRJIT_STRUCT = {}

def __call__(self, arg: CoopVec, /) -> CoopVec:
# For numerical stability: log(1 + exp(x)) = x + log(1 + exp(-x)) when x > 0
# Using exp2: log(1 + exp(x)) = log(1 + 2^(x/ln(2))) = ln(2) * log2(1 + 2^(x/ln(2)))
x_log2 = arg * (1 / drjit.log(2))
return drjit.log(2) * drjit.log2(1.0 + drjit.exp2(x_log2))

class ScaleAdd(Module):
r"""
Expand Down
94 changes: 84 additions & 10 deletions drjit/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -1136,7 +1136,7 @@ class AdamW(Adam):
"""

# Weight decay coefficient
weight_decay: float
global_weight_decay: float

def __init__(
self,
Expand All @@ -1146,7 +1146,7 @@ def __init__(
beta_1: float = 0.9,
beta_2: float = 0.999,
epsilon: float = 1e-8,
weight_decay: float = 0.01,
weight_decay: Union[float, Mapping[str, float]] = None,
mask_updates: bool = False,
promote_fp16: bool = True,
uniform: bool = False,
Expand All @@ -1172,10 +1172,12 @@ def __init__(
will cause past gradients to persist for a longer amount of
time.

weight_decay (float):
weight_decay (float | Mapping[str, float]):
Weight decay coefficient for L2 regularization. Unlike Adam,
this is applied directly to parameters rather than gradients,
providing better regularization with adaptive learning rates.
You may also provide a dictionary mapping parameter names to
individual weight decay values.

uniform (bool):
If enabled, the optimizer will use the *UniformAdam* variant of
Expand All @@ -1195,7 +1197,18 @@ def __init__(
Optional dictionary-like object containing an initial set of
parameters.
"""
# Store per-parameter weight decay
self.weight_decay_dict: Mapping[str, float] = {}

if (weight_decay is None) or (isinstance(weight_decay, Mapping)):
self.global_weight_decay = 0.0
elif not isinstance(weight_decay, float):
raise TypeError(
"'weight_decay' must be None, a float, or a mapping from parameter names to floats"
)
else:
self.global_weight_decay = weight_decay

super().__init__(
lr,
params,
Expand All @@ -1206,11 +1219,48 @@ def __init__(
promote_fp16=promote_fp16,
uniform=uniform
)

self.set_weight_decay(weight_decay)

if weight_decay < 0:
raise RuntimeError("'weight_decay' must be >= 0")
def set_weight_decay(self, value: Union[float, Mapping[str, float], None] = None):
"""
Set the weight decay globally or per parameter.
Args:
value: float to set globally, or dict mapping parameter names to decay.
"""
if isinstance(value, float):
# Global weight decay
if value < 0:
raise ValueError("weight_decay must be non-negative.")
self.global_weight_decay = value

elif isinstance(value, Mapping):
for k, wd in value.items():
self.weight_decay_dict[k] = wd
else:
raise ValueError("weight_decay must be a float or a mapping")

for k in self.state:
decay = self.weight_decay_dict.get(k, self.global_weight_decay)

Float = dr.float32_array_t(dr.leaf_t(self.state[k][0]))
self.state[k] = self.state[k][0], self.state[k][1], self.state[k][2], (
self.state[k][3][0],
self.state[k][3][1],
self.state[k][3][2],
Float(decay),
)

self.weight_decay = weight_decay
def _reset(self, key: str, value: dr.ArrayBase, promoted: bool, /) -> None:
valarr = value.array
tp = type(valarr)
UInt = dr.uint32_array_t(dr.leaf_t(tp))
Float = dr.float32_array_t(dr.leaf_t(tp))
t = UInt(0)
m_t = dr.opaque(tp, 0, valarr.shape)
v_t = dr.opaque(tp, 0, valarr.shape)
decay = self.weight_decay_dict.get(key, self.global_weight_decay)
self.state[key] = value, promoted, None, (t, m_t, v_t, Float(decay))

def _step(
self,
Expand All @@ -1221,9 +1271,33 @@ def _step(
extra: Tuple[int, dr.ArrayBase, dr.ArrayBase],
/,
) -> Tuple[dr.ArrayBase, Tuple[int, dr.ArrayBase, dr.ArrayBase]]:
new_value, new_extra = super()._step(cache, value, grad, lr, extra)
scaled_value = dr.fma(value, -lr * self.weight_decay, new_value)

decay = extra[3]
#Take Adam step
new_value, new_extra = super()._step(cache, value, grad, lr, extra[:3])

# Take weight decay step
scaled_value = dr.fma(value, -lr * decay, new_value)

new_extra = (new_extra[0], new_extra[1], new_extra[2], decay)
return scaled_value, new_extra

def _select(
self,
mask: dr.ArrayBase,
extra: Tuple[int, dr.ArrayBase, dr.ArrayBase],
new_extra: Tuple[int, dr.ArrayBase, dr.ArrayBase, float],
/,
) -> Tuple[int, dr.ArrayBase, dr.ArrayBase]:
# Known issue: we don't mask the update to 't' here. That would
# require moving this parameter to the GPU, with a whole bunch
# of downsides. It is only relevant for AMP training. Oh well.
return (
new_extra[0],
dr.select(mask, extra[1], new_extra[1]),
dr.select(mask, extra[2], new_extra[2]),
new_extra[3]
)

def __repr__(self):
"""Return a human-readable string representation"""
Expand All @@ -1242,7 +1316,7 @@ def __repr__(self):
" lr = %s,\\n"
" beta = (%g, %g),\\n"
" epsilon = %g,\\n"
" weight_decay = %g\\n"
" weight_decay = %s\\n"
"]"
% (
list(self.keys()),
Expand All @@ -1251,7 +1325,7 @@ def __repr__(self):
self.beta_1,
self.beta_2,
self.epsilon,
self.weight_decay,
self.weight_decay_dict,
)
)

Expand Down