@@ -322,7 +322,7 @@ class BitmapTexture final : public Texture<Float, Spectrum> {
322
322
m_wrap_mode,
323
323
m_raw,
324
324
m_accel,
325
- tensor);
325
+ std::move ( tensor) );
326
326
}
327
327
328
328
private:
@@ -367,35 +367,37 @@ class BitmapTextureImpl : public Texture<Float, Spectrum> {
367
367
using StoredTensorXf = dr::replace_scalar_t <TensorXf, StoredScalar>;
368
368
using StoredTexture2f = dr::Texture<StoredType, 2 >;
369
369
370
+ template <typename Tensor>
370
371
BitmapTextureImpl (const Properties &props,
371
- const std::string& name,
372
- const ScalarTransform3f& transform,
373
- dr::FilterMode filter_mode,
374
- dr::WrapMode wrap_mode,
375
- bool raw,
376
- bool accel,
377
- StoredTensorXf & tensor) :
372
+ const std::string& name,
373
+ const ScalarTransform3f& transform,
374
+ dr::FilterMode filter_mode,
375
+ dr::WrapMode wrap_mode,
376
+ bool raw,
377
+ bool accel,
378
+ Tensor& & tensor) :
378
379
Texture (props),
379
380
m_name (name),
380
381
m_transform (transform),
381
382
m_accel (accel),
382
- m_raw (raw),
383
- m_texture (tensor, accel, accel, filter_mode, wrap_mode) {
383
+ m_raw (raw) {
384
384
385
385
/* Compute mean without migrating texture data
386
386
i.e. Avoid call to m_texture.tensor() that triggers migration.
387
387
For CUDA-variants, ideally want to solely keep data as CUDA texture
388
388
*/
389
389
rebuild_internals (tensor, true , false );
390
+
391
+ m_texture = StoredTexture2f (std::forward<Tensor>(tensor), accel, accel,
392
+ filter_mode, wrap_mode);
390
393
}
391
394
392
395
void traverse (TraversalCallback *callback) override {
393
396
callback->put_parameter (" data" , m_texture.tensor (), +ParamFlags::Differentiable);
394
397
callback->put_parameter (" to_uv" , m_transform, +ParamFlags::NonDifferentiable);
395
398
}
396
399
397
- void
398
- parameters_changed (const std::vector<std::string> &keys = {}) override {
400
+ void parameters_changed (const std::vector<std::string> &keys = {}) override {
399
401
if (keys.empty () || string::contains (keys, " data" )) {
400
402
const size_t channels = m_texture.shape ()[2 ];
401
403
if (channels != 1 && channels != 3 )
@@ -801,13 +803,14 @@ class BitmapTextureImpl : public Texture<Float, Spectrum> {
801
803
if (m_transform != ScalarTransform3f ())
802
804
dr::make_opaque (m_transform);
803
805
804
- size_t pixel_count = ( size_t ) dr::prod ( resolution () );
805
- const size_t channels = m_texture. shape ()[ 2 ];
806
- bool range_issue = false ;
806
+ const dr::vector< size_t > &shape = tensor. shape ( );
807
+ size_t pixel_count = shape[ 0 ] * shape[ 1 ],
808
+ channels = shape[ 2 ] ;
807
809
810
+ bool range_issue = false ;
808
811
using FloatStorage = DynamicBuffer<Float>;
809
812
using StoredTypeArray= DynamicBuffer<StoredType>;
810
- FloatStorage values = dr::empty<FloatStorage>(pixel_count) ;
813
+ FloatStorage values;
811
814
812
815
if (channels == 3 ) {
813
816
if constexpr (dr::is_jit_v<Float>) {
@@ -824,7 +827,11 @@ class BitmapTextureImpl : public Texture<Float, Spectrum> {
824
827
values = luminance (colors_fl);
825
828
} else {
826
829
StoredScalar* ptr = (StoredScalar*) tensor.data ();
827
- ScalarFloat *out = values.data (), mean = 0 ;
830
+ ScalarFloat *out = nullptr , mean = 0 ;
831
+ if (init_distr) {
832
+ values = dr::empty<FloatStorage>(pixel_count);
833
+ out = values.data ();
834
+ }
828
835
829
836
for (size_t i = 0 ; i < pixel_count; ++i) {
830
837
Color3f col (ptr[0 ], ptr[1 ], ptr[2 ]);
@@ -836,7 +843,8 @@ class BitmapTextureImpl : public Texture<Float, Spectrum> {
836
843
else
837
844
lum = luminance (col);
838
845
839
- *out++ = lum;
846
+ if (init_distr)
847
+ *out++ = lum;
840
848
mean += lum;
841
849
range_issue |= lum < 0 || lum > 1 ;
842
850
}
@@ -848,11 +856,16 @@ class BitmapTextureImpl : public Texture<Float, Spectrum> {
848
856
values = tensor.array ();
849
857
} else {
850
858
StoredScalar* ptr = (StoredScalar*) tensor.data ();
851
- ScalarFloat *out = values.data (), mean = 0 ;
859
+ ScalarFloat *out = nullptr , mean = 0 ;
860
+ if (init_distr) {
861
+ values = dr::empty<FloatStorage>(pixel_count);
862
+ out = values.data ();
863
+ }
852
864
for (size_t i = 0 ; i < pixel_count; ++i) {
853
865
ScalarFloat value = ptr[i];
854
- *out++ = value;
855
- m_mean += value;
866
+ if (init_distr)
867
+ *out++ = value;
868
+ mean += value;
856
869
range_issue |= value < 0 || value > 1 ;
857
870
}
858
871
m_mean = mean / pixel_count;
0 commit comments