-
Notifications
You must be signed in to change notification settings - Fork 638
Add support for CWT operator #4860
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
mwdowski
wants to merge
25
commits into
NVIDIA:main
Choose a base branch
from
mwdowski:main
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from all commits
Commits
Show all changes
25 commits
Select commit
Hold shift + click to select a range
937b963
add MotherWavelet helper and WaveletGpu kernel
cf7b6a6
Cwt WIP
mwdowski 68bb330
Merge branch 'NVIDIA:main' into wavelet-computing
kubo11 9d6e0b0
Merge pull request #2 from mwdowski/wavelet-computing
mwdowski 359d79c
Merge pull request #1 from mwdowski/mwdowski
mwdowski b034619
Rename namespace
mwdowski 6bb49f5
Merge branch 'main' into mwdowski
mwdowski 5eed0c5
add WaveletArgs class
09196c6
Merge pull request #3 from mwdowski/wavelet-computing
kubo11 279e61b
Improve wavelet computing kernel
c4814f9
Optimize and remove discrete wavelets
11df6aa
Merge pull request #4 from mwdowski/wavelet-computing-improvements
kubo11 d3a8d6a
add DALIWaveletName enum
27cedd3
fix linting errors
2875c95
replace MeyerWavelet with GaussianWavelet
20d5d7e
Merge pull request #5 from mwdowski/wavelet-computing-improvements
kubo11 0efec3d
Fix wavelet exceptions
1ed22bc
Add CWT operator docstr
3c36192
Merge pull request #6 from mwdowski/wavelet-fixes
kubo11 1cdc5e7
WIP
mwdowski e99099e
Merge branch 'NVIDIA:main' into main
mwdowski 15ce332
Merge branch 'main' into mwdowski2
mwdowski 101efc4
Good size but full of zeros
mwdowski 276f87e
WIP
mwdowski 1849a30
Merge pull request #7 from mwdowski/mwdowski2
mwdowski File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
collect_headers(DALI_INST_HDRS PARENT_SCOPE) | ||
collect_sources(DALI_KERNEL_SRCS PARENT_SCOPE) | ||
collect_test_sources(DALI_KERNEL_TEST_SRCS PARENT_SCOPE) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#ifndef DALI_KERNELS_SIGNAL_WAVELET_CWT_ARGS_H_ | ||
#define DALI_KERNELS_SIGNAL_WAVELET_CWT_ARGS_H_ | ||
|
||
#include <vector> | ||
#include "dali/operators/signal/wavelet/wavelet_name.h" | ||
|
||
namespace dali { | ||
namespace kernels { | ||
namespace signal { | ||
|
||
template <typename T = float> | ||
struct CwtArgs { | ||
std::vector<T> a; | ||
dali::DALIWaveletName wavelet; | ||
std::vector<T> wavelet_args; | ||
}; | ||
|
||
} // namespace signal | ||
} // namespace kernels | ||
} // namespace dali | ||
|
||
#endif // DALI_KERNELS_SIGNAL_WAVELET_CWT_ARGS_H_ |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
// Copyright (c) 2020-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include <cmath> | ||
#include <complex> | ||
#include <vector> | ||
#include "dali/core/common.h" | ||
#include "dali/core/error_handling.h" | ||
#include "dali/core/format.h" | ||
#include "dali/kernels/kernel.h" | ||
#include "dali/kernels/signal/wavelet/cwt_args.h" | ||
#include "dali/kernels/signal/wavelet/cwt_gpu.h" | ||
|
||
namespace dali { | ||
namespace kernels { | ||
namespace signal { | ||
|
||
template <typename T> | ||
struct SampleDesc { | ||
const T *in = nullptr; | ||
T *out = nullptr; | ||
int64_t size = 0; | ||
}; | ||
|
||
template <typename T> | ||
__global__ void CwtKernel(const SampleDesc<T> *sample_data) { | ||
const int64_t block_size = blockDim.y * blockDim.x; | ||
const int64_t grid_size = gridDim.x * block_size; | ||
const int sample_idx = blockIdx.y; | ||
const auto sample = sample_data[sample_idx]; | ||
const int64_t offset = block_size * blockIdx.x; | ||
const int64_t tid = threadIdx.y * blockDim.x + threadIdx.x; | ||
|
||
for (int64_t idx = offset + tid; idx < sample.size; idx += grid_size) { | ||
sample.out[idx] = sample.in[idx]; | ||
} | ||
} | ||
|
||
template <typename T> | ||
CwtGpu<T>::~CwtGpu() = default; | ||
|
||
template <typename T> | ||
KernelRequirements CwtGpu<T>::Setup(KernelContext &context, | ||
const InListGPU<T, DynamicDimensions> &in) { | ||
auto out_shape = in.shape; | ||
const size_t num_samples = in.size(); | ||
ScratchpadEstimator se; | ||
se.add<mm::memory_kind::host, SampleDesc<T>>(num_samples); | ||
se.add<mm::memory_kind::device, SampleDesc<T>>(num_samples); | ||
KernelRequirements req; | ||
req.scratch_sizes = se.sizes; | ||
req.output_shapes = {out_shape}; | ||
return req; | ||
} | ||
|
||
template <typename T> | ||
void CwtGpu<T>::Run(KernelContext &context, const OutListGPU<T, DynamicDimensions> &out, | ||
const InListGPU<T, DynamicDimensions> &in, const CwtArgs<T> &args) { | ||
auto num_samples = in.size(); | ||
auto *sample_data = context.scratchpad->AllocateHost<SampleDesc<T>>(num_samples); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same comment as dali/kernels/signal/wavelet/wavelet_gpu.cu |
||
|
||
for (int i = 0; i < num_samples; i++) { | ||
auto &sample = sample_data[i]; | ||
sample.out = out.tensor_data(i); | ||
sample.in = in.tensor_data(i); | ||
sample.size = volume(in.tensor_shape(i)); | ||
assert(sample.size == volume(out.tensor_shape(i))); | ||
} | ||
|
||
auto *sample_data_gpu = context.scratchpad->AllocateGPU<SampleDesc<T>>(num_samples); | ||
CUDA_CALL(cudaMemcpyAsync(sample_data_gpu, sample_data, num_samples * sizeof(SampleDesc<T>), | ||
cudaMemcpyHostToDevice, context.gpu.stream)); | ||
|
||
dim3 block(32, 32); | ||
auto blocks_per_sample = std::max(32, 1024 / num_samples); | ||
dim3 grid(blocks_per_sample, num_samples); | ||
CwtKernel<T><<<grid, block, 0, context.gpu.stream>>>(sample_data_gpu); | ||
} | ||
|
||
template class CwtGpu<float>; | ||
template class CwtGpu<double>; | ||
|
||
} // namespace signal | ||
} // namespace kernels | ||
} // namespace dali |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#ifndef DALI_KERNELS_SIGNAL_WAVELET_CWT_GPU_H_ | ||
#define DALI_KERNELS_SIGNAL_WAVELET_CWT_GPU_H_ | ||
|
||
#include <memory> | ||
#include "dali/core/common.h" | ||
#include "dali/core/error_handling.h" | ||
#include "dali/core/format.h" | ||
#include "dali/core/util.h" | ||
#include "dali/kernels/kernel.h" | ||
#include "dali/kernels/signal/wavelet/cwt_args.h" | ||
|
||
namespace dali { | ||
namespace kernels { | ||
namespace signal { | ||
|
||
template <typename T = float> | ||
class DLL_PUBLIC CwtGpu { | ||
public: | ||
static_assert(std::is_floating_point<T>::value, "Only floating point types are supported"); | ||
|
||
DLL_PUBLIC ~CwtGpu(); | ||
|
||
DLL_PUBLIC KernelRequirements Setup(KernelContext &context, | ||
const InListGPU<T, DynamicDimensions> &in); | ||
|
||
DLL_PUBLIC void Run(KernelContext &context, const OutListGPU<T, DynamicDimensions> &out, | ||
const InListGPU<T, DynamicDimensions> &in, const CwtArgs<T> &args); | ||
}; | ||
|
||
} // namespace signal | ||
} // namespace kernels | ||
} // namespace dali | ||
|
||
#endif // DALI_KERNELS_SIGNAL_WAVELET_CWT_GPU_H_ |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
// Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include <cmath> | ||
#include <vector> | ||
#include "dali/kernels/signal/wavelet/mother_wavelet.cuh" | ||
#include "dali/core/math_util.h" | ||
|
||
namespace dali { | ||
namespace kernels { | ||
namespace signal { | ||
|
||
template <typename T> | ||
HaarWavelet<T>::HaarWavelet(const std::vector<T> &args) { | ||
if (args.size() != 0) { | ||
throw std::invalid_argument("HaarWavelet doesn't accept any arguments."); | ||
} | ||
} | ||
|
||
template <typename T> | ||
__device__ T HaarWavelet<T>::operator()(const T &t) const { | ||
if (0.0 <= t && t < 0.5) { | ||
return 1.0; | ||
} | ||
if (0.5 <= t && t < 1.0) { | ||
return -1.0; | ||
} | ||
return 0.0; | ||
} | ||
|
||
template class HaarWavelet<float>; | ||
template class HaarWavelet<double>; | ||
|
||
template <typename T> | ||
GaussianWavelet<T>::GaussianWavelet(const std::vector<T> &args) { | ||
if (args.size() != 1) { | ||
throw std::invalid_argument("GaussianWavelet accepts exactly one argument - n."); | ||
} | ||
if (args[0] < 1.0 || args[0] > 8.0) { | ||
throw std::invalid_argument( | ||
"GaussianWavelet's argument n should be integer from range [1,8]."); | ||
} | ||
this->n = args[0]; | ||
} | ||
|
||
template <typename T> | ||
__device__ T GaussianWavelet<T>::operator()(const T &t) const { | ||
T expTerm = std::exp(-std::pow(t, 2.0)); | ||
T sqrtTerm = 1.2533141373155001; // std::sqrt(M_PI/2.0) | ||
switch (static_cast<int>(n)) { | ||
case 1: | ||
JanuszL marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return -2.0*t*expTerm/std::sqrt(sqrtTerm); | ||
case 2: | ||
return (-4.0*std::pow(t, 2.0)+2.0)*expTerm/std::sqrt(3.0*sqrtTerm); | ||
case 3: | ||
return (8.0*std::pow(t, 3.0)-12.0*t)*expTerm/std::sqrt(15.0*sqrtTerm); | ||
case 4: | ||
return (-48.0*std::pow(t, 2.0)+16.0*std::pow(t, 4.0)+12.0)*expTerm/std::sqrt(105.0*sqrtTerm); | ||
case 5: | ||
return (-32.0*std::pow(t, 5.0)+160.0*std::pow(t, 3.0)-120.0*t)* | ||
expTerm/std::sqrt(945.0*sqrtTerm); | ||
case 6: | ||
return (-64.0*std::pow(t, 6.0)+480.0*std::pow(t, 4.0)-720.0*std::pow(t, 2.0)+120.0)* | ||
expTerm/std::sqrt(10395.0*sqrtTerm); | ||
case 7: | ||
return (128.0*std::pow(t, 7.0)-1344.0*std::pow(t, 5.0)+3360.0*std::pow(t, 3.0)-1680.0*t)* | ||
expTerm/std::sqrt(135135.0*sqrtTerm); | ||
case 8: | ||
return (256.0*std::pow(t, 8.0)-3584.0*std::pow(t, 6.0)+13440.0*std::pow(t, 4.0)-13440.0* | ||
std::pow(t, 2.0)+1680.0)*expTerm/std::sqrt(2027025.0*sqrtTerm); | ||
} | ||
} | ||
|
||
template class GaussianWavelet<float>; | ||
template class GaussianWavelet<double>; | ||
|
||
template <typename T> | ||
MexicanHatWavelet<T>::MexicanHatWavelet(const std::vector<T> &args) { | ||
if (args.size() != 1) { | ||
throw std::invalid_argument("MexicanHatWavelet accepts exactly one argument - sigma."); | ||
} | ||
this->sigma = args[0]; | ||
} | ||
|
||
template <typename T> | ||
__device__ T MexicanHatWavelet<T>::operator()(const T &t) const { | ||
return 2.0/(std::sqrt(3.0*sigma)*std::pow(M_PI, 0.25))*(1.0-std::pow(t/sigma, 2.0))* | ||
std::exp(-std::pow(t, 2.0)/(2.0*std::pow(sigma, 2.0))); | ||
} | ||
|
||
template class MexicanHatWavelet<float>; | ||
template class MexicanHatWavelet<double>; | ||
|
||
template <typename T> | ||
MorletWavelet<T>::MorletWavelet(const std::vector<T> &args) { | ||
if (args.size() != 0) { | ||
throw std::invalid_argument("MorletWavelet doesn't accept any arguments."); | ||
} | ||
} | ||
|
||
template <typename T> | ||
__device__ T MorletWavelet<T>::operator()(const T &t) const { | ||
return std::exp(-std::pow(t, 2.0) / 2.0) * std::cos(5.0 * t); | ||
} | ||
|
||
template class MorletWavelet<float>; | ||
template class MorletWavelet<double>; | ||
|
||
template <typename T> | ||
ShannonWavelet<T>::ShannonWavelet(const std::vector<T> &args) { | ||
if (args.size() != 2) { | ||
throw std::invalid_argument( | ||
"ShannonWavelet accepts exactly 2 arguments -> fb, fc in that order."); | ||
} | ||
this->fb = args[0]; | ||
this->fc = args[1]; | ||
} | ||
|
||
template <typename T> | ||
__device__ T ShannonWavelet<T>::operator()(const T &t) const { | ||
auto res = std::cos((T)(2.0*M_PI)*fc*t)*std::sqrt(fb); | ||
return t == 0.0 ? res : res*std::sin(t*fb*(T)(M_PI))/(t*fb*(T)(M_PI)); | ||
} | ||
|
||
template class ShannonWavelet<float>; | ||
template class ShannonWavelet<double>; | ||
|
||
template <typename T> | ||
FbspWavelet<T>::FbspWavelet(const std::vector<T> &args) { | ||
if (args.size() != 3) { | ||
throw std::invalid_argument( | ||
"FbspWavelet accepts exactly 3 arguments -> m, fb, fc in that order."); | ||
} | ||
this->m = args[0]; | ||
this->fb = args[1]; | ||
this->fc = args[2]; | ||
} | ||
|
||
template <typename T> | ||
__device__ T FbspWavelet<T>::operator()(const T &t) const { | ||
auto res = std::cos((T)(2.0*M_PI)*fc*t)*std::sqrt(fb); | ||
return t == 0.0 ? res : res*std::pow(std::sin((T)(M_PI)*t*fb/m)/((T)(M_PI)*t*fb/m), m); | ||
} | ||
|
||
template class FbspWavelet<float>; | ||
template class FbspWavelet<double>; | ||
|
||
} // namespace signal | ||
} // namespace kernels | ||
} // namespace dali |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.