Skip to content

Commit

Permalink
added channel strides
Browse files Browse the repository at this point in the history
  • Loading branch information
dusty-nv committed Jun 13, 2022
1 parent 4968712 commit 82840af
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 18 deletions.
33 changes: 17 additions & 16 deletions c/tensorConvert.cu
Original file line number Diff line number Diff line change
Expand Up @@ -177,17 +177,15 @@ cudaError_t cudaTensorNormBGR( void* input, imageFormat format, size_t inputWidt

// gpuTensorNormMean
template<typename T, bool isBGR>
__global__ void gpuTensorNormMean( T* input, int iWidth, float* output, int oWidth, int oHeight, float2 scale, float multiplier, float min_value, const float3 mean, const float3 stdDev )
__global__ void gpuTensorNormMean( T* input, int iWidth, float* output, int oWidth, int oHeight, int stride, float2 scale, float multiplier, float min_value, const float3 mean, const float3 stdDev )
{
const int x = blockIdx.x * blockDim.x + threadIdx.x;
const int y = blockIdx.y * blockDim.y + threadIdx.y;

if( x >= oWidth || y >= oHeight )
return;

const int n = oWidth * oHeight;
const int m = y * oWidth + x;

const int m = y * oWidth + x;
const int dx = ((float)x * scale.x);
const int dy = ((float)y * scale.y);

Expand All @@ -196,23 +194,26 @@ __global__ void gpuTensorNormMean( T* input, int iWidth, float* output, int oWid
const float3 rgb = isBGR ? make_float3(px.z, px.y, px.x)
: make_float3(px.x, px.y, px.z);

output[n * 0 + m] = ((rgb.x * multiplier + min_value) - mean.x) / stdDev.x;
output[n * 1 + m] = ((rgb.y * multiplier + min_value) - mean.y) / stdDev.y;
output[n * 2 + m] = ((rgb.z * multiplier + min_value) - mean.z) / stdDev.z;
output[stride * 0 + m] = ((rgb.x * multiplier + min_value) - mean.x) / stdDev.x;
output[stride * 1 + m] = ((rgb.y * multiplier + min_value) - mean.y) / stdDev.y;
output[stride * 2 + m] = ((rgb.z * multiplier + min_value) - mean.z) / stdDev.z;
}

template<bool isBGR>
cudaError_t launchTensorNormMean( void* input, imageFormat format, size_t inputWidth, size_t inputHeight,
float* output, size_t outputWidth, size_t outputHeight,
const float2& range, const float3& mean, const float3& stdDev,
cudaStream_t stream )
cudaStream_t stream, size_t channelStride )
{
if( !input || !output )
return cudaErrorInvalidDevicePointer;

if( inputWidth == 0 || outputWidth == 0 || inputHeight == 0 || outputHeight == 0 )
return cudaErrorInvalidValue;

if( channelStride == 0 )
channelStride = outputWidth * outputHeight;

const float2 scale = make_float2( float(inputWidth) / float(outputWidth),
float(inputHeight) / float(outputHeight) );

Expand All @@ -223,13 +224,13 @@ cudaError_t launchTensorNormMean( void* input, imageFormat format, size_t inputW
const dim3 gridDim(iDivUp(outputWidth,blockDim.x), iDivUp(outputHeight,blockDim.y));

if( format == IMAGE_RGB8 )
gpuTensorNormMean<uchar3, isBGR><<<gridDim, blockDim, 0, stream>>>((uchar3*)input, inputWidth, output, outputWidth, outputHeight, scale, multiplier, range.x, mean, stdDev);
gpuTensorNormMean<uchar3, isBGR><<<gridDim, blockDim, 0, stream>>>((uchar3*)input, inputWidth, output, outputWidth, outputHeight, channelStride, scale, multiplier, range.x, mean, stdDev);
else if( format == IMAGE_RGBA8 )
gpuTensorNormMean<uchar4, isBGR><<<gridDim, blockDim, 0, stream>>>((uchar4*)input, inputWidth, output, outputWidth, outputHeight, scale, multiplier, range.x, mean, stdDev);
gpuTensorNormMean<uchar4, isBGR><<<gridDim, blockDim, 0, stream>>>((uchar4*)input, inputWidth, output, outputWidth, outputHeight, channelStride, scale, multiplier, range.x, mean, stdDev);
else if( format == IMAGE_RGB32F )
gpuTensorNormMean<float3, isBGR><<<gridDim, blockDim, 0, stream>>>((float3*)input, inputWidth, output, outputWidth, outputHeight, scale, multiplier, range.x, mean, stdDev);
gpuTensorNormMean<float3, isBGR><<<gridDim, blockDim, 0, stream>>>((float3*)input, inputWidth, output, outputWidth, outputHeight, channelStride, scale, multiplier, range.x, mean, stdDev);
else if( format == IMAGE_RGBA32F )
gpuTensorNormMean<float4, isBGR><<<gridDim, blockDim, 0, stream>>>((float4*)input, inputWidth, output, outputWidth, outputHeight, scale, multiplier, range.x, mean, stdDev);
gpuTensorNormMean<float4, isBGR><<<gridDim, blockDim, 0, stream>>>((float4*)input, inputWidth, output, outputWidth, outputHeight, channelStride, scale, multiplier, range.x, mean, stdDev);
else
return cudaErrorInvalidValue;

Expand All @@ -240,18 +241,18 @@ cudaError_t launchTensorNormMean( void* input, imageFormat format, size_t inputW
cudaError_t cudaTensorNormMeanRGB( void* input, imageFormat format, size_t inputWidth, size_t inputHeight,
float* output, size_t outputWidth, size_t outputHeight,
const float2& range, const float3& mean, const float3& stdDev,
cudaStream_t stream )
cudaStream_t stream, size_t channelStride )
{
return launchTensorNormMean<false>(input, format, inputWidth, inputHeight, output, outputWidth, outputHeight, range, mean, stdDev, stream);
return launchTensorNormMean<false>(input, format, inputWidth, inputHeight, output, outputWidth, outputHeight, range, mean, stdDev, stream, channelStride );
}

// cudaTensorNormMeanRGB
cudaError_t cudaTensorNormMeanBGR( void* input, imageFormat format, size_t inputWidth, size_t inputHeight,
float* output, size_t outputWidth, size_t outputHeight,
const float2& range, const float3& mean, const float3& stdDev,
cudaStream_t stream )
cudaStream_t stream, size_t channelStride )
{
return launchTensorNormMean<true>(input, format, inputWidth, inputHeight, output, outputWidth, outputHeight, range, mean, stdDev, stream);
return launchTensorNormMean<true>(input, format, inputWidth, inputHeight, output, outputWidth, outputHeight, range, mean, stdDev, stream, channelStride);
}


4 changes: 2 additions & 2 deletions c/tensorConvert.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ cudaError_t cudaTensorNormBGR( void* input, imageFormat format, size_t inputWidt
/*
* Downsample and apply pixel normalization, mean pixel subtraction and standard deviation, NCHW format
*/
cudaError_t cudaTensorNormMeanRGB( void* input, imageFormat format, size_t inputWidth, size_t inputHeight, float* output, size_t outputWidth, size_t outputHeight, const float2& range, const float3& mean, const float3& stdDev, cudaStream_t stream );
cudaError_t cudaTensorNormMeanBGR( void* input, imageFormat format, size_t inputWidth, size_t inputHeight, float* output, size_t outputWidth, size_t outputHeight, const float2& range, const float3& mean, const float3& stdDev, cudaStream_t stream );
cudaError_t cudaTensorNormMeanRGB( void* input, imageFormat format, size_t inputWidth, size_t inputHeight, float* output, size_t outputWidth, size_t outputHeight, const float2& range, const float3& mean, const float3& stdDev, cudaStream_t stream, size_t channelStride=0 );
cudaError_t cudaTensorNormMeanBGR( void* input, imageFormat format, size_t inputWidth, size_t inputHeight, float* output, size_t outputWidth, size_t outputHeight, const float2& range, const float3& mean, const float3& stdDev, cudaStream_t stream, size_t channelStride=0 );


#endif
Expand Down

0 comments on commit 82840af

Please sign in to comment.