@@ -244,6 +244,8 @@ def __init__(self, dtype: Optional[Type[drjit.ArrayBase]] = None):
244
244
self .dtype = dtype
245
245
def __call__ (self , arg : CoopVec , / ) -> CoopVec :
246
246
return cast (arg , self .dtype )
247
+ def __repr__ (self ):
248
+ return f'Cast(dtype={ self .dtype .__name__ } )'
247
249
248
250
class Linear (Module ):
249
251
r"""
@@ -286,7 +288,7 @@ def __init__(self, in_features: int = -1, out_features: int = -1, bias = True) -
286
288
self .weights = self .bias = None
287
289
288
290
def __repr__ (self ) -> str :
289
- s = f'Linear({ self .config [0 ]} , { self .config [1 ]} '
291
+ s = f'Linear(in_features= { self .config [0 ]} , out_features= { self .config [1 ]} '
290
292
if not self .config [2 ]:
291
293
s += ', bias=False'
292
294
s += ')'
@@ -391,15 +393,20 @@ class TriEncode(Module):
391
393
:align: center
392
394
"""
393
395
396
+ DRJIT_STRUCT = { 'octaves' : int , 'shift' : float , 'channels' : int }
397
+
394
398
def __init__ (self , octaves : int = 0 , shift : float = 0 ) -> None :
395
399
self .octaves = octaves
396
400
self .shift = shift
401
+ self .channels = - 1
397
402
398
403
def _alloc (self , dtype : Type [drjit .ArrayBase ], size : int = - 1 , / ) -> Tuple [Module , int ]:
399
- return self , size * self .octaves * 2
404
+ r = TriEncode (self .octaves , self .shift )
405
+ r .channels = size
406
+ return r , size * self .octaves * 2
400
407
401
408
def __repr__ (self ) -> str :
402
- return f'TriEncode({ self .octaves } )'
409
+ return f'TriEncode(octaves= { self .octaves } , shift= { self . shift } , in_channels= { self . channels } , out_features= { self . channels * self . octaves * 2 } )'
403
410
404
411
def __call__ (self , arg : CoopVec , / ) -> CoopVec :
405
412
args , r = list (arg ), list ()
@@ -453,8 +460,11 @@ class SinEncode(Module):
453
460
:align: center
454
461
"""
455
462
463
+ DRJIT_STRUCT = { 'octaves' : int , 'shift' : Union [tuple , None ], 'channels' : int }
464
+
456
465
def __init__ (self , octaves : int = 0 , shift : float = 0 ) -> None :
457
466
self .octaves = octaves
467
+ self .channels = - 1
458
468
459
469
if shift == 0 :
460
470
self .shift = None
@@ -463,10 +473,13 @@ def __init__(self, octaves: int = 0, shift: float = 0) -> None:
463
473
drjit .cos (shift * 2 * drjit .pi ))
464
474
465
475
def _alloc (self , dtype : Type [drjit .ArrayBase ], size : int = - 1 , / ) -> Tuple [Module , int ]:
466
- return self , size * self .octaves * 2
476
+ r = SinEncode (self .octaves )
477
+ r .channels = size
478
+ r .shift = self .shift
479
+ return r , size * self .octaves * 2
467
480
468
481
def __repr__ (self ) -> str :
469
- return f'SinEncode({ self .octaves } )'
482
+ return f'SinEncode(octaves= { self .octaves } , shift= { self . shift } , in_channels= { self . channels } , out_features= { self . channels * self . octaves * 2 } )'
470
483
471
484
def __call__ (self , arg : CoopVec , / ) -> CoopVec :
472
485
args , r = list (arg ), list ()
0 commit comments