Skip to content

Commit 1c81b5b

Browse files
committed
move verbose __repr__() methods for nn.Module classes
1 parent f80af5a commit 1c81b5b

File tree

1 file changed

+18
-5
lines changed

1 file changed

+18
-5
lines changed

drjit/nn.py

+18-5
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,8 @@ def __init__(self, dtype: Optional[Type[drjit.ArrayBase]] = None):
244244
self.dtype = dtype
245245
def __call__(self, arg: CoopVec, /) -> CoopVec:
246246
return cast(arg, self.dtype)
247+
def __repr__(self):
248+
return f'Cast(dtype={self.dtype.__name__})'
247249

248250
class Linear(Module):
249251
r"""
@@ -286,7 +288,7 @@ def __init__(self, in_features: int = -1, out_features: int = -1, bias = True) -
286288
self.weights = self.bias = None
287289

288290
def __repr__(self) -> str:
289-
s = f'Linear({self.config[0]}, {self.config[1]}'
291+
s = f'Linear(in_features={self.config[0]}, out_features={self.config[1]}'
290292
if not self.config[2]:
291293
s += ', bias=False'
292294
s += ')'
@@ -391,15 +393,20 @@ class TriEncode(Module):
391393
:align: center
392394
"""
393395

396+
DRJIT_STRUCT = { 'octaves' : int, 'shift': float, 'channels': int }
397+
394398
def __init__(self, octaves: int = 0, shift: float = 0) -> None:
395399
self.octaves = octaves
396400
self.shift = shift
401+
self.channels = -1
397402

398403
def _alloc(self, dtype: Type[drjit.ArrayBase], size : int = -1, /) -> Tuple[Module, int]:
399-
return self, size * self.octaves * 2
404+
r = TriEncode(self.octaves, self.shift)
405+
r.channels = size
406+
return r, size * self.octaves * 2
400407

401408
def __repr__(self) -> str:
402-
return f'TriEncode({self.octaves})'
409+
return f'TriEncode(octaves={self.octaves}, shift={self.shift}, in_channels={self.channels}, out_features={self.channels*self.octaves*2})'
403410

404411
def __call__(self, arg: CoopVec, /) -> CoopVec:
405412
args, r = list(arg), list()
@@ -453,8 +460,11 @@ class SinEncode(Module):
453460
:align: center
454461
"""
455462

463+
DRJIT_STRUCT = { 'octaves' : int, 'shift': Union[tuple, None], 'channels': int }
464+
456465
def __init__(self, octaves: int = 0, shift: float = 0) -> None:
457466
self.octaves = octaves
467+
self.channels = -1
458468

459469
if shift == 0:
460470
self.shift = None
@@ -463,10 +473,13 @@ def __init__(self, octaves: int = 0, shift: float = 0) -> None:
463473
drjit.cos(shift * 2 * drjit.pi))
464474

465475
def _alloc(self, dtype: Type[drjit.ArrayBase], size : int = -1, /) -> Tuple[Module, int]:
466-
return self, size * self.octaves * 2
476+
r = SinEncode(self.octaves)
477+
r.channels = size
478+
r.shift = self.shift
479+
return r, size * self.octaves * 2
467480

468481
def __repr__(self) -> str:
469-
return f'SinEncode({self.octaves})'
482+
return f'SinEncode(octaves={self.octaves}, shift={self.shift}, in_channels={self.channels}, out_features={self.channels*self.octaves*2})'
470483

471484
def __call__(self, arg: CoopVec, /) -> CoopVec:
472485
args, r = list(arg), list()

0 commit comments

Comments
 (0)