diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 7e39a9e16..8e4efac5a 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -1,5 +1,7 @@ steps: - label: ":julia: Julia {{matrix.julia}} + CUDA GPU" + command: + - echo 'CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"' >> test/Project.toml plugins: - JuliaCI/julia#v1: version: "{{matrix.julia}}" @@ -14,15 +16,16 @@ steps: queue: "juliagpu" cuda: "*" env: + JULIA_NUM_THREADS: 4 NNLIB_TEST_CUDA: "true" - NNLIB_TEST_CPU: "false" + NNLIB_TEST_CPU: "true" # Could be useful to uncover multithreading related issues + # Buildkite workers have more threads. if: build.message !~ /\[skip tests\]/ timeout_in_minutes: 180 matrix: setup: julia: - # - "1.9" # uncomment when 1.10 is out - - "1" + - "1.10" - "nightly" adjustments: - with: @@ -32,8 +35,9 @@ steps: - label: ":julia: Julia 1 + AMD GPU" plugins: - JuliaCI/julia#v1: - version: "1" + version: "1.10" - JuliaCI/julia-test#v1: + test_args: "--quickfail" - JuliaCI/julia-coverage#v1: codecov: true dirs: @@ -49,8 +53,7 @@ steps: JULIA_AMDGPU_HIP_MUST_LOAD: "1" JULIA_AMDGPU_DISABLE_ARTIFACTS: "1" NNLIB_TEST_AMDGPU: "true" - NNLIB_TEST_CPU: "true" # Could be useful to uncover multithreading related issues - # Buildkite workers have more threads. + NNLIB_TEST_CPU: "false" JULIA_NUM_THREADS: 4 - label: "Benchmarks" diff --git a/.github/workflows/Downstream.yml b/.github/workflows/Downstream.yml index c8e63286b..9d18ec920 100644 --- a/.github/workflows/Downstream.yml +++ b/.github/workflows/Downstream.yml @@ -5,6 +5,11 @@ on: tags: [v*] pull_request: +# needed to allow julia-actions/cache to delete old caches that it has created +permissions: + actions: write + contents: read + jobs: test: name: ${{ matrix.package.repo }}/${{ matrix.package.group }} @@ -19,24 +24,14 @@ jobs: package: - {user: FluxML, repo: Flux.jl, group: All} - {user: FluxML, repo: Tracker.jl, group: All} - - {user: denizyuret, repo: Knet.jl, group: All} - - {user: dfdx, repo: Avalon.jl, group: All} - - {user: JuliaOptimalTransport, repo: OptimalTransport.jl, group: All} - - {user: avik-pal, repo: Lux.jl, group: All} + - {user: LuxDL, repo: Lux.jl, group: All} steps: - uses: actions/checkout@v3 - # for OptimalTransport.jl - - name: Install python - uses: actions/setup-python@v4 - with: - python-version: '3.9' - architecture: ${{ matrix.arch }} - # for OptimalTransport.jl - - run: python -m pip install pot - uses: julia-actions/setup-julia@v1 with: version: ${{ matrix.julia-version }} arch: x64 + - uses: julia-actions/cache@v1 - uses: julia-actions/julia-buildpkg@latest - name: Clone Downstream uses: actions/checkout@v3 @@ -60,3 +55,7 @@ jobs: @info "Not compatible with this release. No problem." exception=err exit(0) # Exit immediately, as a success end + env: + RETESTITEMS_NWORKERS: 4 + BACKEND_GROUP: CPU # for Lux.jl + diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index dac08e745..f2cf0328b 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -9,6 +9,11 @@ on: tags: '*' pull_request: +# needed to allow julia-actions/cache to delete old caches that it has created +permissions: + actions: write + contents: read + defaults: run: shell: bash @@ -24,8 +29,9 @@ jobs: matrix: version: # - '1.9' # uncomment when julia 1.10 is out + - 'lts' - '1' # automatically expands to the latest stable 1.x release of Julia - - 'nightly' + - 'pre' os: - ubuntu-latest # - macOS-latest @@ -51,20 +57,11 @@ jobs: steps: - uses: actions/checkout@v3 - - uses: julia-actions/setup-julia@v1 + - uses: julia-actions/setup-julia@v2 with: version: ${{ matrix.version }} arch: ${{ matrix.arch }} - - uses: actions/cache@v3 - env: - cache-name: cache-artifacts - with: - path: ~/.julia/artifacts - key: ${{ runner.os }}-test-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }} - restore-keys: | - ${{ runner.os }}-test-${{ env.cache-name }}- - ${{ runner.os }}-test- - ${{ runner.os }}- + - uses: julia-actions/cache@v1 - uses: julia-actions/julia-buildpkg@v1 - name: "Run test without coverage" @@ -88,22 +85,15 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - uses: julia-actions/setup-julia@v1 + - uses: julia-actions/setup-julia@v2 with: version: '1.9' + - uses: julia-actions/cache@v1 - run: | julia --project=docs -e ' using Pkg Pkg.develop(PackageSpec(path=pwd())) Pkg.instantiate()' - - run: | - julia --color=yes --project=docs/ -e ' - using NNlib - # using Pkg; Pkg.activate("docs") - using Documenter - using Documenter: doctest - DocMeta.setdocmeta!(NNlib, :DocTestSetup, :(using NNlib); recursive=true) - doctest(NNlib)' - run: julia --project=docs docs/make.jl env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/Project.toml b/Project.toml index b33b90564..03989643a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "NNlib" uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -version = "0.9.6" +version = "0.9.30" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" @@ -9,49 +9,43 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -Requires = "ae029012-a4dd-5104-9daa-d747884805df" +ScopedValues = "7e506255-f358-4e82-b7e4-beb19740aa63" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [weakdeps] AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" -cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" +FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" +cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [extensions] NNlibAMDGPUExt = "AMDGPU" -NNlibCUDAExt = "CUDA" NNlibCUDACUDNNExt = ["CUDA", "cuDNN"] +NNlibCUDAExt = "CUDA" +NNlibEnzymeCoreExt = "EnzymeCore" +NNlibFFTWExt = "FFTW" +NNlibForwardDiffExt = "ForwardDiff" +NNlibSpecialFunctionsExt = "SpecialFunctions" [compat] -AMDGPU = "0.5, 0.6" -Adapt = "3.2" -Atomix = "0.1" -ChainRulesCore = "1.13" +AMDGPU = "1" +Adapt = "3.2, 4" +Atomix = "0.1, 1" CUDA = "4, 5" -cuDNN = "1" -GPUArraysCore = "0.1" +ChainRulesCore = "1.25" +EnzymeCore = "0.5, 0.6, 0.7, 0.8" +FFTW = "1.8.0" +ForwardDiff = "0.10.36, 1" +GPUArraysCore = "0.1, 0.2" KernelAbstractions = "0.9.2" -Requires = "1.0" +LinearAlgebra = "<0.0.1, 1" +Random = "<0.0.1, 1" +ScopedValues = "1.3.0" +SpecialFunctions = "2" +Statistics = "1" +cuDNN = "1" julia = "1.9" - -[extras] -AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" -CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" -ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" -Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" -FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" -ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" -Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" -ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" -StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" -UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" -cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" - -[targets] -test = ["AMDGPU", "CUDA", "ChainRulesTestUtils", "Documenter", - "FiniteDifferences", "ForwardDiff", "Logging", "ReverseDiff", - "StableRNGs", "Test", "UnicodePlots", "Zygote", "cuDNN"] diff --git a/README.md b/README.md index 3892a2cc1..1f29d25c4 100644 --- a/README.md +++ b/README.md @@ -25,4 +25,3 @@ for CUDA support, or using NNlib, AMDGPU ``` for AMDGPU support. - diff --git a/benchmark/perf_report.jl b/benchmark/perf_report.jl index 5c06515eb..9b861e869 100644 --- a/benchmark/perf_report.jl +++ b/benchmark/perf_report.jl @@ -37,10 +37,6 @@ for rank in (2,), (NNlib.depthwiseconv_im2col!, NNlib.∇depthwiseconv_data_im2col!, NNlib.∇depthwiseconv_filter_im2col!, DepthwiseConvDims, "im2col"), ] - if NNlib.is_nnpack_available() - push!(benchmark_items, (NNlib.conv_nnpack!, NNlib.∇conv_data_nnpack!, NNlib.∇conv_filter_nnpack!, DenseConvDims, "nnpack")) - end - for (conv!, ∇conv_data!, ∇conv_filter!, cT, backend) in benchmark_items x = zeros(Float32, repeat([N], rank)..., C_in, 1) @@ -105,15 +101,4 @@ for rank in (2,), @show(pdims) @save "results.jld2" results end - - if NNlib.is_nnpack_available() - if NNlib.nnpack_supported_operation(pdims) - t_fwd = @benchmark NNlib.maxpool_nnpack!($y, $x, $pdims) - - add_result(t_fwd, "maxpool2d", "nnpack", pdims) - - @show(pdims) - @save "results.jld2" results - end - end end diff --git a/docs/.gitignore b/docs/.gitignore index a303fff20..b71a83fb6 100644 --- a/docs/.gitignore +++ b/docs/.gitignore @@ -1,2 +1,3 @@ build/ site/ +Manifest.toml diff --git a/docs/Project.toml b/docs/Project.toml index 3a52a5db2..9de5539c2 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,5 +1,9 @@ [deps] +CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" - -[compat] -Documenter = "0.27" +FLAC = "abae9e3b-a9a0-4778-b5c6-ca109b507d99" +FileIO = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549" +Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a" +NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228" +FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" diff --git a/docs/make.jl b/docs/make.jl index a12937ca0..4bffca944 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,18 +1,21 @@ using Documenter, NNlib -DocMeta.setdocmeta!(NNlib, :DocTestSetup, :(using NNlib); recursive = true) +DocMeta.setdocmeta!(NNlib, :DocTestSetup, + :(using FFTW, NNlib, UnicodePlots); recursive = true) makedocs(modules = [NNlib], - sitename = "NNlib.jl", - doctest = false, - pages = ["Home" => "index.md", - "Reference" => "reference.md"], - format = Documenter.HTML( - canonical = "https://fluxml.ai/NNlib.jl/stable/", - # analytics = "UA-36890222-9", - assets = ["assets/flux.css"], - prettyurls = get(ENV, "CI", nothing) == "true"), - ) + sitename = "NNlib.jl", + doctest = true, + pages = ["Home" => "index.md", + "Reference" => "reference.md", + "Audio" => "audio.md"], + format = Documenter.HTML( + canonical = "https://fluxml.ai/NNlib.jl/stable/", + # analytics = "UA-36890222-9", + assets = ["assets/flux.css"], + prettyurls = get(ENV, "CI", nothing) == "true"), + warnonly=[:missing_docs,] +) deploydocs(repo = "github.com/FluxML/NNlib.jl.git", target = "build", diff --git a/docs/src/assets/jfk.flac b/docs/src/assets/jfk.flac new file mode 100644 index 000000000..24841d55a Binary files /dev/null and b/docs/src/assets/jfk.flac differ diff --git a/docs/src/audio.md b/docs/src/audio.md new file mode 100644 index 000000000..e56a5bf43 --- /dev/null +++ b/docs/src/audio.md @@ -0,0 +1,61 @@ +# Reference + +!!! note + Spectral functions require importing `FFTW` package to enable them. + +## Window functions + +```@docs +hann_window +hamming_window +``` + +## Spectral + +```@docs +stft +istft +NNlib.power_to_db +NNlib.db_to_power +``` + +## Spectrogram + +```@docs +melscale_filterbanks +spectrogram +``` + +Example: + +```@example 1 +using FFTW # <- required for STFT support. +using NNlib +using FileIO +using Makie, CairoMakie +CairoMakie.activate!() + +waveform, sampling_rate = load("./assets/jfk.flac") +fig = lines(reshape(waveform, :)) +save("waveform.png", fig) + +# Spectrogram. + +n_fft = 1024 +spec = spectrogram(waveform; n_fft, hop_length=n_fft ÷ 4, window=hann_window(n_fft)) +fig = heatmap(transpose(NNlib.power_to_db(spec)[:, :, 1])) +save("spectrogram.png", fig) + +# Mel-scale spectrogram. + +n_freqs = n_fft ÷ 2 + 1 +fb = melscale_filterbanks(; n_freqs, n_mels=128, sample_rate=Int(sampling_rate)) +mel_spec = permutedims(spec, (2, 1, 3)) ⊠ fb # (time, n_mels) +fig = heatmap(NNlib.power_to_db(mel_spec)[:, :, 1]) +save("mel-spectrogram.png", fig) +nothing # hide +``` + +|Waveform|Spectrogram|Mel Spectrogram| +|:---:|:---:|:---:| +|![](waveform.png)|![](spectrogram.png)|![](mel-spectrogram.png)| diff --git a/docs/src/index.md b/docs/src/index.md index 46958da1b..78d09df4e 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -12,4 +12,13 @@ for CUDA support, or ```julia using NNlib, AMDGPU ``` -for AMDGPU support. \ No newline at end of file +for AMDGPU support. + +## Threading + +Various `NNlib` functions utilize available julia threads on divisible workloads. To disable this use +the `ScopedValue`-backed switch `NNlib.@disallow_spawns` +i.e. +```julia +NNlib.@disallow_spawns function_that_uses_nnlib() +``` diff --git a/docs/src/reference.md b/docs/src/reference.md index c01db6b24..1b1f7827d 100644 --- a/docs/src/reference.md +++ b/docs/src/reference.md @@ -10,6 +10,8 @@ Non-linearities that go between layers of your model. Note that, unless otherwis celu elu gelu +gelu_tanh +gelu_erf hardsigmoid sigmoid_fast hardtanh @@ -78,10 +80,11 @@ pad_zeros `NNlib.conv` supports complex datatypes on CPU and CUDA devices. -!!! AMDGPU MIOpen supports only cross-correlation (flipkernel=true). - Therefore for every regular convolution (flipkernel=false) +!!! note "AMDGPU MIOpen supports only cross-correlation (`flipkernel=true`)." + + Therefore for every regular convolution (`flipkernel=false`) kernel is flipped before calculation. - For better performance, use cross-correlation (flipkernel=true) + For better performance, use cross-correlation (`flipkernel=true`) and manually flip the kernel before `NNlib.conv` call. `Flux` handles this automatically, this is only required for direct calls. @@ -111,6 +114,14 @@ upsample_trilinear pixel_shuffle ``` +## Rotation +Rotate images in the first two dimensions of an array. + +```@docs +imrotate +∇imrotate +``` + ## Batched Operations `Flux`'s `Bilinear` layer uses `NNlib.batched_mul` internally. diff --git a/ext/NNlibAMDGPUExt/NNlibAMDGPUExt.jl b/ext/NNlibAMDGPUExt/NNlibAMDGPUExt.jl index c9f78add1..88a4e26d5 100644 --- a/ext/NNlibAMDGPUExt/NNlibAMDGPUExt.jl +++ b/ext/NNlibAMDGPUExt/NNlibAMDGPUExt.jl @@ -53,7 +53,6 @@ include("batched_mul.jl") include("conv.jl") include("pool.jl") - include("softmax.jl") include("activations.jl") else @warn """ diff --git a/ext/NNlibAMDGPUExt/activations.jl b/ext/NNlibAMDGPUExt/activations.jl index 1563bb45e..498cc8a8a 100644 --- a/ext/NNlibAMDGPUExt/activations.jl +++ b/ext/NNlibAMDGPUExt/activations.jl @@ -1,13 +1,13 @@ for (f, op) in [ - NNlib.relu => MIOpen.relu, - NNlib.relu6 => x -> MIOpen.clippedrelu(x, 6), - NNlib.softplus => MIOpen.softrelu, - NNlib.σ => MIOpen.sigmoid, - Base.tanh => MIOpen.tanh, - # TODO define for leakyrelu, elu, etc.? -] + NNlib.relu => MIOpen.relu, + NNlib.relu6 => x -> MIOpen.clippedrelu(x, 6), + NNlib.softplus => MIOpen.softrelu, + NNlib.σ => MIOpen.sigmoid, + Base.tanh => MIOpen.tanh, + # TODO define for leakyrelu, elu, etc.? + ], N in 1:5 @eval function Base.materialize( - bc::Broadcast.Broadcasted{<:Any,<:Any,typeof($f),<:Tuple{ROCArray{<:MIOPENFloat}}} + bc::Broadcast.Broadcasted{<:Any,<:Any,typeof($f),<:Tuple{ROCArray{<:MIOPENFloat,$N}}} ) return $op(bc.args[1]) end diff --git a/ext/NNlibAMDGPUExt/softmax.jl b/ext/NNlibAMDGPUExt/softmax.jl deleted file mode 100644 index de75f9748..000000000 --- a/ext/NNlibAMDGPUExt/softmax.jl +++ /dev/null @@ -1,11 +0,0 @@ -for fname in (:softmax, :logsoftmax) - @eval function NNlib.$(fname)(x::ROCArray{T}; dims = 1) where T <: MIOPENFloat - MIOpen.$(fname)(x; dims) - end - - @eval function NNlib.$(Symbol("∇$(fname)"))( - dy::ROCArray{T, N}, x::ROCArray{T, N}, y::ROCArray{T, N}; dims = 1, - ) where {T <: MIOPENFloat, N} - MIOpen.$(Symbol("∇$(fname)!"))(dy, y; dims) - end -end diff --git a/ext/NNlibCUDACUDNNExt/batchnorm.jl b/ext/NNlibCUDACUDNNExt/batchnorm.jl index 2c38f009e..d74fb3ad8 100644 --- a/ext/NNlibCUDACUDNNExt/batchnorm.jl +++ b/ext/NNlibCUDACUDNNExt/batchnorm.jl @@ -84,7 +84,15 @@ function cudnnBNForward!(y::DenseCuArray{T}, g::DenseCuArray{T}, b::DenseCuArray cache.ivar = ivar end else - cudnnBatchNormalizationForwardInference(handle(), CUDNN_BATCHNORM_SPATIAL, scalingParameter(T, alpha), scalingParameter(T, beta), xd, x, yd, y, gd, g, b, running_mean, running_var, eps) + if track_stats + cudnnBatchNormalizationForwardInference(handle(), CUDNN_BATCHNORM_SPATIAL, scalingParameter(T, alpha), scalingParameter(T, beta), xd, x, yd, y, gd, g, b, running_mean, running_var, eps) + else + # cudnnBatchNormalizationForwardInference does not accept CV_NULL for running_mean + # and running_var. We could calculate mean and var of `x` here, but instead use + # cudnnBatchNormalizationFowardTraining. cudnnBatchNormalizationForwardTraining does + # accept CV_NULL and will calculate mean and var itself. + cudnnBatchNormalizationForwardTraining(handle(), CUDNN_BATCHNORM_SPATIAL, scalingParameter(T, alpha), scalingParameter(T, beta), xd, x, yd, y, gd, g, b, momentum, CU_NULL, CU_NULL, eps, CU_NULL, CU_NULL) + end end return y end diff --git a/ext/NNlibCUDACUDNNExt/conv.jl b/ext/NNlibCUDACUDNNExt/conv.jl index c51ac6dd6..d6947fc48 100644 --- a/ext/NNlibCUDACUDNNExt/conv.jl +++ b/ext/NNlibCUDACUDNNExt/conv.jl @@ -129,7 +129,7 @@ function conv_bias_act!(y::DenseCuArray{T}, x::DenseCuArray{T}, w::DenseCuArray{ activation = (σ == NNlib.relu ? CUDNN_ACTIVATION_RELU : CUDNN_ACTIVATION_IDENTITY) cudnnConvolutionForward!(y, w, x, d; z, bias, activation, alpha, beta) if activation === CUDNN_ACTIVATION_IDENTITY && σ ∉ (nothing, identity) - y = σ.(y) + @. y = σ(y) end return y end diff --git a/ext/NNlibCUDAExt/NNlibCUDAExt.jl b/ext/NNlibCUDAExt/NNlibCUDAExt.jl index 876481886..7939b8086 100644 --- a/ext/NNlibCUDAExt/NNlibCUDAExt.jl +++ b/ext/NNlibCUDAExt/NNlibCUDAExt.jl @@ -9,7 +9,6 @@ include("activations.jl") include("batchedadjtrans.jl") include("batchedmul.jl") include("ctc.jl") -include("fold.jl") include("scatter.jl") include("utils.jl") diff --git a/ext/NNlibCUDAExt/ctc.jl b/ext/NNlibCUDAExt/ctc.jl index 84a319ba8..f51a04a19 100644 --- a/ext/NNlibCUDAExt/ctc.jl +++ b/ext/NNlibCUDAExt/ctc.jl @@ -14,7 +14,7 @@ import NNlib: ctc_loss, ctc_alpha, ∇ctc_loss const MAX_THREADS = 256 -function log_plus_f(p1, p2) +@inline function log_plus_f(p1, p2) isinf(p1) && return p2 isinf(p2) && return p1 if p1 < p2 @@ -229,4 +229,4 @@ function ∇ctc_loss(ŷ::CuArray, y, out) accum = CUDA.fill(log(typed_zero), size(ŷ)) @cuda blocks=1 threads=nThreads compute_beta_and_grad_kernel(ŷ, length(y), T, nRepeats, CuArray(z′), alphas, betas, output, accum, grads, blank, loss) return grads -end \ No newline at end of file +end diff --git a/ext/NNlibCUDAExt/fold.jl b/ext/NNlibCUDAExt/fold.jl deleted file mode 100644 index 240ed10fc..000000000 --- a/ext/NNlibCUDAExt/fold.jl +++ /dev/null @@ -1,111 +0,0 @@ - -function unfold_kernel!(col::AbstractArray{T}, x, col_size, input_size, output_size, kernel_size, flipkernel, stride, pad_lo, dilation, max_idx) where {T} - index = threadIdx().x + (blockIdx().x - 1) * blockDim().x - - @inbounds if index <= max_idx - i, kw, kh, kd, c, b = CartesianIndices(col_size)[index].I # col indices - w, h, d = CartesianIndices(output_size)[i].I # x indices - - # project - w, h, d = @. ((w, h, d) - 1)*stride - pad_lo + 1 + ((kw, kh, kd) - 1)*dilation - - if !flipkernel - kw, kh, kd = kernel_size .- (kw, kh, kd) .+ 1 - end - - # check out of bounds - if !all(checkindex.(Bool, UnitRange.(1, input_size), (w, h, d))) - col[i, kw, kh, kd, c, b] = T(0) - return nothing - end - - xval::T = x[w, h, d, c, b] - col[i, kw, kh, kd, c, b] = xval - end - - return nothing -end - -function fold_kernel!(x::AbstractArray{T}, col, col_size, input_size, output_size, kernel_size, flipkernel, stride, pad_lo, dilation, max_idx) where {T} - index = threadIdx().x + (blockIdx().x - 1) * blockDim().x - - @inbounds if index <= max_idx - i, kw, kh, kd, c, b = CartesianIndices(col_size)[index].I # col indices - w, h, d = CartesianIndices(output_size)[i].I # x indices - - # project - w, h, d = @. ((w, h, d) - 1)*stride - pad_lo + 1 + ((kw, kh, kd) - 1)*dilation - - # check out of bounds - if !all(checkindex.(Bool, UnitRange.(1, input_size), (w, h, d))) - return nothing - end - - if !flipkernel - kw, kh, kd = kernel_size .- (kw, kh, kd) .+ 1 - end - - cval::T = col[i, kw, kh, kd, c, b] - CUDA.@atomic x[w, h, d, c, b] += cval - end - - return nothing -end - -function NNlib.unfold!(col::AnyCuArray{cT,3}, x::AnyCuArray{xT,5}, cdims::NNlib.DenseConvDims) where {cT, xT} - if NNlib.spatial_dims(cdims) != 3 - throw(DimensionMismatch("unfold!() only accepts 3d convoluitional inputs")) - end - - input_size = NNlib.input_size(cdims) - C_in = NNlib.channels_in(cdims) - kernel_size = NNlib.kernel_size(cdims) - pad_w_lo, pad_w_hi, pad_h_lo, pad_h_hi, pad_d_lo, pad_d_hi = NNlib.padding(cdims) - pad_lo = (pad_w_lo, pad_h_lo, pad_d_lo) - dilation = NNlib.dilation(cdims) - stride = NNlib.stride(cdims) - output_size = NNlib.output_size(cdims) - flipkernel = NNlib.flipkernel(cdims) - - col_reshaped = reshape(col, (prod(output_size), kernel_size..., C_in, :)) - - max_idx = prod(size(col)) - args = col_reshaped, x, size(col_reshaped), input_size, output_size, kernel_size, flipkernel, stride, pad_lo, dilation, max_idx - kernel = @cuda launch=false unfold_kernel!(args...) - config = launch_configuration(kernel.fun; max_threads=256) - threads = min(max_idx, config.threads) - blocks = cld(max_idx, threads) - kernel(args...; threads=threads, blocks=blocks) - return col -end - -function NNlib.fold!(x::AnyCuArray{xT,5}, col::AnyCuArray{cT,3}, cdims::NNlib.DenseConvDims) where {xT, cT} - if NNlib.spatial_dims(cdims) != 3 - throw(DimensionMismatch("fold!() only accepts 3d convoluitional inputs")) - end - - # going to accumulate into x - fill!(x, xT(0)) - - input_size = NNlib.input_size(cdims) - C_in = NNlib.channels_in(cdims) - kernel_size = NNlib.kernel_size(cdims) - pad_w_lo, pad_w_hi, pad_h_lo, pad_h_hi, pad_d_lo, pad_d_hi = NNlib.padding(cdims) - pad_lo = (pad_w_lo, pad_h_lo, pad_d_lo) - dilation = NNlib.dilation(cdims) - stride = NNlib.stride(cdims) - output_size = NNlib.output_size(cdims) - flipkernel = NNlib.flipkernel(cdims) - - col_reshaped = reshape(col, (prod(output_size), kernel_size..., C_in, :)) - - max_idx = prod(size(col)) - args = x, col_reshaped, size(col_reshaped), input_size, output_size, kernel_size, flipkernel, stride, pad_lo, dilation, max_idx - kernel = @cuda launch=false fold_kernel!(args...) - config = launch_configuration(kernel.fun; max_threads=256) - threads = min(max_idx, config.threads) - blocks = cld(max_idx, threads) - kernel(args...; threads=threads, blocks=blocks) - return x -end - diff --git a/ext/NNlibCUDAExt/sampling.jl b/ext/NNlibCUDAExt/sampling.jl index 9f7db2f78..1264a1ed1 100644 --- a/ext/NNlibCUDAExt/sampling.jl +++ b/ext/NNlibCUDAExt/sampling.jl @@ -2,7 +2,7 @@ @inbounds CUDA.@atomic dx[ix, iy, c, n] += value end -function grid_sample_kernel!(n_elem, output, input, grid, padding_mode) +function grid_sample_kernel!(n_elem, output::AbstractArray{T, 4}, input::AbstractArray{T, 4}, grid::AbstractArray{V, 4}, padding_mode) where {T,V} index = (threadIdx().x - 1) + (blockIdx().x - 1) * blockDim().x if index < n_elem iW, iH, iC, _ = size(input) @@ -16,7 +16,7 @@ function grid_sample_kernel!(n_elem, output, input, grid, padding_mode) nothing end -function ∇grid_sample_kernel!(n_elem, dx, dgrid, Δ, input, grid, padding_mode) +function ∇grid_sample_kernel!(n_elem, dx::AbstractArray{T, 4}, dgrid::AbstractArray{V, 4}, Δ::AbstractArray{T, 4}, input::AbstractArray{T, 4}, grid::AbstractArray{V, 4}, padding_mode) where {T,V} index = (threadIdx().x - 1) + (blockIdx().x - 1) * blockDim().x if index < n_elem iW, iH, iC, _ = size(input) @@ -59,3 +59,74 @@ function NNlib.∇grid_sample(Δ::CuArray{T, 4}, x::CuArray{T, 4}, grid::CuArray kernel(n_elem, dx, dgrid, Δ, x, grid, pad; threads=threads, blocks=blocks) dx, dgrid end + + +@inline function NNlib._safe_add!(dx::CuDeviceArray{T, 5}, value, ix, iy, iz, c, n) where T + @inbounds CUDA.@atomic dx[ix, iy, iz, c, n] += value +end + +function grid_sample_kernel!(n_elem, output::AbstractArray{T, 5}, input::AbstractArray{T, 5}, grid::AbstractArray{V, 5}, padding_mode) where {T,V} + index = (threadIdx().x - 1) + (blockIdx().x - 1) * blockDim().x + if index < n_elem + iW, iH,iD, iC, _ = size(input) + _, gW, gH, gD, _ = size(grid) + + w = index % gW + 1 + h = (index ÷ gW) % gH + 1 + d = (index ÷ (gW * gH)) % gD + 1 + n = index ÷ (gW * gH * gD) + 1 + # n = index ÷ (gW * gH) + 1 + # d = (index ÷ (gW * gH * n)) + 1 + + NNlib._grid_sample_kernel!(output, input, grid, padding_mode, w, h, d, n, iW, iH, iD, iC) + end + nothing +end + +function ∇grid_sample_kernel!(n_elem, dx::AbstractArray{T, 5}, dgrid::AbstractArray{V, 5}, Δ::AbstractArray{T, 5}, input::AbstractArray{T, 5}, grid::AbstractArray{V, 5}, padding_mode) where {T,V} + index = (threadIdx().x - 1) + (blockIdx().x - 1) * blockDim().x + if index < n_elem + iW, iH, iD, iC, _ = size(input) + _, gW, gH, gD, _ = size(grid) + + w = index % gW + 1 + h = (index ÷ gW) % gH + 1 + d = (index ÷ (gW * gH)) % gD + 1 + n = index ÷ (gW * gH * gD) + 1 + # n = index ÷ (gW * gH) + 1 + # d = (index ÷ (gW * gH * n)) + 1 + + NNlib._∇grid_sample_kernel!(dx, dgrid, Δ, input, grid, padding_mode, w, h, d, n, iW, iH, iD, iC) + end + nothing +end + +function NNlib.grid_sample(x::CuArray{T, 5}, grid::CuArray{V, 5}; padding_mode = :zeros) where {T, V} + pad = Val(padding_mode) + _, _, _, xC, xN = size(x) + _, gW, gH, gD, _ = size(grid) + n_elem = gW * gH * gD * xN + y = similar(x, T, (gW, gH, gD, xC, xN)) + + kernel = @cuda launch=false grid_sample_kernel!(n_elem, y, x, grid, pad) + config = launch_configuration(kernel.fun; max_threads=256) + threads = min(n_elem, config.threads) + blocks = cld(n_elem, threads) + kernel(n_elem, y, x, grid, pad; threads=threads, blocks=blocks) + y +end + +function NNlib.∇grid_sample(Δ::CuArray{T, 5}, x::CuArray{T, 5}, grid::CuArray{V, 5}; padding_mode = :zeros) where {T, V} + pad = Val(padding_mode) + xN = size(x, 5) + _, gW, gH, gD, _ = size(grid) + n_elem = gW * gH * gD * xN + dx, dgrid = CUDA.zeros(T, size(x)), similar(grid) + + kernel = @cuda launch=false ∇grid_sample_kernel!(n_elem, dx, dgrid, Δ, x, grid, pad) + config = launch_configuration(kernel.fun; max_threads=256) + threads = min(n_elem, config.threads) + blocks = cld(n_elem, threads) + kernel(n_elem, dx, dgrid, Δ, x, grid, pad; threads=threads, blocks=blocks) + dx, dgrid +end \ No newline at end of file diff --git a/ext/NNlibEnzymeCoreExt/NNlibEnzymeCoreExt.jl b/ext/NNlibEnzymeCoreExt/NNlibEnzymeCoreExt.jl new file mode 100644 index 000000000..1894a585b --- /dev/null +++ b/ext/NNlibEnzymeCoreExt/NNlibEnzymeCoreExt.jl @@ -0,0 +1,386 @@ +module NNlibEnzymeCoreExt + +using NNlib +import EnzymeCore +using Random + +using EnzymeCore.EnzymeRules + +for (name, dataname, filtername) in ( + (typeof(NNlib.conv!), NNlib.∇conv_data!, NNlib.∇conv_filter!), + (typeof(NNlib.depthwiseconv!), NNlib.∇depthwiseconv_data!, NNlib.∇depthwiseconv_filter!), + (typeof(NNlib.∇conv_data!), NNlib.conv!, NNlib.∇conv_filter!), + (typeof(NNlib.∇conv_filter!), NNlib.∇conv_data!, NNlib.conv!), + ) + @eval begin + + function EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{$name}, ::Type{RT}, + y::EnzymeCore.Annotation{<:AbstractArray{yT, N}}, + x::EnzymeCore.Annotation{<:AbstractArray{xT, N}}, + w::EnzymeCore.Annotation{<:AbstractArray{wT, N}}, + cdims; kwargs...) where {RT, yT, xT, wT, N} + + if typeof(y) <: EnzymeCore.Duplicated || typeof(y) <: EnzymeCore.BatchDuplicated + func.val(y.val, x.val, w.val, cdims.val; kwargs...) + end + + primal = if EnzymeRules.needs_primal(config) + y.val + else + nothing + end + shadow = if EnzymeRules.needs_shadow(config) + y.dval + else + nothing + end + + # Cache x if its overwritten and w is active (and thus required) + cache_x = ( EnzymeRules.overwritten(config)[3] + && !(typeof(w) <: EnzymeCore.Const) + && !(typeof(y) <: EnzymeCore.Const) + ) ? copy(x.val) : nothing + + # Cache w if its overwritten and x is active (and thus required) + cache_w = ( EnzymeRules.overwritten(config)[4] + && !(typeof(x) <: EnzymeCore.Const) + && !(typeof(y) <: EnzymeCore.Const) + ) ? copy(w.val) : nothing + + cache = (cache_x, cache_w) + + return EnzymeRules.AugmentedReturn(primal, shadow, cache) + end + + function EnzymeRules.reverse(config, func::EnzymeCore.Const{$name}, ::Type{RT}, cache, + y::EnzymeCore.Annotation{<:AbstractArray{yT, N}}, + x::EnzymeCore.Annotation{<:AbstractArray{xT, N}}, + w::EnzymeCore.Annotation{<:AbstractArray{wT, N}}, + cdims; kwargs...) where {RT, yT, xT, wT, N} + cache_x, cache_w = cache + + # Don't cache x if not overwritten and w is active (and thus required) + if !(typeof(w) <: EnzymeCore.Const) && !(typeof(y) <: EnzymeCore.Const) + if !EnzymeRules.overwritten(config)[3] + cache_x = x.val + end + end + + # Don't cache w if not overwritten and x is active (and thus required) + if !(typeof(x) <: EnzymeCore.Const) && !(typeof(y) <: EnzymeCore.Const) + if !EnzymeRules.overwritten(config)[4] + cache_w = w.val + end + end + + dys = y.dval + dxs = (typeof(x) <: EnzymeCore.Const) ? dys : x.dval + dws = (typeof(w) <: EnzymeCore.Const) ? dys : w.dval + + if EnzymeRules.width(config) == 1 + dys = (dys,) + dxs = (dxs,) + dws = (dws,) + end + + for (dy, dx, dw) in zip(dys, dxs, dws) + if !(typeof(y) <: EnzymeCore.Const) && dy !== y.val + + if !(typeof(x) <: EnzymeCore.Const) && dx !== x.val + # dx += grad wrt x.val + $dataname(dx, $(name != typeof(NNlib.∇conv_filter!) ? :dy : :cache_w), $(name != typeof(NNlib.∇conv_filter!) ? :cache_w : :dy), cdims.val; alpha=xT(1), beta=xT(1), kwargs...) + end + if !(typeof(w) <: EnzymeCore.Const) && dw !== w.val + # dw += grad wrt w.val + $filtername(dw, $(name != typeof(NNlib.∇conv_data!) ? :cache_x : :dy), $(name != typeof(NNlib.∇conv_data!) ? :dy : :cache_x), cdims.val; alpha=wT(1), beta=wT(1), kwargs...) + end + + dy .= 0 + end + end + + return (nothing, nothing, nothing, nothing) + end + +end +end + +function EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{typeof(NNlib.gather!)}, ::Type{RT}, dst::OutType, src, idx::EnzymeCore.Const) where {OutType, RT} + + if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.BatchDuplicated + func.val(dst.val, src.val, idx.val) + end + + primal = if EnzymeRules.needs_primal(config) + dst.val + else + nothing + end + shadow = if EnzymeRules.needs_shadow(config) + dst.dval + else + nothing + end + + # Cache idx if its overwritten + cache_idx = ( EnzymeRules.overwritten(config)[4] + && !(typeof(src) <: EnzymeCore.Const) + && !(typeof(dst) <: EnzymeCore.Const) + ) ? copy(idx.val) : nothing + + return EnzymeRules.AugmentedReturn(primal, shadow, cache_idx) +end + +function EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NNlib.gather!)}, ::Type{RT}, cache_idx, dst::OutType, src, idx::EnzymeCore.Const) where {OutType, RT} + + # Don't cache idx if not overwritten + if !(typeof(src) <: EnzymeCore.Const) && !(typeof(dst) <: EnzymeCore.Const) + if !EnzymeRules.overwritten(config)[4] + cache_idx = idx.val + end + end + + ddsts = dst.dval + dsrcs = (typeof(src) <: EnzymeCore.Const) ? ddsts : src.dval + + if EnzymeRules.width(config) == 1 + ddsts = (ddsts,) + dsrcs = (dsrcs,) + end + + for (ddst, dsrc) in zip(ddsts, dsrcs) + if !(typeof(dst) <: EnzymeCore.Const) && ddst !== dst.val + + if !(typeof(src) <: EnzymeCore.Const) && dsrc !== src.val + NNlib.scatter!(+, dsrc, ddst, cache_idx) + end + + ddst .= 0 + end + end + + return (nothing, nothing, nothing) +end + + + +function EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{typeof(NNlib.scatter!)}, ::Type{RT}, op::EnzymeCore.Const, dst::OutType, src, idx::EnzymeCore.Const) where {OutType, RT} + + @assert !(OutType <: EnzymeCore.Const) + if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.BatchDuplicated + func.val(op.val, dst.val, src.val, idx.val) + end + + primal = if EnzymeRules.needs_primal(config) + dst.val + else + nothing + end + shadow = if EnzymeRules.needs_shadow(config) + dst.dval + else + nothing + end + + # Cache idx if its overwritten + cache_idx = ( EnzymeRules.overwritten(config)[4] + && !(typeof(src) <: EnzymeCore.Const) + && !(typeof(dst) <: EnzymeCore.Const) + ) ? copy(idx.val) : nothing + + return EnzymeRules.AugmentedReturn(primal, shadow, cache_idx) +end + +function EnzymeRules.reverse(config, + func::EnzymeCore.Const{typeof(NNlib.scatter!)}, + ::Type{RT}, + cache_idx, + op::Union{EnzymeCore.Const{typeof(+)},EnzymeCore.Const{typeof(-)}}, dst::OutType, + src, + idx::EnzymeCore.Const) where {OutType, RT} + + # Don't cache idx if not overwritten + if !(typeof(src) <: EnzymeCore.Const) && !(typeof(dst) <: EnzymeCore.Const) + if !EnzymeRules.overwritten(config)[4] + cache_idx = idx.val + end + end + + ddsts = dst.dval + dsrcs = (typeof(src) <: EnzymeCore.Const) ? ddsts : src.dval + + if EnzymeRules.width(config) == 1 + ddsts = (ddsts,) + dsrcs = (dsrcs,) + end + + for (ddst, dsrc) in zip(ddsts, dsrcs) + if !(typeof(dst) <: EnzymeCore.Const) && ddst !== dst.val + + if !(typeof(src) <: EnzymeCore.Const) && dsrc !== src.val + + if eltype(typeof(op)) == typeof(+) + dsrc .+= NNlib.gather(ddst, cache_idx) + else + @assert eltype(typeof(op)) == typeof(-) + dsrc .-= NNlib.gather(ddst, cache_idx) + end + end + + end + end + + return (nothing, nothing, nothing, nothing) +end + + + +for pool in [:maxpool, :meanpool, :lpnormpool] + pool! = Symbol(pool, :!) + ∇pool = Symbol(:∇, pool, :!) + + @eval begin + +function EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{typeof($pool!)}, ::Type{RT}, y::OutType, x, dims; kwargs...) where {OutType, RT} + + if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.BatchDuplicated + func.val(y.val, x.val, dims.val; kwargs...) + end + + primal = if EnzymeRules.needs_primal(config) + y.val + else + nothing + end + shadow = if EnzymeRules.needs_shadow(config) + y.dval + else + nothing + end + + cache_y = ( EnzymeRules.overwritten(config)[2] + && !(typeof(x) <: EnzymeCore.Const) + && !(typeof(y) <: EnzymeCore.Const) + ) ? copy(y.val) : nothing + + cache_x = ( EnzymeRules.overwritten(config)[3] + && !(typeof(x) <: EnzymeCore.Const) + && !(typeof(y) <: EnzymeCore.Const) + ) ? copy(x.val) : nothing + + cache = (cache_y, cache_x) + + return EnzymeRules.AugmentedReturn(primal, shadow, cache) +end + +function EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof($pool!)}, ::Type{RT}, cache, y, x, dims; kwargs...) where {RT} + cache_y, cache_x = cache + + # Don't cache y if not overwritten + if !(typeof(x) <: EnzymeCore.Const) && !(typeof(y) <: EnzymeCore.Const) + if !EnzymeRules.overwritten(config)[2] + cache_y = y.val + end + end + + # Don't cache x if not overwritten + if !(typeof(x) <: EnzymeCore.Const) && !(typeof(y) <: EnzymeCore.Const) + if !EnzymeRules.overwritten(config)[3] + cache_x = x.val + end + end + + dys = y.dval + dxs = (typeof(x) <: EnzymeCore.Const) ? dys : x.dval + + if EnzymeRules.width(config) == 1 + dys = (dys,) + dxs = (dxs,) + end + + for (dy, dx) in zip(dys, dxs) + if !(typeof(y) <: EnzymeCore.Const) && dy !== y.val + + if !(typeof(x) <: EnzymeCore.Const) && dx !== x.val + NNlib.$(∇pool)(dx, dy, cache_y, cache_x, dims.val; alpha=eltype(dx)(1), beta=eltype(dx)(1), kwargs...) + end + + dy .= 0 + end + end + + return (nothing, nothing, nothing) +end + +end +end + +function EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{typeof(NNlib._dropout!)}, ::Type{RT}, rng, dst::OutType, src, p, dims) where {OutType, RT} + + T = float(real(eltype(dst.val))) + val = convert(T, 1/(1-p.val)) + keep = if dims.val isa Colon + similar(dst.val, T, size(dst.val)) + else + similar(dst.val, T, ntuple(d -> d in dims.val ? size(dst.val,d) : 1, ndims(dst.val))) + end + rand!(rng.val, keep) + + keep = keep .> p.val + + if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.BatchDuplicated + dst.val .= (keep .* val) .* src.val + end + + primal = if EnzymeRules.needs_primal(config) + dst.val + else + nothing + end + shadow = if EnzymeRules.needs_shadow(config) + dst.dval + else + nothing + end + + if typeof(dst) <: EnzymeCore.Const || typeof(src) <: EnzymeCore.Const + keep = nothing + end + + return EnzymeRules.AugmentedReturn(primal, shadow, keep) +end + +function EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NNlib._dropout!)}, ::Type{RT}, keep, rng, dst::OutType, src, p, dims) where {OutType, RT} + T = float(real(eltype(dst.val))) + val = convert(T, 1/(1-p.val)) + + ddsts = dst.dval + dsrcs = (typeof(src) <: EnzymeCore.Const) ? ddsts : src.dval + + if EnzymeRules.width(config) == 1 + ddsts = (ddsts,) + dsrcs = (dsrcs,) + end + + for (ddst, dsrc) in zip(ddsts, dsrcs) + if !(typeof(dst) <: EnzymeCore.Const) && ddst !== dst.val + + if !(typeof(src) <: EnzymeCore.Const) && dsrc !== src.val + dsrc .+= (keep .* val) .* ddst + end + + ddst .= 0 + end + end + + dp = if typeof(p) <: EnzymeCore.Active + typeof(p.val)(0) + else + nothing + end + + return (nothing, nothing, nothing, dp, nothing) +end + + +end diff --git a/ext/NNlibFFTWExt/NNlibFFTWExt.jl b/ext/NNlibFFTWExt/NNlibFFTWExt.jl new file mode 100644 index 000000000..ee314cd51 --- /dev/null +++ b/ext/NNlibFFTWExt/NNlibFFTWExt.jl @@ -0,0 +1,9 @@ +module NNlibFFTWExt + +using FFTW +using NNlib +using KernelAbstractions + +include("stft.jl") + +end diff --git a/ext/NNlibFFTWExt/stft.jl b/ext/NNlibFFTWExt/stft.jl new file mode 100644 index 000000000..b5ae00e3e --- /dev/null +++ b/ext/NNlibFFTWExt/stft.jl @@ -0,0 +1,127 @@ +function NNlib.stft(x; + n_fft::Int, hop_length::Int = n_fft ÷ 4, window = nothing, + center::Bool = true, normalized::Bool = false, +) + kab = get_backend(x) + use_window = !isnothing(window) + + use_window && kab != get_backend(window) && throw(ArgumentError( + "`window` must be on the same device as stft input `x` ($kab), \ + instead: `$(get_backend(window))`.")) + use_window && !(0 < length(window) ≤ n_fft) && throw(ArgumentError( + "Expected `0 < length(window) ≤ n_fft=$n_fft`, \ + but got `length(window)=$(length(window))`.")) + hop_length < 0 && throw(ArgumentError( + "Expected `hop_length > 0`, but got `hop_length=$hop_length`.")) + + # Pad window on both sides with `0` to `n_fft` length if needed. + if use_window && length(window) < n_fft + left = ((n_fft - length(window)) ÷ 2) + 1 + tmp = KernelAbstractions.zeros(kab, eltype(window), n_fft) + tmp[left:left + length(window) - 1] .= window + window = tmp + end + + if center + pad_amount = n_fft ÷ 2 + x = pad_reflect(x, pad_amount; dims=1) + end + + n = size(x, 1) + (0 < n_fft ≤ n) || throw(ArgumentError( + "Expected `0 < n_fft ≤ size(x, 1)=$n`, but got `n_fft=$n_fft`.")) + + n_frames = 1 + (n - n_fft) ÷ hop_length + + # time2col. + # Reshape `x` to (n_fft, n_frames, B) if needed. + # Each row in `n_frames` is shifted by `hop_length`. + if n_frames > 1 + # TODO can be more efficient if we support something like torch.as_strided + ids = [ + row + hop_length * col + for row in 1:n_fft, col in 0:(n_frames - 1)] + x = @inbounds x[ids, ntuple(_ -> Colon(), ndims(x) - 1)...] + end + + region = 1 + use_window && (x = x .* window;) + y = eltype(x) <: Complex ? fft(x, region) : rfft(x, region) + + normalized && (y = y .* eltype(y)(n_fft^-0.5);) + return y +end + +function NNlib.istft(y; + n_fft::Int, hop_length::Int = n_fft ÷ 4, window = nothing, + center::Bool = true, normalized::Bool = false, + return_complex::Bool = false, + original_length::Union{Nothing, Int} = nothing, +) + kab = get_backend(y) + use_window = !isnothing(window) + + use_window && kab != get_backend(window) && throw(ArgumentError( + "`window` must be on the same device as istft input `y` ($kab), \ + instead: `$(get_backend(window))`.")) + use_window && !(0 < length(window) ≤ n_fft) && throw(ArgumentError( + "Expected `0 < length(window) ≤ n_fft=$n_fft`, \ + but got `length(window)=$(length(window))`.")) + hop_length < 0 && throw(ArgumentError( + "Expected `hop_length > 0`, but got `hop_length=$hop_length`.")) + + # TODO check `y` eltype is complex + + n_frames = size(y, 2) + + # Pad window on both sides with `0` to `n_fft` length if needed. + if use_window && length(window) < n_fft + left = ((n_fft - length(window)) ÷ 2) + 1 + tmp = KernelAbstractions.zeros(kab, eltype(window), n_fft) + tmp[left:left + length(window) - 1] .= window + window = tmp + end + + # Denormalize. + normalized && (y = y .* eltype(y)(n_fft^0.5);) + + region = 1 + x = return_complex ? ifft(y, region) : irfft(y, n_fft, region) + + # De-apply window. + use_window && (x = x ./ window;) + + # col2time. + expected_output_len = n_fft + hop_length * (n_frames - 1) + + ids = Vector{Int}(undef, expected_output_len) + in_idx, out_idx = 0, 0 + prev_e, v = 0, 0 + + for col in 0:(n_frames - 1) + for row in 1:n_fft + in_idx += 1 + v = row + hop_length * col + v > prev_e || continue + + out_idx += 1 + ids[out_idx] = in_idx + end + prev_e = v + end + + # In case of batched input, reshaped it (n_fft, n_frames, batch) -> (:, batch). + nd = ntuple(_ -> Colon(), ndims(x) - 2) + ndims(x) == 3 && (x = reshape(x, (:, size(x, 3)));) + x = @inbounds x[ids, nd...] + + # Trim padding. + left = center ? (n_fft ÷ 2 + 1) : 1 + right = if isnothing(original_length) + center ? (size(x, 1) - n_fft ÷ 2) : expected_output_len + else + left + original_length - 1 + end + x = x[left:right, nd...] + return x +end diff --git a/ext/NNlibForwardDiffExt.jl b/ext/NNlibForwardDiffExt.jl new file mode 100644 index 000000000..84351bc17 --- /dev/null +++ b/ext/NNlibForwardDiffExt.jl @@ -0,0 +1,9 @@ +module NNlibForwardDiffExt + +using ForwardDiff: ForwardDiff +using NNlib: NNlib + +NNlib.within_gradient(x::ForwardDiff.Dual) = true +NNlib.within_gradient(x::AbstractArray{<:ForwardDiff.Dual}) = true + +end diff --git a/ext/NNlibSpecialFunctionsExt.jl b/ext/NNlibSpecialFunctionsExt.jl new file mode 100644 index 000000000..b50abb288 --- /dev/null +++ b/ext/NNlibSpecialFunctionsExt.jl @@ -0,0 +1,15 @@ +module NNlibSpecialFunctionsExt + +using NNlib: NNlib, oftf +using SpecialFunctions: erf + +# Full gelu (gelu_erf) +NNlib.gelu_erf(x) = x/2*(1 + erf(x/sqrt(oftf(x,2)))) + +function NNlib.deriv_gelu_erf(x) + SQRT2 = sqrt(oftf(x,2)) + Φ = (1 + erf(x/SQRT2))/2 + Φ + x/SQRT2*exp(-(x^2)/2)/sqrt(oftf(x,π)) +end + +end \ No newline at end of file diff --git a/src/NNlib.jl b/src/NNlib.jl index 8450a0261..5ea907783 100644 --- a/src/NNlib.jl +++ b/src/NNlib.jl @@ -12,38 +12,36 @@ using KernelAbstractions: @atomic using LinearAlgebra using LinearAlgebra.BLAS: @blasfunc, BlasInt using LinearAlgebra: AdjOrTransAbsMat, Adjoint, BlasFloat, Transpose -using Pkg using Random -using Requires +using ScopedValues using Statistics using Statistics: mean -const libblas = Base.libblas_name - const Numeric = Union{AbstractArray{<:T}, T} where {T<:Number} +# internal. TODO: change to an approach where amount of threading is controlled, not just on/off +const ALLOW_SPAWNS = ScopedValue(true) +should_use_spawn() = Threads.nthreads(:default) > 1 && ALLOW_SPAWNS[] +""" + @disallow_spawns ex + +Disallow NNlib to use `@spawn` on divisible workloads. i.e. within `conv` etc. +""" +macro disallow_spawns(ex) + quote + @with ALLOW_SPAWNS => false $(esc(ex)) + end +end + # Include APIs include("dim_helpers.jl") export ConvDims, DenseConvDims, PoolDims, DepthwiseConvDims -is_nnpack_available() = false - -@init @require NNPACK_jll="a6bfbf70-4841-5cb9-aa18-3a8ad3c413ee" begin - if isdefined(NNPACK_jll, :libnnpack) - include("nnpack/NNPACK.jl") - else - @warn "NNPACK not available for your platform: " * - "$( Pkg.BinaryPlatforms.platform_name(Pkg.BinaryPlatforms.platform_key_abi()))" * - "($( Pkg.BinaryPlatforms.triplet(Pkg.BinaryPlatforms.platform_key_abi()))) - You will be able to use only the default Julia NNlib backend" - end -end - include("activations.jl") for f in ACTIVATIONS @eval export $(f) end -export sigmoid, hardsigmoid, logsigmoid, thresholdrelu # Aliases +export sigmoid, hardsigmoid, logsigmoid, thresholdrelu, gelu # Aliases include("attention.jl") export dot_product_attention, dot_product_attention_scores, make_causal_mask @@ -52,7 +50,7 @@ include("dropout.jl") export dropout, dropout! include("softmax.jl") -export softmax, softmax!, ∇softmax, ∇softmax!, logsoftmax, +export softmax, softmax!, ∇softmax, ∇softmax!, logsoftmax, logsoftmax!, ∇logsoftmax, ∇logsoftmax!, logsumexp include("batched/batchedadjtrans.jl") @@ -64,9 +62,9 @@ include("gemm.jl") export grid_sample, ∇grid_sample include("conv.jl") -export conv, conv!, ∇conv_data, ∇conv_data!, ∇conv_filter, - ∇conv_filter!, depthwiseconv, depthwiseconv!, - ∇depthwiseconv_data, ∇depthwiseconv_data!, +export conv, conv!, ∇conv_data, ∇conv_data!, ∇conv_filter, + ∇conv_filter!, depthwiseconv, depthwiseconv!, + ∇depthwiseconv_data, ∇depthwiseconv_data!, ∇depthwiseconv_filter, ∇depthwiseconv_filter! include("conv_bias_act.jl") @@ -97,11 +95,6 @@ export upsample_nearest, ∇upsample_nearest, include("gather.jl") include("scatter.jl") include("utils.jl") -@init @require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" begin - using .ForwardDiff - within_gradient(x::ForwardDiff.Dual) = true - within_gradient(x::AbstractArray{<:ForwardDiff.Dual}) = true -end include("sampling.jl") include("functions.jl") @@ -123,4 +116,12 @@ include("impl/depthwiseconv_im2col.jl") include("impl/pooling_direct.jl") include("deprecations.jl") +include("rotation.jl") +export imrotate, ∇imrotate + +include("audio/stft.jl") +include("audio/spectrogram.jl") +include("audio/mel.jl") +export stft, istft, hann_window, hamming_window, spectrogram, melscale_filterbanks + end # module NNlib diff --git a/src/activations.jl b/src/activations.jl index a034586a8..ba3646998 100644 --- a/src/activations.jl +++ b/src/activations.jl @@ -5,7 +5,7 @@ ACTIVATIONS = [ :σ, :hardσ, :hardtanh, :relu, - :leakyrelu, :relu6, :rrelu, :elu, :gelu, :swish, :hardswish, :selu, + :leakyrelu, :relu6, :rrelu, :elu, :gelu_tanh, :gelu_erf, :swish, :hardswish, :selu, :celu, :softplus, :softsign, :logσ, :logcosh, :mish, :tanhshrink, :softshrink, :trelu, :lisht, :tanh_fast, :sigmoid_fast, @@ -14,6 +14,13 @@ ACTIVATIONS = [ # of type float (to allow for integer inputs) oftf(x, y) = oftype(float(x), y) +# oftype contains control flow on 1.10+, which can lead to type instabilities under AD +function rrule(::typeof(oftf), x, y) + proj_y = ChainRulesCore.ProjectTo(y) + oftf_pullback(Δ) = (NoTangent(), NoTangent(), proj_y(Δ)) + return oftf(x, y), oftf_pullback +end + """ σ(x) = 1 / (1 + exp(-x)) @@ -24,7 +31,7 @@ The ascii name `sigmoid` is also exported. See also [`sigmoid_fast`](@ref). -``` +```julia-repl julia> using UnicodePlots julia> lineplot(sigmoid, -5, 5, height=7) @@ -56,7 +63,7 @@ const sigmoid = σ Piecewise linear approximation of [`sigmoid`](@ref). -``` +```julia-repl julia> lineplot(hardsigmoid, -5, 5, height=7) ┌────────────────────────────────────────┐ 1 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⢀⡠⠖⠋⠉⠉⠉⠉⠉⠉⠉⠉│ hardσ(x) @@ -95,7 +102,7 @@ const hardsigmoid = hardσ Return `log(σ(x))` which is computed in a numerically stable way. -``` +```julia-repl julia> lineplot(logsigmoid, -5, 5, height=7) ┌────────────────────────────────────────┐ 0 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⡧⠤⠔⠒⠒⠒⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉│ logσ(x) @@ -121,7 +128,7 @@ Segment-wise linear approximation of `tanh`, much cheaper to compute. See ["Large Scale Machine Learning"](https://ronan.collobert.com/pub/matos/2004_phdthesis_lip6.pdf). See also [`tanh_fast`](@ref). -``` +```julia-repl julia> lineplot(hardtanh, -2, 2, height=7) ┌────────────────────────────────────────┐ 1 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⣀⠔⠋⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉│ hardtanh(x) @@ -157,7 +164,7 @@ hardtanh(x) = clamp(x, oftype(x, -1), oftype(x, 1)) # clamp(x, -1, 1) is type-s [Rectified Linear Unit](https://en.wikipedia.org/wiki/Rectifier_(neural_networks)) activation function. -``` +```julia-repl julia> lineplot(relu, -2, 2, height=7) ┌────────────────────────────────────────┐ 2 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠔⠋│ relu(x) @@ -181,7 +188,7 @@ Leaky [Rectified Linear Unit](https://en.wikipedia.org/wiki/Rectifier_(neural_ne activation function. You can also specify the coefficient explicitly, e.g. `leakyrelu(x, 0.01)`. -```julia +```julia-repl julia> lineplot(x -> leakyrelu(x, 0.5), -2, 2, height=7) ┌────────────────────────────────────────┐ 2 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠤⠒⠉│ #42(x) @@ -213,7 +220,7 @@ const leakyrelu_a = 0.01 # also used in gradient below activation function capped at 6. See ["Convolutional Deep Belief Networks"](https://www.cs.toronto.edu/~kriz/conv-cifar10-aug2010.pdf) from CIFAR-10. -``` +```julia-repl julia> lineplot(relu6, -10, 10, height=7) ┌────────────────────────────────────────┐ 6 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠎⠉⠉⠉⠉⠉⠉⠉⠉│ relu6(x) @@ -238,7 +245,7 @@ Randomized Leaky Rectified Linear Unit activation function. See ["Empirical Evaluation of Rectified Activations"](https://arxiv.org/abs/1505.00853) You can also specify the bound explicitly, e.g. `rrelu(x, 0.0, 1.0)`. -```julia +```julia-repl julia> lineplot(rrelu, -20, 10, height=7) ┌────────────────────────────────────────┐ 10 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢸⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⡤⠖⠋│ rrelu(x) @@ -268,7 +275,7 @@ Exponential Linear Unit activation function. See ["Fast and Accurate Deep Network Learning by Exponential Linear Units"](https://arxiv.org/abs/1511.07289). You can also specify the coefficient explicitly, e.g. `elu(x, 1)`. -``` +```julia-repl julia> lineplot(elu, -2, 2, height=7) ┌────────────────────────────────────────┐ 2 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠤⠒⠉│ elu(x) @@ -294,14 +301,14 @@ elu(x, α=1) = ifelse(x ≥ 0, float(x), @fastmath oftf(x, α) * (exp(x) - 1)) deriv_elu(Ω, α=1) = ifelse(Ω ≥ 0, one(Ω), Ω + oftype(Ω, α)) """ - gelu(x) = 0.5x * (1 + tanh(√(2/π) * (x + 0.044715x^3))) + gelu_tanh(x) = 0.5x * (1 + tanh(√(2/π) * (x + 0.044715x^3))) -Activation function from ["Gaussian Error Linear Units"](https://arxiv.org/abs/1606.08415). +Activation function from ["Gaussian Error Linear Units"](https://arxiv.org/abs/1606.08415) using tanh approximation. -``` -julia> lineplot(gelu, -2, 2, height=7) +```julia-repl +julia> lineplot(gelu_tanh, -2, 2, height=7) ┌────────────────────────────────────────┐ - 2 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡠⠔⠊│ gelu(x) + 2 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡠⠔⠊│ gelu_tanh(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠔⠊⠁⠀⠀⠀│ │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⣀⠤⠒⠉⠀⠀⠀⠀⠀⠀⠀│ f(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⣀⡠⠤⠒⠉⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ @@ -312,11 +319,11 @@ julia> lineplot(gelu, -2, 2, height=7) ⠀-2⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀2⠀ ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀ -julia> lineplot(gelu, -5, 0, height=7); +julia> lineplot(gelu_tanh, -5, 0, height=7); julia> lineplot!(ans, swish) ┌────────────────────────────────────────┐ - 0 │⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠒⠒⠤⣄⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢸│ gelu(x) + 0 │⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠒⠒⠤⣄⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢸│ gelu_tanh(x) │⠑⠒⠢⠤⣄⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠉⠓⢄⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇│ swish(x) │⠀⠀⠀⠀⠀⠈⠉⠒⠤⣀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠑⢆⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣸⠁│ f(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠉⠒⢄⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠑⢄⠀⠀⠀⠀⠀⠀⠀⠀⢠⡇⠀│ @@ -328,7 +335,7 @@ julia> lineplot!(ans, swish) ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀ ``` """ -function gelu(x) +function gelu_tanh(x) α = oftf(x, 0.044715) # λ = oftf(x, gelu_λ) # x/2 * (1 + tanh(λ * (x + α * x^3))) # Standard implementation, for reference @@ -339,7 +346,7 @@ end const gelu_λ = √(2 / π) const gelu_2λ = √(8 / π) -function deriv_gelu(x) +function deriv_gelu_tanh(x) α = oftf(x, 0.044715) α2 = oftf(x, 0.08943) λλ = oftf(x, gelu_2λ) @@ -350,13 +357,34 @@ function deriv_gelu(x) muladd(dσ * λλ * muladd(x2, α2, t), x, Ω) end +""" + gelu_erf(x) = xΦ(x) = 0.5x * (1 + erf(x/√2)) + +Activation function from ["Gaussian Error Linear Units"](https://arxiv.org/abs/1606.08415) without approximation. +The SpecialFunctions.jl package needs to be loaded to use this function. +""" +function gelu_erf end +function deriv_gelu_erf end + +""" + gelu(x) = gelu_tanh(x) + +Activation function from ["Gaussian Error Linear Units"](https://arxiv.org/abs/1606.08415). +See [`gelu_tanh`](@ref). +""" +const gelu = gelu_tanh +# Need to alias the type as well to ensure serialization libraries still work +# See https://github.com/FluxML/NNlib.jl/issues/631 +const var"#gelu" = typeof(gelu_tanh) +const deriv_gelu = deriv_gelu_tanh + """ swish(x) = x * σ(x) Self-gated activation function. See ["Swish: a Self-Gated Activation Function"](https://arxiv.org/abs/1710.05941). -``` +```julia-repl julia> lineplot(swish, -2, 2, height=7) ┌────────────────────────────────────────┐ 2 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⡤│ swish(x) @@ -379,7 +407,7 @@ julia> lineplot(swish, -2, 2, height=7) Hard-Swish activation function. See ["Searching for MobileNetV3"](https://arxiv.org/abs/1905.02244). -``` +```julia-repl julia> lineplot(hardswish, -2, 5, height = 7) ┌────────────────────────────────────────┐ 5 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⡤⠔⠒⠉│ hardswish(x) @@ -423,7 +451,7 @@ deriv_hardswish(x) = ifelse(x < -3, oftf(x,0), ifelse(x > 3, oftf(x,1), x/3 + of Activation function from ["LiSHT: Non-Parametric Linearly Scaled Hyperbolic Tangent ..."](https://arxiv.org/abs/1901.05894) -``` +```julia-repl julia> lineplot(lisht, -2, 2, height=7) ┌────────────────────────────────────────┐ 2 │⠢⣄⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠔│ lisht(x) @@ -462,7 +490,7 @@ lisht(x) = x * tanh_fast(x) Scaled exponential linear units. See ["Self-Normalizing Neural Networks"](https://arxiv.org/abs/1706.02515). -``` +```julia-repl julia> lineplot(selu, -3, 2, height=7) ┌────────────────────────────────────────┐ 3 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ selu(x) @@ -500,7 +528,7 @@ end Activation function from ["Continuously Differentiable Exponential Linear Units"](https://arxiv.org/abs/1704.07483). -``` +```julia-repl julia> lineplot(celu, -2, 2, height=7) ┌────────────────────────────────────────┐ 2 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠤⠒⠉│ celu(x) @@ -528,7 +556,7 @@ deriv_celu(Ω, α=1) = ifelse(Ω > 0, oftf(Ω, 1), Ω / oftf(Ω, α) + 1) Threshold gated rectified linear activation function. See ["Zero-bias autoencoders and the benefits of co-adapting features"](https://arxiv.org/abs/1402.3337) -``` +```julia-repl julia> lineplot(trelu, -2, 4, height=7) ┌────────────────────────────────────────┐ 4 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⡤⠖⠋│ trelu(x) @@ -552,7 +580,7 @@ const thresholdrelu = trelu See ["Quadratic Polynomials Learn Better Image Features"](http://www.iro.umontreal.ca/~lisa/publications2/index.php/attachments/single/205) (2009). -``` +```julia-repl julia> lineplot(softsign, -5, 5, height=7) ┌────────────────────────────────────────┐ 1 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⣀⣀⣀⣀⠤⠤⠤⠤⠤│ softsign(x) @@ -595,7 +623,7 @@ deriv_softsign(x) = 1 / (1 + abs(x))^2 See ["Deep Sparse Rectifier Neural Networks"](http://proceedings.mlr.press/v15/glorot11a/glorot11a.pdf), JMLR 2011. -``` +```julia-repl julia> lineplot(softplus, -3, 3, height=7) ┌────────────────────────────────────────┐ 4 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ softplus(x) @@ -633,7 +661,7 @@ softplus(x) = log1p(exp(-abs(x))) + relu(x) Return `log(cosh(x))` which is computed in a numerically stable way. -``` +```julia-repl julia> lineplot(logcosh, -5, 5, height=7) ┌────────────────────────────────────────┐ 5 │⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ logcosh(x) @@ -657,7 +685,7 @@ const log2 = log(2) Activation function from ["Mish: A Self Regularized Non-Monotonic Neural Activation Function"](https://arxiv.org/abs/1908.08681). -``` +```julia-repl julia> lineplot(mish, -5, 5, height=7) ┌────────────────────────────────────────┐ 5 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡠⠖⠋│ mish(x) @@ -679,7 +707,7 @@ mish(x) = x * tanh(softplus(x)) See ["Tanhshrink Activation Function"](https://www.gabormelli.com/RKB/Tanhshrink_Activation_Function). -``` +```julia-repl julia> lineplot(tanhshrink, -3, 3, height=7) ┌────────────────────────────────────────┐ 3 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ tanhshrink(x) @@ -705,7 +733,7 @@ tanhshrink(x) = x - tanh_fast(x) See ["Softshrink Activation Function"](https://www.gabormelli.com/RKB/Softshrink_Activation_Function). -``` +```julia-repl julia> lineplot(softshrink, -2, 2, height=7) ┌────────────────────────────────────────┐ 2 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀│ softshrink(x) @@ -763,7 +791,7 @@ For any other number types, it just calls `tanh`. See also [`sigmoid_fast`](@ref). -``` +```julia-repl julia> tanh(0.5f0) 0.46211717f0 @@ -801,11 +829,11 @@ tanh_fast(x::Number) = Base.tanh(x) sigmoid_fast(x) This is a faster, and very slightly less accurate, version of `sigmoid`. -For `x::Float32, perhaps 3 times faster, and maximum errors 2 eps instead of 1. +For `x::Float32`, perhaps 3 times faster, and maximum errors 2 eps instead of 1. See also [`tanh_fast`](@ref). -``` +```julia-repl julia> sigmoid(0.2f0) 0.54983395f0 @@ -816,7 +844,10 @@ julia> hardσ(0.2f0) 0.53333336f0 ``` """ -@inline function sigmoid_fast(x::Real) +function sigmoid_fast(x::Real) + @static if VERSION ≥ v"1.11-" + @inline + end t = @fastmath exp(-abs(x)) y = ifelse(x ≥ 0, inv(1 + t), t / (1 + t)) ifelse(x > 40, one(y), ifelse(x < -80, zero(y), y)) @@ -864,7 +895,8 @@ UNARY_ACTS = [ # f, dfdx (:relu6, :((Ω>0) & (Ω<6))), # rrelu is random, can't write a rule. (:elu, :(deriv_elu(Ω))), - (:gelu, :(deriv_gelu(x))), + (:gelu_tanh, :(deriv_gelu_tanh(x))), + (:gelu_erf, :(deriv_gelu_erf(x))), (:swish, :(Ω + sigmoid_fast(x) * (1 - Ω))), (:hardswish, :(deriv_hardswish(x))), # lisht diff --git a/src/attention.jl b/src/attention.jl index 7357dd362..9e22063a7 100644 --- a/src/attention.jl +++ b/src/attention.jl @@ -56,8 +56,8 @@ function dot_product_attention(q::AA3, k::AA3, v::AA3, bias=nothing; First dimension in query, key and value must be divisible by `nheads`. Instead: - size(q): $(size(q)) - - size(k): $(size(q)) - - size(v): $(size(q)) + - size(k): $(size(k)) + - size(v): $(size(v)) - nheads: $nheads """)) (size(q, 3) == size(k, 3) == size(v, 3)) || throw(ArgumentError(""" diff --git a/src/audio/mel.jl b/src/audio/mel.jl new file mode 100644 index 000000000..4181c23eb --- /dev/null +++ b/src/audio/mel.jl @@ -0,0 +1,102 @@ +""" + melscale_filterbanks(; + n_freqs::Int, n_mels::Int, sample_rate::Int, + fmin::Float32 = 0f0, fmax::Float32 = Float32(sample_rate ÷ 2)) + +Create triangular Mel scale filter banks +(ref: [Mel scale - Wikipedia](https://en.wikipedia.org/wiki/Mel_scale)). +Each column is a filterbank that highlights its own frequency. + +# Arguments: + +- `n_freqs::Int`: Number of frequencies to highlight. +- `n_mels::Int`: Number of mel filterbanks. +- `sample_rate::Int`: Sample rate of the audio waveform. +- `fmin::Float32`: Minimum frequency in Hz. +- `fmax::Float32`: Maximum frequency in Hz. + +# Returns: + +Filterbank matrix of shape `(n_freqs, n_mels)` where each column is a filterbank. + +```jldoctest +julia> n_mels = 8; + +julia> fb = melscale_filterbanks(; n_freqs=200, n_mels, sample_rate=16000); + +julia> plot = lineplot(fb[:, 1]); + +julia> for i in 2:n_mels + lineplot!(plot, fb[:, i]) + end + +julia> plot + ┌────────────────────────────────────────┐ + 1 │⠀⡀⢸⠀⢸⠀⠀⣧⠀⠀⢸⡄⠀⠀⠀⣷⠀⠀⠀⠀⠀⣷⠀⠀⠀⠀⠀⠀⢀⣿⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ + │⠀⡇⢸⡆⢸⡇⠀⣿⠀⠀⡜⡇⠀⠀⢰⠋⡆⠀⠀⠀⢰⠁⡇⠀⠀⠀⠀⠀⡸⠀⢣⠀⠀⠀⠀⠀⠀⠀⠀⠀│ + │⠀⣿⢸⡇⡇⡇⢰⠹⡄⠀⡇⢱⠀⠀⢸⠀⢣⠀⠀⠀⡜⠀⢸⡀⠀⠀⠀⢀⠇⠀⠈⡇⠀⠀⠀⠀⠀⠀⠀⠀│ + │⠀⣿⡇⡇⡇⡇⢸⠀⡇⢀⠇⠸⡀⠀⡇⠀⠸⡀⠀⢀⠇⠀⠀⢇⠀⠀⠀⡸⠀⠀⠀⠸⡄⠀⠀⠀⠀⠀⠀⠀│ + │⢠⢻⡇⡇⡇⢱⢸⠀⢇⢸⠀⠀⡇⢀⠇⠀⠀⡇⠀⢸⠀⠀⠀⠸⡀⠀⢠⠇⠀⠀⠀⠀⢱⠀⠀⠀⠀⠀⠀⠀│ + │⢸⢸⡇⢱⡇⢸⡇⠀⢸⢸⠀⠀⢣⢸⠀⠀⠀⢸⠀⡇⠀⠀⠀⠀⢇⠀⡜⠀⠀⠀⠀⠀⠈⢇⠀⠀⠀⠀⠀⠀│ + │⢸⢸⡇⢸⠀⢸⡇⠀⢸⡇⠀⠀⢸⡎⠀⠀⠀⠈⣶⠁⠀⠀⠀⠀⠸⣤⠃⠀⠀⠀⠀⠀⠀⠘⡆⠀⠀⠀⠀⠀│ + │⢸⠀⡇⢸⠀⠀⡇⠀⠀⡇⠀⠀⠀⡇⠀⠀⠀⠀⣿⠀⠀⠀⠀⠀⠀⣿⠀⠀⠀⠀⠀⠀⠀⠀⢱⡀⠀⠀⠀⠀│ + │⢸⢸⡇⢸⠀⢸⡇⠀⢸⡇⠀⠀⢸⢇⠀⠀⠀⢀⠿⡀⠀⠀⠀⠀⢰⠛⡄⠀⠀⠀⠀⠀⠀⠀⠀⢣⠀⠀⠀⠀│ + │⢸⢸⡇⡸⡇⢸⡇⠀⢸⢸⠀⠀⡜⢸⠀⠀⠀⢸⠀⡇⠀⠀⠀⠀⡎⠀⢣⠀⠀⠀⠀⠀⠀⠀⠀⠘⡆⠀⠀⠀│ + │⢸⢸⡇⡇⡇⡸⢸⠀⡎⢸⠀⠀⡇⠈⡆⠀⠀⡇⠀⢸⠀⠀⠀⢰⠁⠀⠘⡆⠀⠀⠀⠀⠀⠀⠀⠀⠸⡄⠀⠀│ + │⡇⢸⡇⡇⡇⡇⢸⠀⡇⠈⡆⢰⠁⠀⡇⠀⢰⠁⠀⠈⡆⠀⠀⡎⠀⠀⠀⢱⠀⠀⠀⠀⠀⠀⠀⠀⠀⢣⠀⠀│ + │⡇⢸⢸⡇⡇⡇⠸⣰⠃⠀⡇⡸⠀⠀⢸⠀⡜⠀⠀⠀⢣⠀⢸⠁⠀⠀⠀⠈⡆⠀⠀⠀⠀⠀⠀⠀⠀⠈⢇⠀│ + │⡇⡇⢸⠇⢸⡇⠀⣿⠀⠀⢣⡇⠀⠀⠸⣄⠇⠀⠀⠀⠸⡀⡇⠀⠀⠀⠀⠀⢱⠀⠀⠀⠀⠀⠀⠀⠀⠀⠸⡄│ + 0 │⣇⣇⣸⣀⣸⣀⣀⣟⣀⣀⣸⣃⣀⣀⣀⣿⣀⣀⣀⣀⣀⣿⣀⣀⣀⣀⣀⣀⣈⣇⣀⣀⣀⣀⣀⣀⣀⣀⣀⣱│ + └────────────────────────────────────────┘ + ⠀0⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀200⠀ +``` +""" +function melscale_filterbanks(; + n_freqs::Int, n_mels::Int, sample_rate::Int, + fmin::Float32 = 0f0, fmax::Float32 = Float32(sample_rate ÷ 2), +) + mel_min, mel_max = _hz_to_mel(fmin), _hz_to_mel(fmax) + mel_points = range(mel_min, mel_max; length=n_mels + 2) + + all_freqs = collect(range(0f0, Float32(sample_rate ÷ 2); length=n_freqs)) + freq_points = _mel_to_hz.(mel_points) + filter_banks = _triangular_filterbanks(freq_points, all_freqs) + + if any(maximum(filter_banks; dims=1) .≈ 0f0) + @warn """At least one mel filterbank has all zero values. + The value for `n_mels=$n_mels` may be set too high. + Or the value for `n_freqs=$n_freqs` may be set too low. + """ + end + return filter_banks +end + +_hz_to_mel(freq::T) where T = T(2595) * log10(T(1) + (freq / T(700))) + +_mel_to_hz(mel::T) where T = T(700) * (T(10)^(mel / T(2595)) - T(1)) + +""" + _triangular_filterbanks( + freq_points::Vector{Float32}, all_freqs::Vector{Float32}) + +Create triangular filter banks. + +# Arguments: + +- `freq_points::Vector{Float32}`: Filter midpoints of size `n_filters`. +- `all_freqs::Vector{Float32}`: Frequency points of size `n_freqs`. + +# Returns: + +Array of size `(n_freqs, n_filters)`. +""" +function _triangular_filterbanks( + freq_points::Vector{Float32}, all_freqs::Vector{Float32}, +) + diff = @view(freq_points[2:end]) .- @view(freq_points[1:end - 1]) + slopes = transpose(reshape(freq_points, :, 1) .- reshape(all_freqs, 1, :)) + + down_slopes = -(@view(slopes[:, 1:end - 2]) ./ reshape(@view(diff[1:end - 1]), 1, :)) + up_slopes = @view(slopes[:, 3:end]) ./ reshape(@view(diff[2:end]), 1, :) + return max.(0f0, min.(down_slopes, up_slopes)) +end diff --git a/src/audio/spectrogram.jl b/src/audio/spectrogram.jl new file mode 100644 index 000000000..efee1d114 --- /dev/null +++ b/src/audio/spectrogram.jl @@ -0,0 +1,79 @@ +""" + spectrogram(waveform; + pad::Int = 0, n_fft::Int, hop_length::Int, window, + center::Bool = true, power::Real = 2.0, + normalized::Bool = false, window_normalized::Bool = false, + ) + +Create a spectrogram or a batch of spectrograms from a raw audio signal. + +# Arguments + +- `pad::Int`: + Then amount of padding to apply on both sides. +- `window_normalized::Bool`: + Whether to normalize the waveform by the window’s L2 energy. +- `power::Real`: + Exponent for the magnitude spectrogram (must be ≥ 0) + e.g., `1` for magnitude, `2` for power, etc. + If `0`, complex spectrum is returned instead. + +See [`stft`](@ref) for other arguments. + +# Returns + +Spectrogram in the shape `(T, F, B)`, where +`T` is the number of window hops and `F = n_fft ÷ 2 + 1`. +""" +function spectrogram(waveform::AbstractArray{T}; + pad::Int = 0, n_fft::Int, hop_length::Int, window, + center::Bool = true, power::Real = 2.0, + normalized::Bool = false, window_normalized::Bool = false, +) where T + pad > 0 && (waveform = pad_zeros(waveform, pad; dims=1);) + + # Pack batch dimensions. + sz = size(waveform) + spec_ = stft(reshape(waveform, (sz[1], :)); + n_fft, hop_length, window, center, normalized) + # Unpack batch dimensions. + spec = reshape(spec_, (size(spec_)[1:2]..., sz[2:end]...)) + window_normalized && (spec = spec .* inv(norm(window));) + + if power > 0 + p = T(power) + spec = abs.(spec .+ eps(T)).^p + end + return spec +end + +""" + power_to_db(s; ref::Real = 1f0, amin::Real = 1f-10, top_db::Real = 80f0) + +Convert a power spectrogram (amplitude squared) to decibel (dB) units. + +# Arguments + +- `s`: Input power. +- `ref`: Scalar w.r.t. which the input is scaled. +- `amin`: Minimum threshold for `s`. +- `top_db`: Threshold the output at `top_db` below the peak: + `max.(s_db, maximum(s_db) - top_db)`. + +# Returns + +`s_db ~= 10 * log10(s) - 10 * log10(ref)` +""" +function power_to_db(s; ref::Real = 1f0, amin::Real = 1f-10, top_db::Real = 80f0) + log_spec = 10f0 .* (log10.(max.(amin, s)) .- log10.(max.(amin, ref))) + return max.(log_spec, maximum(log_spec) - top_db) +end + +""" + db_to_power(s_db; ref::Real = 1f0) + +Inverse of [`power_to_db`](@ref). +""" +function db_to_power(s_db; ref::Real = 1f0) + return ref .* 10f0.^(s_db .* 0.1f0) +end diff --git a/src/audio/stft.jl b/src/audio/stft.jl new file mode 100644 index 000000000..c90e7f49c --- /dev/null +++ b/src/audio/stft.jl @@ -0,0 +1,206 @@ +""" + hamming_window( + window_length::Int, ::Type{T} = Float32; periodic::Bool = true, + α::T = T(0.54), β::T = T(0.46), + ) where T <: Real + +Hamming window function +(ref: [Window function § Hann and Hamming windows - Wikipedia](https://en.wikipedia.org/wiki/Window_function#Hann_and_Hamming_windows)). +Generalized version of `hann_window`. + +``w[n] = \\alpha - \\beta \\cos(\\frac{2 \\pi n}{N - 1})`` + +Where ``N`` is the window length. + +```julia-repl +julia> lineplot(hamming_window(100); width=30, height=10) + ┌──────────────────────────────┐ + 1 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡠⠚⠉⠉⠉⠢⡄⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ + │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⠎⠁⠀⠀⠀⠀⠀⠈⢢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ + │⠀⠀⠀⠀⠀⠀⠀⠀⠀⢠⠊⠀⠀⠀⠀⠀⠀⠀⠀⠀⢣⡀⠀⠀⠀⠀⠀⠀⠀⠀│ + │⠀⠀⠀⠀⠀⠀⠀⠀⢰⠃⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠱⡀⠀⠀⠀⠀⠀⠀⠀│ + │⠀⠀⠀⠀⠀⠀⠀⣠⠃⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠳⡀⠀⠀⠀⠀⠀⠀│ + │⠀⠀⠀⠀⠀⠀⢰⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠱⡄⠀⠀⠀⠀⠀│ + │⠀⠀⠀⠀⠀⡰⠃⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠱⡀⠀⠀⠀⠀│ + │⠀⠀⠀⢀⠴⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠘⢄⠀⠀⠀│ + │⠀⢀⡠⠊⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠳⣀⠀│ + 0 │⠉⠉⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠉│ + └──────────────────────────────┘ + ⠀0⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀100⠀ +``` + +# Arguments: + +- `window_length::Int`: Size of the window. +- `::Type{T}`: Elemet type of the window. + +# Keyword Arguments: + +- `periodic::Bool`: If `true` (default), returns a window to be used as + periodic function. If `false`, return a symmetric window. + + Following always holds: + +```jldoctest +julia> N = 256; + +julia> hamming_window(N; periodic=true) ≈ hamming_window(N + 1; periodic=false)[1:end - 1] +true +``` +- `α::Real`: Coefficient α in the equation above. +- `β::Real`: Coefficient β in the equation above. + +# Returns: + +Vector of length `window_length` and eltype `T`. +""" +function hamming_window( + window_length::Int, ::Type{T} = Float32; periodic::Bool = true, + α::T = T(0.54), β::T = T(0.46), +) where T <: Real + window_length < 1 && throw(ArgumentError( + "`window_length` must be > 0, instead: `$window_length`.")) + + n::T = ifelse(periodic, window_length, window_length - 1) + scale = T(2) * π / n + return [α - β * cos(scale * T(k)) for k in 0:(window_length - 1)] +end + +""" + hann_window( + window_length::Int, ::Type{T} = Float32; periodic::Bool = true, + ) where T <: Real + +Hann window function +(ref: [Window function § Hann and Hamming windows - Wikipedia](https://en.wikipedia.org/wiki/Window_function#Hann_and_Hamming_windows)). + +``w[n] = \\frac{1}{2}[1 - \\cos(\\frac{2 \\pi n}{N - 1})]`` + +Where ``N`` is the window length. + +```julia-repl +julia> lineplot(hann_window(100); width=30, height=10) + ┌──────────────────────────────┐ + 1 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣠⠚⠉⠉⠉⠢⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ + │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡔⠁⠀⠀⠀⠀⠀⠘⢄⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ + │⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⠞⠀⠀⠀⠀⠀⠀⠀⠀⠈⢆⠀⠀⠀⠀⠀⠀⠀⠀⠀│ + │⠀⠀⠀⠀⠀⠀⠀⠀⢀⡎⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⢣⠀⠀⠀⠀⠀⠀⠀⠀│ + │⠀⠀⠀⠀⠀⠀⠀⠀⡎⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⢦⠀⠀⠀⠀⠀⠀⠀│ + │⠀⠀⠀⠀⠀⠀⢀⠞⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⢆⠀⠀⠀⠀⠀⠀│ + │⠀⠀⠀⠀⠀⢀⡜⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⢇⠀⠀⠀⠀⠀│ + │⠀⠀⠀⠀⢀⠎⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⢦⠀⠀⠀⠀│ + │⠀⠀⠀⢠⠊⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠣⡀⠀⠀│ + 0 │⣀⣀⠔⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⢤⣀│ + └──────────────────────────────┘ + ⠀0⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀100⠀ +``` + +# Arguments: + +- `window_length::Int`: Size of the window. +- `::Type{T}`: Elemet type of the window. + +# Keyword Arguments: + +- `periodic::Bool`: If `true` (default), returns a window to be used as + periodic function. If `false`, return a symmetric window. + + Following always holds: + +```jldoctest +julia> N = 256; + +julia> hann_window(N; periodic=true) ≈ hann_window(N + 1; periodic=false)[1:end - 1] +true + +julia> hann_window(N) ≈ hamming_window(N; α=0.5f0, β=0.5f0) +true +``` + +# Returns: + +Vector of length `window_length` and eltype `T`. +""" +function hann_window( + window_length::Int, ::Type{T} = Float32; periodic::Bool = true, +) where T <: Real + hamming_window(window_length, T; periodic, α=T(0.5), β=T(0.5)) +end + +""" + stft(x; + n_fft::Int, hop_length::Int = n_fft ÷ 4, window = nothing, + center::Bool = true, normalized::Bool = false, + ) + +Short-time Fourier transform (STFT). + +The STFT computes the Fourier transform of short overlapping windows of the input, +giving frequency components of the signal as they change over time. + +``Y[\\omega, m] = \\sum_{k = 0}^{N - 1} \\text{window}[k] \\text{input}[m \\times \\text{hop length} + k] \\exp(-j \\frac{2 \\pi \\omega k}{\\text{n fft}})`` + +where ``N`` is the window length, +``\\omega`` is the frequency ``0 \\le \\omega < \\text{n fft}`` +and ``m`` is the index of the sliding window. + +# Arguments: + +- `x`: Input, must be either a 1D time sequence (`(L,)` shape) + or a 2D batch of time sequence (`(L, B)` shape). + +# Keyword Arguments: + +- `n_fft::Int`: Size of Fourier transform. +- `hop_length::Int`: Distance between neighboring sliding window frames. +- `window`: Optional window function to apply. + Must be 1D vector `0 < length(window) ≤ n_fft`. + If window is shorter than `n_fft`, it is padded with zeros on both sides. + If `nothing` (default), then no window is applied. +- `center::Bool`: Whether to pad input on both sides so that ``t``-th frame + is centered at time ``t \\times \\text{hop length}``. + Padding is done with `pad_reflect` function. +- `normalized::Bool`: Whether to return normalized STFT, + i.e. multiplied with ``\\text{n fft}^{-0.5}``. + +# Returns: + +Complex array of shape `(n_fft, n_frames, B)`, +where `B` is the optional batch dimension. +""" +function stft end + +""" + istft(y; + n_fft::Int, hop_length::Int = n_fft ÷ 4, window = nothing, + center::Bool = true, normalized::Bool = false, + return_complex::Bool = false, + original_length::Union{Nothing, Int} = nothing, + ) + +Inverse Short-time Fourier Transform. + +Return the least squares estimation of the original signal + +# Arguments: + +- `y`: Input complex array in the `(n_fft, n_frames, B)` shape. + Where `B` is the optional batch dimension. + +# Keyword Arguments: + +- `n_fft::Int`: Size of Fourier transform. +- `hop_length::Int`: Distance between neighboring sliding window frames. +- `window`: Window function that was applied to the input of `stft`. + If `nothing` (default), then no window was applied. +- `center::Bool`: Whether input to `stft` was padded on both sides + so that ``t``-th frame is centered at time ``t \\times \\text{hop length}``. + Padding is done with `pad_reflect` function. +- `normalized::Bool`: Whether input to `stft` was normalized. +- `return_complex::Bool`: Whether the output should be complex, + or if the input should be assumed to derive from a real signal and window. +- `original_length::Union{Nothing, Int}`: Optional size of the first dimension + of the input to `stft`. Helps restoring the exact `stft` input size. + Otherwise, the array might be a bit shorter. +""" +function istft end diff --git a/src/batched/batchedadjtrans.jl b/src/batched/batchedadjtrans.jl index 130f039bc..9423a6fb6 100644 --- a/src/batched/batchedadjtrans.jl +++ b/src/batched/batchedadjtrans.jl @@ -87,6 +87,7 @@ function Base.stride(A::Union{BatchedTranspose, BatchedAdjoint{<:Real}}, d::Inte Base.stride(A.parent, d) end +Base.pointer(A::BatchedAdjOrTrans) = pointer(parent(A)) Base.unsafe_convert(::Type{Ptr{T}}, A::BatchedAdjOrTrans{T}) where {T} = Base.unsafe_convert(Ptr{T}, parent(A)) diff --git a/src/batched/batchedmul.jl b/src/batched/batchedmul.jl index a3b7efc74..ccd9b0e82 100644 --- a/src/batched/batchedmul.jl +++ b/src/batched/batchedmul.jl @@ -203,7 +203,7 @@ In-place batched matrix multiplication, equivalent to If `size(B,3) == 1` then every batch uses `B[:,:,1]` instead. This will call `batched_gemm!` whenever possible. For real arrays this means that, -for `X ∈ [A,B,C]`, either `strides(X,1)==1` or `strides(X,2)==1`, the latter may +for `X ∈ [A,B,C]`, either `stride(X,1)==1` or `stride(X,2)==1`, the latter may be caused by `batched_transpose` or by for instance `PermutedDimsArray(::Array, (3,1,2))`. Unlike `batched_mul` this will never make a copy. diff --git a/src/bias_act.jl b/src/bias_act.jl index ef7fb29d9..935a50239 100644 --- a/src/bias_act.jl +++ b/src/bias_act.jl @@ -8,7 +8,7 @@ const RCR = RuleConfig{>:HasReverseMode} @inline only_derivative(y,f::F,x) where F = only(only(ChainRulesCore.derivatives_given_output(y, f, x))) # This has no methods, used for testing whether `derivatives_given_output(Ω, f, x)` -# is independent of `x`, as `_return_type` says `Union{}` when calling is an error. +# is independent of `x`, as `return_type` says `Union{}` when calling is an error. struct NotaNumber <: Real end """ @@ -57,7 +57,7 @@ function ChainRulesCore.rrule(cfg::RCR, ::typeof(bias_act!), σ::F, x::AbstractA end # Fast path: it is now safe to overwrite x, since this is not needed for gradient of σ - if isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, NotaNumber})) + if isconcretetype(Core.Compiler.return_type(only_derivative, Tuple{T, F, NotaNumber})) Ω = bias_act!(σ, x, b) # now x === Ω, when x isa StridedArray{<:AbstractFloat} function bias_act!_fastback(Δ) # Tempting to overwrite x again, but only safe if you call pullback at most once, @@ -70,7 +70,7 @@ function ChainRulesCore.rrule(cfg::RCR, ::typeof(bias_act!), σ::F, x::AbstractA # # Slower path: can't overwrite x, but can use derivatives_given_output # # This case is WRONG and tests fail, but not sure why - # elseif isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, T})) + # elseif isconcretetype(Core.Compiler.return_type(only_derivative, Tuple{T, F, T})) # Ω2 = fast_act(σ, x).(x) .+ b # @show σ b # function bias_act!_back2(Δ) diff --git a/src/conv.jl b/src/conv.jl index 3fecb9151..26b3d0362 100644 --- a/src/conv.jl +++ b/src/conv.jl @@ -76,7 +76,7 @@ end # Let's generate auto-allocating versions of all our functions, for all backends. # We `@timeit` these methods separately, as we want to know how much time is spent in # allocation. :P -for backend in (Symbol(), :_direct, :_im2col, :_nnpack) +for backend in (Symbol(), :_direct, :_im2col) # First make auto-allocating versions of the conv()-like calls: for name in (:conv, :depthwiseconv) @eval begin @@ -134,7 +134,7 @@ end # since we can specialize on sizes. for front_name in (:conv, :∇conv_data, :∇conv_filter, :depthwiseconv, :∇depthwiseconv_data, :∇depthwiseconv_filter) - for backend in (Symbol(), :_direct, :_im2col) ## NNPACK is only for 2d conv + for backend in (Symbol(), :_direct, :_im2col) for N in (3, 4) @eval begin function $(Symbol("$(front_name)$(backend)!"))( @@ -181,7 +181,7 @@ for (front_name, backend, signature) in ( ) # We only define 3d conv primitives, we reshape lower down to get 1d and 2d convolution @eval begin - + function $(Symbol("$(front_name)!"))( out::AbstractArray{$(signature[1][1]), $(signature[1][2])}, in1::AbstractArray{$(signature[2][1]), $(signature[1][2])}, @@ -202,11 +202,21 @@ for (front_name, backend, signature) in ( C_in = channels_in(cdims) ÷ groupcount(cdims), C_out = channels_out(cdims) ÷ groupcount(cdims)) - Threads.@sync for (xc, wc) in zip(x_cs, w_cs) + function conv_group(xc, wc) x = @view in1[ntuple(i -> i == 4 ? xc : Colon(), 5)...] w = @view in2[ntuple(i -> i == 5 ? wc : Colon(), 5)...] y = @view out[ntuple(i -> i == 4 ? wc : Colon(), 5)...] - Threads.@spawn $(Symbol("$(front_name)_$(backend)!"))(y, x, w, cdims2; kwargs...) + $(Symbol("$(front_name)_$(backend)!"))(y, x, w, cdims2; kwargs...) + end + + if should_use_spawn() && length(x_cs) > 1 + Threads.@sync for (xc, wc) in zip(x_cs, w_cs) + Threads.@spawn conv_group(xc, wc) + end + else + for (xc, wc) in zip(x_cs, w_cs) + conv_group(xc, wc) + end end return out @@ -246,11 +256,21 @@ for (front_name, backend, signature) in ( C_in = channels_in(cdims) ÷ groupcount(cdims), C_out = channels_out(cdims) ÷ groupcount(cdims)) - Threads.@sync for (xc, yc, wc) in zip(dx_cs, dy_cs, w_cs) + function ∇conv_data_group(xc, yc, wc) dxv = @view out[ntuple(i -> i == 4 ? xc : Colon(), 5)...] dyv = @view in1[ntuple(i -> i == 4 ? yc : Colon(), 5)...] wv = @view in2[ntuple(i -> i == 5 ? wc : Colon(), 5)...] - Threads.@spawn $(Symbol("$(front_name)_$(backend)!"))(dxv, dyv, wv, cdims2; kwargs...) + $(Symbol("$(front_name)_$(backend)!"))(dxv, dyv, wv, cdims2; kwargs...) + end + + if should_use_spawn() && length(dx_cs) > 1 + Threads.@sync for (xc, yc, wc) in zip(dx_cs, dy_cs, w_cs) + Threads.@spawn ∇conv_data_group(xc, yc, wc) + end + else + for (xc, yc, wc) in zip(dx_cs, dy_cs, w_cs) + ∇conv_data_group(xc, yc, wc) + end end return out @@ -288,11 +308,21 @@ for (front_name, backend, signature) in ( C_in = channels_in(cdims) ÷ groupcount(cdims), C_out = channels_out(cdims) ÷ groupcount(cdims)) - Threads.@sync for (wc, xc, yc) in zip(dw_cs, x_cs, dy_cs) + function ∇conv_filter_group(wc, xc, yc) x = @view in1[ntuple(i -> i == 4 ? xc : Colon(), 5)...] dy = @view in2[ntuple(i -> i == 4 ? yc : Colon(), 5)...] - dw = @view out[ntuple(i -> i == 5 ? yc : Colon(), 5)...] - Threads.@spawn $(Symbol("$(front_name)_$(backend)!"))(dw, x, dy, cdims2; kwargs...) + dw = @view out[ntuple(i -> i == 5 ? wc : Colon(), 5)...] + $(Symbol("$(front_name)_$(backend)!"))(dw, x, dy, cdims2; kwargs...) + end + + if should_use_spawn() && length(dw_cs) > 1 + Threads.@sync for (wc, xc, yc) in zip(dw_cs, x_cs, dy_cs) + Threads.@spawn ∇conv_filter_group(wc, xc, yc) + end + else + for (wc, xc, yc) in zip(dw_cs, x_cs, dy_cs) + ∇conv_filter_group(wc, xc, yc) + end end return out @@ -306,10 +336,10 @@ for (front_name, backend, signature) in ( # (frontend, backend, (out Array signature, in1 Array signature, in2 Array signature, (parametric Types))) (:depthwiseconv, :im2col, ((:T, 5), (:T, 5), (:T, 5), :C, (:(T <: G), :(C <: ConvDims)))), (:depthwiseconv, :direct, ((:yT, :N), (:T1, :N), (:T2, :N), :C, (:yT, :T1, :T2, :N, :(C <: ConvDims)))), - + (:∇depthwiseconv_data, :im2col, ((:T, 5), (:T, 5), (:T, 5), :C, (:(T <: G), :(C <: ConvDims)))), (:∇depthwiseconv_data, :direct, ((:yT, :N), (:T1, :N), (:T2, :N), :C, (:yT, :T1, :T2, :N, :(C <: ConvDims)))), - + (:∇depthwiseconv_filter, :im2col, ((:T, 5), (:T, 5), (:T, 5), :C, (:(T <: G), :(C <: ConvDims)))), (:∇depthwiseconv_filter, :direct, ((:yT, :N), (:T1, :N), (:T2, :N), :C, (:yT, :T1, :T2, :N, :(C <: ConvDims)))), ) @@ -343,12 +373,12 @@ for conv in [:conv, :depthwiseconv] conv_pullback, ∇conv_data_pullback = Symbol.([conv, ∇conv_data], :_pullback) @eval function rrule(::typeof($conv), x, w, cdims; kw...) - function $conv_pullback(Δ) - Δ = colmajor(Δ) + function $conv_pullback(Δraw) + Δ = colmajor(unthunk(Δraw)) return ( NoTangent(), - @thunk($∇conv_data(unthunk(Δ), w, cdims, kw...)), - @thunk($∇conv_filter(x, unthunk(Δ), cdims, kw...)), + @thunk($∇conv_data(Δ, w, cdims, kw...)), + @thunk($∇conv_filter(x, Δ, cdims, kw...)), NoTangent(), ) end @@ -356,12 +386,12 @@ for conv in [:conv, :depthwiseconv] end @eval function rrule(::typeof($∇conv_data), x, w, cdims; kw...) - function $∇conv_data_pullback(Δ) - Δ = colmajor(Δ) + function $∇conv_data_pullback(Δraw) + Δ = colmajor(unthunk(Δraw)) return ( NoTangent(), - @thunk($conv(unthunk(Δ), w, cdims, kw...)), - @thunk($∇conv_filter(unthunk(Δ), x, cdims, kw...)), + @thunk($conv(Δ, w, cdims, kw...)), + @thunk($∇conv_filter(Δ, x, cdims, kw...)), NoTangent(), ) end @@ -381,26 +411,3 @@ function rrule(::typeof(∇conv_filter), x, dy, cdims; kw...) end return ∇conv_filter(x, dy, cdims; kw...), ∇conv_filter_pullback end - -# Use NNPACK if it is available and the operation is supported -# commented out 'till proper benchmarking and more correctness test are performed -# if is_nnpack_available() -# function conv(x::Array{Float32, 4}, w::Array{Float32, 4}, -# cdims::DenseConvDims{2, K, C_in, C_out, (1, 1), P, (1, 1), F}; -# kwargs...) where {K, C_in, C_out, P, F} -# return conv_nnpack(x, w, cdims; kwargs...) -# end - -# function ∇conv_data(dy::Array{Float32, 4}, w::Array{Float32, 4}, -# cdims::DenseConvDims{2, K, C_in, C_out, (1, 1), P, (1, 1), F}; -# kwargs...) where {K, C_in, C_out, P, F} -# return ∇conv_data_nnpack(dy, w, cdims; kwargs...) -# end - -# function ∇conv_filter(x::Array{Float32, 4}, dy::Array{Float32, 4}, -# cdims::DenseConvDims{2, K, C_in, C_out, (1, 1), P, (1, 1), F}; -# kwargs...) where {K, C_in, C_out, P, F} -# return ∇conv_filter_nnpack(x, dy, cdims; kwargs...) -# end -# end -######################################################## diff --git a/src/ctc.jl b/src/ctc.jl index 6202622c3..c5188768f 100644 --- a/src/ctc.jl +++ b/src/ctc.jl @@ -23,7 +23,8 @@ function logaddexp(a, b) end """ - add_blanks(z) + add_blanks(z) + Adds blanks to the start and end of `z`, and between items in `z` """ function add_blanks(z, blank) diff --git a/src/dim_helpers/ConvDims.jl b/src/dim_helpers/ConvDims.jl index 9358a41a8..e8bcc08f4 100644 --- a/src/dim_helpers/ConvDims.jl +++ b/src/dim_helpers/ConvDims.jl @@ -73,7 +73,7 @@ function im2col_dims(c::ConvDims) # Size of single dotproduct within convolution prod(kernel_size(c))*channels_in(c), # One workspace per thread - VERSION > v"1.9.0-0" ? Threads.maxthreadid() : Threads.nthreads(), + VERSION > v"1.9.0-0" ? Threads.nthreads(:default) : Threads.nthreads(), ) end diff --git a/src/dim_helpers/DepthwiseConvDims.jl b/src/dim_helpers/DepthwiseConvDims.jl index 8163a3def..fbfbcd718 100644 --- a/src/dim_helpers/DepthwiseConvDims.jl +++ b/src/dim_helpers/DepthwiseConvDims.jl @@ -2,7 +2,7 @@ DepthwiseConvDims Concrete subclass of `ConvDims` for a depthwise convolution. Differs primarily due to -characterization by C_in, C_mult, rather than C_in, C_out. Useful to be separate from +characterization by `C_in`, `C_mult`, rather than `C_in`, `C_out`. Useful to be separate from DenseConvDims primarily for channel calculation differences. """ struct DepthwiseConvDims{N, K, S, P, D} <: ConvDims{N} diff --git a/src/dim_helpers/PoolDims.jl b/src/dim_helpers/PoolDims.jl index 75d56b8cd..bfb39e1bc 100644 --- a/src/dim_helpers/PoolDims.jl +++ b/src/dim_helpers/PoolDims.jl @@ -1,6 +1,6 @@ """ PoolDims(x_size::NTuple{M}, k::Union{NTuple{L, Int}, Int}; - stride=k, padding=0, dilation=1) where {M, L} + stride=k, padding=0, dilation=1) where {M, L} Dimensions for a "pooling" operation that can have an arbitrary input size, kernel size, stride, dilation, and channel count. Used to dispatch onto efficient implementations at diff --git a/src/dropout.jl b/src/dropout.jl index 02673cf03..44fc59c14 100644 --- a/src/dropout.jl +++ b/src/dropout.jl @@ -12,7 +12,7 @@ i.e. each row of a matrix is either zero or not. Optional first argument is the random number generator used. # Examples -``` +```julia-repl julia> dropout(ones(2, 10), 0.2) 2×10 Matrix{Float64}: 1.25 1.25 0.0 1.25 1.25 1.25 1.25 1.25 1.25 1.25 diff --git a/src/fold.jl b/src/fold.jl index dcde60e4f..f3c205e15 100644 --- a/src/fold.jl +++ b/src/fold.jl @@ -1,4 +1,3 @@ - """ unfold(x, kernel_size; stride = 1, pad = 0, dilation = 0, flipped = true) @@ -7,10 +6,10 @@ window_size, batchsize)`. The window size is determined by the `prod(spatial dim of kernel)*input_channels`. The number of sliding windows will match those of convolution (`conv`) with the same kernel_size and arguments. Note that by default `conv` flips the spatial dimensions of its kernel (default -`flipped=false`), whereas `unfold` does not (default `flipped=true`). -Uses `NNlib.im2col!` as backend. +`flipped=false`), whereas `unfold` does not (default `flipped=true`). +Uses `NNlib.im2col!` as backend. -See also [`fold`](@ref), the adjoint/transpose operator +See also [`fold`](@ref), the adjoint/transpose operator and a potential inverse of `unfold`. # Example @@ -23,7 +22,7 @@ julia> w = reshape([1 0 -1], 3, 1, 1); # 1D conv kernel of length 3 julia> kws = (pad=1, stride=2, flipped=true); # use same args for conv and unfold -julia> z = NNlib.unfold(x, size(w); kws...) +julia> z = NNlib.unfold(x, size(w); kws...) 4×3×1 Array{Int64, 3}: [:, :, 1] = 0 100 2 @@ -61,8 +60,8 @@ end The adjoint/transpose operator of `unfold`. It accumulates sliding windows from the output of `unfold` into a container tensor of size `output_size`. An inverse -to `unfold` may be obtained (in some cases) by using `fold` and accounting for scaling issues -with a divisor (see example). Uses `NNlib.col2im!` as backend. +to `unfold` may be obtained (in some cases) by using `fold` and accounting for scaling issues +with a divisor (see example). Uses `NNlib.col2im!` as backend. See also [`unfold`](@ref). @@ -101,7 +100,7 @@ julia> divisor = NNlib.fold(NNlib.unfold(ones(size(x)...), (3,1,1)), size(x), (3 2.0 1.0 -julia> z ./ divisor +julia> z ./ divisor 7×1×1 Array{Float64, 3}: [:, :, 1] = 100.0 @@ -133,30 +132,30 @@ function unfold(x::AbstractArray{T, N}, cdims::DenseConvDims) where {T, N} end function fold(y::AbstractArray{T, 3}, output_size::NTuple, cdims::DenseConvDims) where {T} - x = similar(y, output_size) + x = similar(y, output_size) return fold!(x, y, cdims) end -# N < 5 -dimension in-place versions +# N < 5 -dimension in-place versions function unfold!(y::AbstractArray{yT, 3}, x::AbstractArray{xT, N}, cdims::DenseConvDims) where {yT, xT, N} unfold!( - y, - insert_singleton_spatial_dimension(x, 5-N), - insert_singleton_spatial_dimension(cdims, 5-N), + y, + insert_singleton_spatial_dimension(x, 5-N), + insert_singleton_spatial_dimension(cdims, 5-N), ) return y end function fold!(x::AbstractArray{xT, N}, y::AbstractArray{yT, 3}, cdims::DenseConvDims) where {yT, xT, N} fold!( - insert_singleton_spatial_dimension(x, 5-N), + insert_singleton_spatial_dimension(x, 5-N), y, - insert_singleton_spatial_dimension(cdims, 5-N), + insert_singleton_spatial_dimension(cdims, 5-N), ) return x end -# 5-dimension in-place versions +# 5-dimension in-place versions function unfold!(y::AbstractArray{yT, 3}, x::AbstractArray{xT, 5}, cdims::DenseConvDims) where {yT, xT} @threads for batch_idx in 1:size(x, 5) y_slice = view(y, :, :, batch_idx) @@ -173,6 +172,110 @@ function fold!(x::AbstractArray{xT, 5}, y::AbstractArray{yT, 3}, cdims::DenseCon return x end +@kernel function unfold_kernel!( + col::AbstractArray{T}, x, col_size, + input_size, output_size, kernel_size, + flipkernel, stride, pad_lo, dilation, max_idx, +) where T + index = @index(Global) + + @inbounds if index ≤ max_idx + i, kw, kh, kd, c, b = CartesianIndices(col_size)[index].I # col indices + w, h, d = CartesianIndices(output_size)[i].I # x indices + + # project + w, h, d = @. ((w, h, d) - 1) * stride - pad_lo + 1 + ((kw, kh, kd) - 1) * dilation + + if !flipkernel + kw, kh, kd = kernel_size .- (kw, kh, kd) .+ 1 + end + + # check out of bounds + if !all(checkindex.(Bool, UnitRange.(1, input_size), (w, h, d))) + col[i, kw, kh, kd, c, b] = T(0) + else + xval::T = x[w, h, d, c, b] + col[i, kw, kh, kd, c, b] = xval + end + end +end + +@kernel function fold_kernel!( + x::AbstractArray{T}, col, col_size, + input_size, output_size, kernel_size, + flipkernel, stride, pad_lo, dilation, max_idx, +) where T + index = @index(Global) + + @inbounds if index ≤ max_idx + i, kw, kh, kd, c, b = CartesianIndices(col_size)[index].I # col indices + w, h, d = CartesianIndices(output_size)[i].I # x indices + + # project + w, h, d = @. ((w, h, d) - 1) * stride - pad_lo + 1 + ((kw, kh, kd) - 1) * dilation + + # check out of bounds + if all(checkindex.(Bool, UnitRange.(1, input_size), (w, h, d))) + if !flipkernel + kw, kh, kd = kernel_size .- (kw, kh, kd) .+ 1 + end + + cval::T = col[i, kw, kh, kd, c, b] + @atomic x[w, h, d, c, b] += cval + end + end +end + +function unfold!( + col::AnyGPUArray{cT,3}, x::AnyGPUArray{xT,5}, cdims::DenseConvDims, +) where {cT, xT} + spatial_dims(cdims) != 3 && throw(DimensionMismatch( + "unfold!() only accepts 3d convoluitional inputs")) + + C_in = channels_in(cdims) + ker_size = kernel_size(cdims) + pad_w_lo, pad_w_hi, pad_h_lo, pad_h_hi, pad_d_lo, pad_d_hi = padding(cdims) + pad_lo = (pad_w_lo, pad_h_lo, pad_d_lo) + + out_size = output_size(cdims) + col_reshaped = reshape(col, (prod(out_size), ker_size..., C_in, :)) + + max_idx = prod(size(col)) + unfold_kernel!(get_backend(x))( + col_reshaped, x, size(col_reshaped), + input_size(cdims), out_size, ker_size, + flipkernel(cdims), stride(cdims), pad_lo, dilation(cdims), max_idx; + ndrange=max_idx) + return col +end + +function fold!( + x::AnyGPUArray{xT,5}, col::AnyGPUArray{cT,3}, cdims::DenseConvDims, +) where {xT, cT} + spatial_dims(cdims) != 3 && throw(DimensionMismatch( + "fold!() only accepts 3d convoluitional inputs")) + + # going to accumulate into x + fill!(x, xT(0)) + + C_in = channels_in(cdims) + ker_size = kernel_size(cdims) + pad_w_lo, pad_w_hi, pad_h_lo, pad_h_hi, pad_d_lo, pad_d_hi = padding(cdims) + pad_lo = (pad_w_lo, pad_h_lo, pad_d_lo) + out_size = output_size(cdims) + + col_reshaped = reshape(col, (prod(out_size), ker_size..., C_in, :)) + + max_idx = prod(size(col)) + fold_kernel!(get_backend(x))( + x, col_reshaped, size(col_reshaped), + input_size(cdims), out_size, ker_size, + flipkernel(cdims), stride(cdims), pad_lo, dilation(cdims), max_idx; + ndrange=max_idx) + + return x +end + # reverse diff rules function rrule(::typeof(unfold), x, cdims::DenseConvDims; kw...) function unfold_pullback(Δ) diff --git a/src/gather.jl b/src/gather.jl index 1ad69df24..d75f89a2c 100644 --- a/src/gather.jl +++ b/src/gather.jl @@ -109,7 +109,7 @@ function gather!(dst::AbstractArray, src::AbstractArray, idx::AbstractArray) return dst end -function gather!(dst::AbstractGPUArray, src::AbstractGPUArray, idx::AbstractGPUArray) +function gather!(dst::AnyGPUArray, src::AnyGPUArray, idx::AnyGPUArray) n_dims = scatter_dims(src, dst, idx) dims = size(src)[1:n_dims] max_dims_idx = prod(dims) diff --git a/src/gemm.jl b/src/gemm.jl index 051508750..9a3c6cd57 100644 --- a/src/gemm.jl +++ b/src/gemm.jl @@ -3,6 +3,12 @@ using LinearAlgebra.BLAS: get_num_threads, set_num_threads +if isdefined(LinearAlgebra.BLAS, :libblastrampoline) + const libblas = LinearAlgebra.BLAS.libblastrampoline +else + const libblas = Base.libblas_name +end + """ gemm!() @@ -80,16 +86,16 @@ for (gemm, elt) in gemm_datatype_mappings LinearAlgebra.BLAS.chkstride1(B) LinearAlgebra.BLAS.chkstride1(C) - ptrA = Base.unsafe_convert(Ptr{$elt}, A) - ptrB = Base.unsafe_convert(Ptr{$elt}, B) - ptrC = Base.unsafe_convert(Ptr{$elt}, C) + ptrA = pointer(A) + ptrB = pointer(B) + ptrC = pointer(C) strA = size(A, 3) == 1 ? 0 : Base.stride(A, 3) strB = size(B, 3) == 1 ? 0 : Base.stride(B, 3) strC = Base.stride(C, 3) n_threads = min( - VERSION > v"1.9.0-0" ? Threads.maxthreadid() : Threads.nthreads(), + VERSION > v"1.9.0-0" ? Threads.nthreads(:default) : Threads.nthreads(), 1 + max(length(A), length(B)) ÷ 8000) # In some tests, size (20,20,20) is worth splitting between two threads, # as is size (32,32,8). @@ -98,22 +104,34 @@ for (gemm, elt) in gemm_datatype_mappings old_threads = get_num_threads() set_num_threads(1) - Threads.@sync for ks in Iterators.partition(1:size(C, 3), cld(size(C, 3), n_threads)) - Threads.@spawn for k in ks + + parts = Iterators.partition(1:size(C, 3), cld(size(C, 3), n_threads)) + + function gemm!_part(ks) + for k in ks ptrAk = ptrA + (k-1) * strA * sizeof($elt) ptrBk = ptrB + (k-1) * strB * sizeof($elt) ptrCk = ptrC + (k-1) * strC * sizeof($elt) ccall((@blasfunc($(gemm)), libblas), Nothing, - (Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, - Ref{BlasInt}, Ref{$elt}, Ptr{$elt}, Ref{BlasInt}, - Ptr{$elt}, Ref{BlasInt}, Ref{$elt}, Ptr{$elt}, - Ref{BlasInt}), - transA, transB, m, n, - ka, alpha, ptrAk, max(1,Base.stride(A,2)), - ptrBk, max(1,Base.stride(B,2)), beta, ptrCk, - max(1,Base.stride(C,2))) + (Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, + Ref{BlasInt}, Ref{$elt}, Ptr{$elt}, Ref{BlasInt}, + Ptr{$elt}, Ref{BlasInt}, Ref{$elt}, Ptr{$elt}, + Ref{BlasInt}), + transA, transB, m, n, + ka, alpha, ptrAk, max(1,Base.stride(A,2)), + ptrBk, max(1,Base.stride(B,2)), beta, ptrCk, + max(1,Base.stride(C,2))) + end + end + if should_use_spawn() && length(parts) > 1 + Threads.@sync for ks in parts + Threads.@spawn gemm!_part(ks) + end + else + for ks in parts + gemm!_part(ks) end end set_num_threads(old_threads) diff --git a/src/impl/conv_direct.jl b/src/impl/conv_direct.jl index 9f12f1dc9..497f2e929 100644 --- a/src/impl/conv_direct.jl +++ b/src/impl/conv_direct.jl @@ -81,6 +81,11 @@ function conv_direct!( # Use `calc_padding_regions` to determine where we do or don't need to worry about padding padded_regions, central_region = calc_padding_regions(cdims) + # Set outputs to zero to support custom datatypes (https://github.com/FluxML/NNlib.jl/issues/490) + if iszero(beta) + y = fill!(y, zero(yT)) + end + # Start with the central region w_region, h_region, d_region = central_region @inbounds for batch in 1:size(x, 5), diff --git a/src/impl/conv_im2col.jl b/src/impl/conv_im2col.jl index cde297c08..1893822de 100644 --- a/src/impl/conv_im2col.jl +++ b/src/impl/conv_im2col.jl @@ -47,20 +47,28 @@ function conv_im2col!( parts = Iterators.partition(axes(x, 5), ceil(Int, size(x, 5) / ntasks)) - @sync for (task_n, part) in enumerate(parts) - Threads.@spawn begin - col_slice = col_slice = view(col, :, :, task_n) # col_slice is a task-local workspace - for batch_idx in part - im2col!(col_slice, view(x, :, :, :, :, batch_idx), cdims) - GC.@preserve col_slice w y begin - col_ptr = pointer(col_slice) - w_ptr = pointer(w) - y_ptr = pointer(y, (batch_idx - 1)*M*N + 1) - gemm!(Val(false), Val(false), M, N, K, alpha, col_ptr, w_ptr, beta, y_ptr) - end + function conv_part(task_n, part) + col_slice = col_slice = view(col, :, :, task_n) # col_slice is a task-local workspace + for batch_idx in part + im2col!(col_slice, view(x, :, :, :, :, batch_idx), cdims) + GC.@preserve col_slice w y begin + col_ptr = pointer(col_slice) + w_ptr = pointer(w) + y_ptr = pointer(y, (batch_idx - 1)*M*N + 1) + gemm!(Val(false), Val(false), M, N, K, alpha, col_ptr, w_ptr, beta, y_ptr) end end end + + if should_use_spawn() && length(parts) > 1 + @sync for (task_n, part) in enumerate(parts) + Threads.@spawn conv_part(task_n, part) + end + else + for (task_n, part) in enumerate(parts) + conv_part(task_n, part) + end + end return y end @@ -152,18 +160,25 @@ function ∇conv_data_im2col!( parts = Iterators.partition(axes(dx, 5), ceil(Int, size(dx, 5) / ntasks)) - @sync for (task_n, part) in enumerate(parts) - Threads.@spawn begin - col_slice = col_slice = view(col, :, :, task_n) # col_slice is a task-local workspace - for batch_idx in part - GC.@preserve col_slice w dy begin - dy_ptr = pointer(dy, (batch_idx - 1)*M*K + 1) - w_ptr = pointer(w) - col_ptr = pointer(col_slice) - gemm!(Val(false), Val(true), M, N, K, alpha, dy_ptr, w_ptr, T(0), col_ptr) - end - col2im!(view(dx, :, :, :, :, batch_idx), col_slice, cdims) + function ∇conv_data_part(task_n, part) + col_slice = col_slice = view(col, :, :, task_n) # col_slice is a task-local workspace + for batch_idx in part + GC.@preserve col_slice w dy begin + dy_ptr = pointer(dy, (batch_idx - 1)*M*K + 1) + w_ptr = pointer(w) + col_ptr = pointer(col_slice) + gemm!(Val(false), Val(true), M, N, K, alpha, dy_ptr, w_ptr, T(0), col_ptr) end + col2im!(view(dx, :, :, :, :, batch_idx), col_slice, cdims, beta) + end + end + if should_use_spawn() && length(parts) > 1 + @sync for (task_n, part) in enumerate(parts) + Threads.@spawn ∇conv_data_part(task_n, part) + end + else + for (task_n, part) in enumerate(parts) + ∇conv_data_part(task_n, part) end end return dx @@ -276,7 +291,7 @@ end """ - col2im!(x, col, cdims) + col2im!(x, col, cdims, beta=0) Does the inverse of `im2col!()`, converting `col` back into a 3d image, used for backward passes, transposed convolutions, etc... @@ -287,7 +302,7 @@ desperate enough yet. """ col2im! -function col2im!(x::AbstractArray{T,4}, col::AbstractArray{T,2}, cdims::ConvDims) where T +function col2im!(x::AbstractArray{T,4}, col::AbstractArray{T,2}, cdims::ConvDims, beta::T=T(0)) where T if spatial_dims(cdims) != 3 throw(DimensionMismatch("col2im!() only accepts 3d convoluitional inputs")) end @@ -303,7 +318,13 @@ function col2im!(x::AbstractArray{T,4}, col::AbstractArray{T,2}, cdims::ConvDims # TODO: Rewrite this method so we don't have this fill!() at the beginning! # Calculate each output pixel once rather than accumulating into it? - fill!(x, T(0)) + if beta == T(0) + fill!(x, T(0)) + elseif beta == T(1) + # nothing + else + x .*= beta + end # Reshape col for easy access. col_reshaped = reshape(col, ( diff --git a/src/impl/depthwiseconv_im2col.jl b/src/impl/depthwiseconv_im2col.jl index 30edc16e1..60e40a9dc 100644 --- a/src/impl/depthwiseconv_im2col.jl +++ b/src/impl/depthwiseconv_im2col.jl @@ -30,25 +30,32 @@ function depthwiseconv_im2col!( dcdims = DenseConvDims(cdims) - @sync for (task_n, part) in enumerate(parts) - Threads.@spawn begin - col_slice = col_slice = view(col, :, :, task_n) # col_slice is a task-local workspace - for batch_idx in part - im2col!(col_slice, view(x, :, :, :, :, batch_idx), dcdims) - - # We do a separate convolution for each channel in x, as we must - for c_in in 1:channels_in(cdims) - # Walk each pointer forward as we process each input channel - GC.@preserve col_slice w y begin - col_ptr = pointer(col_slice, (c_in-1)*M*K+1) - w_ptr = pointer(w, (c_in-1)*K*N+1) - y_ptr = pointer(y, ((batch_idx - 1)*channels_in(cdims) + c_in - 1)*M*N + 1) - gemm!(Val(false), Val(false), M, N, K, alpha, col_ptr, w_ptr, beta, y_ptr) - end + function depthwiseconv_part(task_n, part) + col_slice = col_slice = view(col, :, :, task_n) # col_slice is a task-local workspace + for batch_idx in part + im2col!(col_slice, view(x, :, :, :, :, batch_idx), dcdims) + + # We do a separate convolution for each channel in x, as we must + for c_in in 1:channels_in(cdims) + # Walk each pointer forward as we process each input channel + GC.@preserve col_slice w y begin + col_ptr = pointer(col_slice, (c_in-1)*M*K+1) + w_ptr = pointer(w, (c_in-1)*K*N+1) + y_ptr = pointer(y, ((batch_idx - 1)*channels_in(cdims) + c_in - 1)*M*N + 1) + gemm!(Val(false), Val(false), M, N, K, alpha, col_ptr, w_ptr, beta, y_ptr) end end end end + if should_use_spawn() && length(parts) > 1 + @sync for (task_n, part) in enumerate(parts) + Threads.@spawn depthwiseconv_part(task_n, part) + end + else + for (task_n, part) in enumerate(parts) + depthwiseconv_part(task_n, part) + end + end return y end @@ -117,22 +124,29 @@ function ∇depthwiseconv_data_im2col!( parts = Iterators.partition(axes(dx)[end], ceil(Int, size(dx, 5) / ntasks)) - @sync for (task_n, part) in enumerate(parts) - Threads.@spawn begin - col_slice = col_slice = view(col, :, :, task_n) # col_slice is a task-local workspace - for batch_idx in part - # We do a separate convolution for each channel in x, as we must - for cidx in 1:channels_in(cdims) - GC.@preserve col_slice w dy begin - # Walk each pointer forward as we process each input channel - dy_ptr = pointer(dy, (batch_idx - 1)*M*K*channels_in(cdims)+(cidx - 1)*K*M + 1) - w_ptr = pointer(w, (cidx - 1)*K*N + 1) - col_ptr = pointer(col_slice, (cidx - 1)*M*N + 1) - gemm!(Val(false), Val(true), M, N, K, alpha, dy_ptr, w_ptr, T(0), col_ptr) - end + function ∇depthwiseconv_data_part(task_n, part) + col_slice = col_slice = view(col, :, :, task_n) # col_slice is a task-local workspace + for batch_idx in part + # We do a separate convolution for each channel in x, as we must + for cidx in 1:channels_in(cdims) + GC.@preserve col_slice w dy begin + # Walk each pointer forward as we process each input channel + dy_ptr = pointer(dy, (batch_idx - 1)*M*K*channels_in(cdims)+(cidx - 1)*K*M + 1) + w_ptr = pointer(w, (cidx - 1)*K*N + 1) + col_ptr = pointer(col_slice, (cidx - 1)*M*N + 1) + gemm!(Val(false), Val(true), M, N, K, alpha, dy_ptr, w_ptr, T(0), col_ptr) end - col2im!(view(dx, :, :, :, :, batch_idx), col_slice, cdims) end + col2im!(view(dx, :, :, :, :, batch_idx), col_slice, cdims, beta) + end + end + if should_use_spawn() && length(parts) > 1 + @sync for (task_n, part) in enumerate(parts) + Threads.@spawn ∇depthwiseconv_data_part(task_n, part) + end + else + for (task_n, part) in enumerate(parts) + ∇depthwiseconv_data_part(task_n, part) end end return dx diff --git a/src/nnpack/NNPACK.jl b/src/nnpack/NNPACK.jl deleted file mode 100644 index 685415a7e..000000000 --- a/src/nnpack/NNPACK.jl +++ /dev/null @@ -1,55 +0,0 @@ -using NNPACK_jll - -include("libnnpack_types.jl") -include("error.jl") -include("libnnpack.jl") -include("performance.jl") -include("interface.jl") - - -const shared_threadpool_dict = Dict{UInt64, Base.RefValue}() - -""" - is_nnpack_available() - -Checks if the current hardware is supported by NNPACK. -""" -function is_nnpack_available() - status = nnp_initialize() - if status == nnp_status_unsupported_hardware - return false - else - return true - end -end - -""" - allocate_threadpool() - -Allocates several threadpool based on the upper limit on the number of threads for the machine. -Allows NNPACK to intelligently choose which threadpool to use for getting the best -performance. -""" -function allocate_threadpool() - global NNPACK_CPU_THREADS = NNPACK_CPU_THREADS > 8 ? UInt64(8) : UInt64(exp2(floor(log2(NNPACK_CPU_THREADS)))) - for i in 0:Int(log2(NNPACK_CPU_THREADS)) - threads = UInt64(2^i) - push!(shared_threadpool_dict, threads => Ref(pthreadpool_create(threads))) - end -end - -@init begin - status = nnp_initialize() - if status == nnp_status_unsupported_hardware - @warn "Hardware is unsupported by NNPACK so falling back to default NNlib" - end - try - global NNPACK_CPU_THREADS = parse(UInt64, ENV["NNPACK_CPU_THREADS"]) - catch - # Sys.CPU_THREADS should be a better default if we are tuning the benchmark suite on - # a particular machine. However, we fix the runtime threadpool here to have a max of - # 4 threads so anything above will be ignored anyways - global NNPACK_CPU_THREADS = UInt64(4) - end - allocate_threadpool() -end diff --git a/src/nnpack/error.jl b/src/nnpack/error.jl deleted file mode 100644 index 83522c37d..000000000 --- a/src/nnpack/error.jl +++ /dev/null @@ -1,83 +0,0 @@ -struct NNPACKError <: Exception - code::nnp_status - msg::AbstractString -end - -Base.show(io::IO, err::NNPACKError) = print(io, "NNPACKError(code $(err.code), $(err.msg))") - -function NNPACKError(status::nnp_status) - msg = "NNPACK STATUS SUCCESS" - if status == nnp_status_invalid_batch_size - msg = "NNPACK STATUS INVALID BATCH SIZE" - elseif status == nnp_status_invalid_channels - msg = "NNPACK STATUS INVALID CHANNELS" - elseif status == nnp_status_invalid_input_channels - msg = "NNPACK STATUS INVALID INPUT CHANNELS" - elseif status == nnp_status_invalid_output_channels - msg = "NNPACK STATUS INVALID OUTPUT CHANNELS" - elseif status == nnp_status_invalid_input_size - msg = "NNPACK STATUS INVALID INPUT SIZE" - elseif status == nnp_status_invalid_input_stride - msg = "NNPACK STATUS INVALID INPUT STRIDE" - elseif status == nnp_status_invalid_input_padding - msg = "NNPACK STATUS INVALID INPUT PADDING" - elseif status == nnp_status_invalid_kernel_size - msg = "NNPACK STATUS INVALID KERNEL SIZE" - elseif status == nnp_status_invalid_pooling_size - msg = "NNPACK STATUS INVALID POOLING SIZE" - elseif status == nnp_status_invalid_pooling_stride - msg = "NNPACK STATUS INVALID POOLING STRIDE" - elseif status == nnp_status_invalid_algorithm - msg = "NNPACK STATUS INVALID ALGORITHM" - elseif status == nnp_status_invalid_transform_strategy - msg = "NNPACK STATUS INVALID TRANSFORM STRATEGY" - elseif status == nnp_status_invalid_output_subsampling - msg = "NNPACK STATUS INVALID OUTPUT SUBSAMPLING" - elseif status == nnp_status_invalid_activation - msg = "NNPACK STATUS INVALID ACTIVATION" - elseif status == nnp_status_invalid_activation_parameters - msg = "NNPACK STATUS INVALID ACTIVATION PARAMETERS" - elseif status == nnp_status_unsupported_input_size - msg = "NNPACK STATUS UNSUPPORTED INPUT SIZE" - elseif status == nnp_status_unsupported_input_stride - msg = "NNPACK STATUS UNSUPPORTED INPUT STRIDE" - elseif status == nnp_status_unsupported_input_padding - msg = "NNPACK STATUS UNSUPPORTED INPUT PADDING" - elseif status == nnp_status_unsupported_kernel_size - msg = "NNPACK STATUS UNSUPPORTED KERNEL SIZE" - elseif status == nnp_status_unsupported_pooling_size - msg = "NNPACK STATUS UNSUPPORTED POOLING SIZE" - elseif status == nnp_status_unsupported_pooling_stride - msg = "NNPACK STATUS UNSUPPORTED POOLING STRIDE" - elseif status == nnp_status_unsupported_algorithm - msg = "NNPACK STATUS UNSUPPORTED ALGORITHM" - elseif status == nnp_status_unsupported_transform_strategy - msg = "NNPACK STATUS UNSUPPORTED TRANSFORM STRATEGY" - elseif status == nnp_status_unsupported_activation - msg = "NNPACK STATUS UNSUPPORTED ACTIVATION" - elseif status == nnp_status_unsupported_activation_parameters - msg = "NNPACK STATUS UNSUPPORTED ACTIVATION PARAMETERS" - elseif status == nnp_status_uninitialized - msg = "NNPACK STATUS UNINITIALIZED" - elseif status == nnp_status_unsupported_hardware - msg = "NNPACK STATUS UNSUPPORTED HARDWARE" - elseif status == nnp_status_out_of_memory - msg = "NNPACK STATUS OUT OF MEMORY" - elseif status == nnp_status_insufficient_buffer - msg = "NNPACK STATUS INSUFFICIENT BUFFER" - elseif status == nnp_status_misaligned_buffer - msg = "NNPACK STATUS MISALIGNED BUFFER" - end - NNPACKError(status, msg) -end - -macro nnpack_check(nnp_func) - quote - local err::nnp_status - err = $(esc(nnp_func)) - if err != nnp_status_success - throw(NNPACKError(err)) - end - err - end -end diff --git a/src/nnpack/impl.jl b/src/nnpack/impl.jl deleted file mode 100644 index 3309404e1..000000000 --- a/src/nnpack/impl.jl +++ /dev/null @@ -1,50 +0,0 @@ -function maxpool_nnpack!(y::A, x::A, pdims::PoolDims) where {A<:Array{Float32, 4}} - check_dims(size(x), size(y), pdims) - threadpool = select_threadpool(pdims, size(y, 4)) - nnp_max_pooling_output(y, x, kernel_size(pdims), padding = padding(pdims), - stride = stride(pdims), threadpool = threadpool) -end - -function conv_nnpack!(y::A1, x::A1, w::A1, cdims::ConvDims; - b::A2 = zeros(Float32, size(x, 3)), - algo = UInt32(0)) where {A1<:Array{Float32, 4}, - A2<:Array{Float32, 1}} - check_dims(size(x), size(w), size(y), cdims) - threadpool = select_threadpool(cdims, size(y, 4)) - - if flipkernel(cdims) == 0 - w = flipweight(w) - end - - nnp_convolution_output(y, x, w, b, algo = algo, padding = padding(cdims), - stride = stride(cdims), threadpool = threadpool) -end - -function ∇conv_data_nnpack!(dx::A, dy::A, w::A, cdims::ConvDims; - algo = UInt32(0)) where{A<:Array{Float32, 4}} - check_dims(size(dx), size(w), size(dy), cdims) - threadpool = select_threadpool(cdims, size(dy, 4)) - - if flipkernel(cdims) == 0 - w = flipweight(w) - end - - nnp_convolution_input_gradient(dx, dy, w, algo = algo, padding = padding(cdims), - stride = stride(cdims), threadpool = threadpool) -end - -function ∇conv_filter_nnpack!(dw::A, x::A, dy::A, cdims::ConvDims; - algo = UInt32(0)) where{A<:Array{Float32, 4}} - check_dims(size(x), size(dw), size(dy), cdims) - threadpool = select_threadpool(cdims, size(dy, 4)) - - nnp_convolution_kernel_gradient(dw, x, dy, algo = algo, padding = padding(cdims), - stride = stride(cdims), threadpool = threadpool) - - if flipkernel(cdims) == 0 - dw .= flipweight(dw) - end - - dw -end - diff --git a/src/nnpack/interface.jl b/src/nnpack/interface.jl deleted file mode 100644 index 5cdaccb4d..000000000 --- a/src/nnpack/interface.jl +++ /dev/null @@ -1,44 +0,0 @@ -include("impl.jl") - -## NNPACK supports only Float32 -for (front_name, backend) in ( - :conv => :_nnpack, - :∇conv_data => :_nnpack, - :∇conv_filter => :_nnpack, - ) - @eval begin - function $(Symbol("$(front_name)$(backend)!"))( - out::Array{T1,4}, in1::Array{T2,4}, in2::Array{T3,4}, - cdims::ConvDims; kwargs...) where {T1, T2, T3} - @warn "Automatically converting input tensor to Float32. This will have performance implications" maxlog=1 - # Output must of the same type as in the function signature - T1.($(Symbol("$(front_name)$(backend)!"))(Float32.(out), Float32.(in1), - Float32.(in2), cdims; kwargs...)) - end - end -end - -function maxpool_nnpack!(y::Array{T1, 4}, x::Array{T2, 4}, pdims::PoolDims; - kwargs...) where {T1, T2} - @warn "Automatically converting input tensor to Float32. This will have performance implications" maxlog=1 - # We want the output to be of the same type as desired - T1.(maxpool_nnpack!(Float32.(y), Float32.(x), pdims; kwargs...)) -end - -""" - nnpack_supported_operation(cdims::ConvDims) - nnpack_supported_operation(pdims::PoolDims) - -Returns `true` if nnpack supports the convolution/pooling operation for the given parameters. -""" -function nnpack_supported_operation(pdims::PoolDims{2, K, S, P, (1, 1)}) where {K, S, P} - val = input_size(pdims)[1:2] .+ (P[1] + P[2], P[3] + P[4]) .- K - return val .% S == (0, 0) ? true : false -end - -function nnpack_supported_operation(cdims::ConvDims{2, K, (1, 1), P, (1, 1)}) where {K, S, P} - return true -end - -# Return false for everything else -nnpack_supported_operation(dims) = false diff --git a/src/nnpack/libnnpack.jl b/src/nnpack/libnnpack.jl deleted file mode 100644 index 2f3996c32..000000000 --- a/src/nnpack/libnnpack.jl +++ /dev/null @@ -1,135 +0,0 @@ -#NOTE: We do the error handling of nnp_initialize while loading NNPACK -function nnp_initialize() - ccall((:nnp_initialize, libnnpack), nnp_status, (),) -end - -function nnp_deinitialize() - @nnpack_check ccall((:nnp_deinitialize, libnnpack), nnp_status, (),) -end - -function pthreadpool_create(n = 0) - ccall((:pthreadpool_create, libnnpack), Ptr{Cvoid}, (Csize_t,), n) -end - -function nnp_relu_output(batch_size, channels, input, output, negative_slope, threadpool) - @nnpack_check ccall((:nnp_relu_output, libnnpack), nnp_status, (Csize_t, Csize_t, Ptr{Cfloat}, Ptr{Cfloat}, Cfloat, pthreadpool_t), batch_size, channels, input, output, negative_slope, threadpool) -end - -function nnp_relu_output(x::Array{Float32,N}, y::Array{Float32,N}; negative_slope::AbstractFloat = 0.0, threadpool = C_NULL) where {N} - # Investigate why the channel and batch dims need to specified like this - nnp_relu_output(prod(size(x)[N-1:N]), prod(size(x)[1:N-2]), x, y, negative_slope, threadpool) - y -end - -function nnp_relu_input_gradient(batch_size, channels, grad_output, input, grad_input, negative_slope, threadpool) - @nnpack_check ccall((:nnp_relu_input_gradient, libnnpack), nnp_status, (Csize_t, Csize_t, Ptr{Cfloat}, Ptr{Cfloat}, Ptr{Cfloat}, Cfloat, pthreadpool_t), batch_size, channels, grad_output, input, grad_input, negative_slope, threadpool) -end - -function nnp_relu_input_gradient(x::Array{Float32,N}, dy::Array{Float32,N}, dx::Array{Float32,N}; negative_slope::AbstractFloat = 0.0, threadpool = C_NULL) where {N} - # Investigate why the channel and batch dims need to specified like this - nnp_relu_input_gradient(Csize_t(prod(size(x)[N-1:N])), prod(size(x)[1:N-2]), dy, x, dx, negative_slope, threadpool) - dx -end - -function nnp_softmax_output(batch_size, channels, input, output, threadpool) - @nnpack_check ccall((:nnp_softmax_output, libnnpack), nnp_status, (Csize_t, Csize_t, Ptr{Cfloat}, Ptr{Cfloat}, pthreadpool_t), batch_size, channels, input, output, threadpool) -end - -function nnp_softmax_output(x::VecOrMat{Float32}, y::VecOrMat{Float32}; threadpool = C_NULL) - nnp_softmax_output(ndims(x) == 2 ? size(x, 2) : 1, size(x, 1), x, y, threadpool) - y -end - -#FIXME: Output of fully connected not consistent with `kernel * input` -#NOTE: This most likely due to nnpack being row major. Investigate this. - -function nnp_fully_connected_output(batch_size, input_channels, output_channels, input, kernel, output, threadpool, profile) - @nnpack_check ccall((:nnp_fully_connected_output, libnnpack), nnp_status, (Csize_t, Csize_t, Csize_t, Ptr{Cfloat}, Ptr{Cfloat}, Ptr{Cfloat}, pthreadpool_t, Ptr{Cvoid}), batch_size, input_channels, output_channels, input, kernel, output, threadpool, C_NULL) -end - -function nnp_fully_connected_output(x::Array{Float32,2}, w::Array{Float32,2}, y::Array{Float32,2}; profile = nothing, threadpool = C_NULL) - profile = profile == nothing ? nnp_profile() : profile - nnp_fully_connected_output(size(x, 2), size(x, 1), size(w, 1), x, w, y, threadpool, profile) - y -end - -function nnp_fully_connected_inference_f16f32(input_channels, output_channels, input, kernel, output, threadpool) - @nnpack_check ccall((:nnp_fully_connected_inference_f16f32, libnnpack), nnp_status, (Csize_t, Csize_t, Ptr{Cfloat}, Ptr{Cvoid}, Ptr{Cfloat}, pthreadpool_t), input_channels, output_channels, input, kernel, output, threadpool) -end - -nnp_fully_connected_inference_f16f32(x::Array{Float32, 1}, w::Array{Float16,2}, y::Array{Float32, 1}; threadpool = C_NULL) = - nnp_fully_connected_inference(reshape(x, size(x), 1), w, reshape(y, size(y), 1), threadpool = threadpool) - -function nnp_fully_connected_inference_f16f32(x::Array{Float32, 2}, w::Array{Float16,2}, y::Array{Float32, 2}; threadpool = C_NULL) - nnp_fully_connected_inference(size(x, 1), size(y, 1), x, w, y, threadpool) - y -end - -function nnp_fully_connected_inference(input_channels, output_channels, input, kernel, output, threadpool) - @nnpack_check ccall((:nnp_fully_connected_inference, libnnpack), nnp_status, (Csize_t, Csize_t, Ptr{Cfloat}, Ptr{Cfloat}, Ptr{Cfloat}, pthreadpool_t), input_channels, output_channels, input, kernel, output, threadpool) -end - -nnp_fully_connected_inference(x::Array{Float32, 1}, w::Array{Float32,2}; threadpool = C_NULL) = - nnp_fully_connected_inference(reshape(x, size(x), 1), w, threadpool = threadpool) - -function nnp_fully_connected_inference(x::Array{Float32, 2}, w::Array{Float32, 2}, y::Array{Float32, 2}; threadpool = C_NULL) - nnp_fully_connected_inference(size(x, 1), size(y, 1), x, w, y, threadpool) - y -end - -function nnp_max_pooling_output(batch_size, channels, input_size, input_padding, pooling_size, pooling_stride, input, output, threadpool) - @nnpack_check ccall((:nnp_max_pooling_output, libnnpack), nnp_status, (Csize_t, Csize_t, nnp_size, nnp_padding, nnp_size, nnp_size, Ptr{Cfloat}, Ptr{Cfloat}, pthreadpool_t), batch_size, channels, input_size, input_padding, pooling_size, pooling_stride, input, output, threadpool) -end - -function nnp_max_pooling_output(y::Array{Float32,4}, x::Array{Float32,4}, kernel::Tuple; padding = 0, stride = 1, threadpool = C_NULL) - input_size = nnp_size(Csize_t.((size(x, 1), size(x, 2)))...) - pooling_size = nnp_size(Csize_t.(kernel)...) - input_padding = nnp_padding(Csize_t(padding[2]), Csize_t(padding[1]), Csize_t(padding[2]), Csize_t(padding[1])) - pooling_stride = nnp_size(Csize_t.(stride)...) - nnp_max_pooling_output(size(x, 4), size(x, 3), input_size, input_padding, pooling_size, pooling_stride, x, y, threadpool) - y -end - -#TODO: Add wrapper for convolution inference - -function nnp_convolution_input_gradient(algorithm, batch_size, input_channels, output_channels, input_size, input_padding, kernel_size, grad_output, kernel, grad_input, workspace_buffer, workspace_size, activation, activation_parameters, threadpool, profile) - @nnpack_check ccall((:nnp_convolution_input_gradient, libnnpack), nnp_status, (nnp_convolution_algorithm, Csize_t, Csize_t, Csize_t, nnp_size, nnp_padding, nnp_size, Ptr{Cfloat}, Ptr{Cfloat}, Ptr{Cfloat}, Ptr{Cvoid}, Csize_t, nnp_activation, Ptr{Cvoid}, pthreadpool_t, Ptr{Cvoid}), algorithm, batch_size, input_channels, output_channels, input_size, input_padding, kernel_size, grad_output, kernel, grad_input, workspace_buffer, workspace_size, activation, activation_parameters, threadpool, C_NULL) -end - -function nnp_convolution_input_gradient(dx::Array{Float32,4}, dy::Array{Float32,4}, w::Array{Float32,4}; algo::nnp_convolution_algorithm = UInt32(0), workspace_buffer = nothing, workspace_size = 0, padding = 0, stride = 1, threadpool = C_NULL, profile = nothing) - input_size = nnp_size(Csize_t.((size(dx,1), size(dx,2)))...) - kernel_size = nnp_size(Csize_t.((size(w,1),size(w,2)))...) - input_padding = nnp_padding(Csize_t(padding[2]), Csize_t(padding[1]), Csize_t(padding[2]), Csize_t(padding[1])) - profile = profile == nothing ? nnp_profile() : profile - workspace_buffer = workspace_buffer === nothing ? C_NULL : workspace_buffer - nnp_convolution_input_gradient(UInt32(algo), size(dx,4), size(dx,3), size(w,4), input_size, input_padding, kernel_size, dy, w, dx, workspace_buffer, workspace_size, UInt32(0), C_NULL, threadpool, profile) - dx -end - -function nnp_convolution_kernel_gradient(algorithm, batch_size, input_channels, output_channels, input_size, input_padding, kernel_size, input, grad_output, grad_kernel, workspace_buffer, workspace_size, activation, activation_parameters, threadpool, profile) - @nnpack_check ccall((:nnp_convolution_kernel_gradient, libnnpack), nnp_status, (nnp_convolution_algorithm, Csize_t, Csize_t, Csize_t, nnp_size, nnp_padding, nnp_size, Ptr{Cfloat}, Ptr{Cfloat}, Ptr{Cfloat}, Ptr{Cvoid}, Csize_t, nnp_activation, Ptr{Cvoid}, pthreadpool_t, Ptr{Cvoid}), algorithm, batch_size, input_channels, output_channels, input_size, input_padding, kernel_size, input, grad_output, grad_kernel, workspace_buffer, workspace_size, activation, activation_parameters, threadpool, C_NULL) -end - -function nnp_convolution_kernel_gradient(dw::Array{Float32,4}, x::Array{Float32,4}, dy::Array{Float32,4}; algo::nnp_convolution_algorithm = UInt32(0), workspace_buffer = nothing, workspace_size = 0, padding = 0, stride = 1, threadpool = C_NULL, profile = nothing) - input_size = nnp_size(Csize_t.((size(x,1), size(x,2)))...) - kernel_size = nnp_size(Csize_t.((size(dw,1),size(dw,2)))...) - input_padding = nnp_padding(Csize_t(padding[2]), Csize_t(padding[1]), Csize_t(padding[2]), Csize_t(padding[1])) - profile = profile == nothing ? nnp_profile() : profile - workspace_buffer = workspace_buffer === nothing ? C_NULL : workspace_buffer - nnp_convolution_kernel_gradient(UInt32(algo), size(x,4), size(x,3), size(dw,4), input_size, input_padding, kernel_size, x, dy, dw, workspace_buffer, workspace_size, UInt32(0), C_NULL, threadpool, profile) - dw -end - -function nnp_convolution_output(algorithm, batch_size, input_channels, output_channels, input_size, input_padding, kernel_size, input, kernel, bias, output, workspace_buffer, workspace_size, activation, activation_parameters, threadpool, profile) - @nnpack_check ccall((:nnp_convolution_output, libnnpack), nnp_status, (nnp_convolution_algorithm, Csize_t, Csize_t, Csize_t, nnp_size, nnp_padding, nnp_size, Ptr{Cfloat}, Ptr{Cfloat}, Ptr{Cfloat}, Ptr{Cfloat}, Ptr{Cvoid}, Csize_t, nnp_activation, Ptr{Cvoid}, pthreadpool_t, Ptr{Cvoid}), algorithm, batch_size, input_channels, output_channels, input_size, input_padding, kernel_size, input, kernel, bias, output, workspace_buffer, workspace_size, activation, activation_parameters, threadpool, C_NULL) -end - -function nnp_convolution_output(y::Array{Float32,4}, x::Array{Float32,4}, w::Array{Float32,4}, b::Array{Float32,1}; algo::nnp_convolution_algorithm = UInt32(0), workspace_buffer = nothing, workspace_size = 0, padding = 0, stride = 1, threadpool = C_NULL, profile = nothing) - input_size = nnp_size(Csize_t.((size(x,1), size(x,2)))...) - kernel_size = nnp_size(Csize_t.((size(w,1),size(w,2)))...) - input_padding = nnp_padding(Csize_t(padding[3]), Csize_t(padding[2]), Csize_t(padding[4]), Csize_t(padding[1])) - profile = profile == nothing ? nnp_profile() : profile - workspace_buffer = workspace_buffer === nothing ? C_NULL : workspace_buffer - nnp_convolution_output(UInt32(algo), size(x,4), size(x,3), size(w,4), input_size, input_padding, kernel_size, x, w, b, y, workspace_buffer, workspace_size, UInt32(0), C_NULL, threadpool, profile) - y -end diff --git a/src/nnpack/libnnpack_types.jl b/src/nnpack/libnnpack_types.jl deleted file mode 100644 index 6e7b23c16..000000000 --- a/src/nnpack/libnnpack_types.jl +++ /dev/null @@ -1,85 +0,0 @@ -const nnp_status = UInt32 -const nnp_status_success = (UInt32)(0) -const nnp_status_invalid_batch_size = (UInt32)(2) -const nnp_status_invalid_channels = (UInt32)(3) -const nnp_status_invalid_input_channels = (UInt32)(4) -const nnp_status_invalid_output_channels = (UInt32)(5) -const nnp_status_invalid_input_size = (UInt32)(10) -const nnp_status_invalid_input_stride = (UInt32)(11) -const nnp_status_invalid_input_padding = (UInt32)(12) -const nnp_status_invalid_kernel_size = (UInt32)(13) -const nnp_status_invalid_pooling_size = (UInt32)(14) -const nnp_status_invalid_pooling_stride = (UInt32)(15) -const nnp_status_invalid_algorithm = (UInt32)(16) -const nnp_status_invalid_transform_strategy = (UInt32)(17) -const nnp_status_invalid_output_subsampling = (UInt32)(13) -const nnp_status_invalid_activation = (UInt32)(14) -const nnp_status_invalid_activation_parameters = (UInt32)(15) -const nnp_status_unsupported_input_size = (UInt32)(20) -const nnp_status_unsupported_input_stride = (UInt32)(21) -const nnp_status_unsupported_input_padding = (UInt32)(22) -const nnp_status_unsupported_kernel_size = (UInt32)(23) -const nnp_status_unsupported_pooling_size = (UInt32)(24) -const nnp_status_unsupported_pooling_stride = (UInt32)(25) -const nnp_status_unsupported_algorithm = (UInt32)(26) -const nnp_status_unsupported_transform_strategy = (UInt32)(57) -const nnp_status_unsupported_activation = (UInt32)(28) -const nnp_status_unsupported_activation_parameters = (UInt32)(29) -const nnp_status_uninitialized = (UInt32)(50) -const nnp_status_unsupported_hardware = (UInt32)(51) -const nnp_status_out_of_memory = (UInt32)(52) -const nnp_status_insufficient_buffer = (UInt32)(53) -const nnp_status_misaligned_buffer = (UInt32)(54) - -const nnp_activation = UInt32 -const nnp_activation_identity = (UInt32)(0) -const nnp_activation_relu = (UInt32)(1) - -const nnp_convolution_algorithm = UInt32 -const nnp_convolution_algorithm_auto = (UInt32)(0) -const nnp_convolution_algorithm_ft8x8 = (UInt32)(1) -const nnp_convolution_algorithm_ft16x16 = (UInt32)(2) -const nnp_convolution_algorithm_wt8x8 = (UInt32)(3) -const nnp_convolution_algorithm_implicit_gemm = (UInt32)(4) -const nnp_convolution_algorithm_direct = (UInt32)(5) -const nnp_convolution_algorithm_wt8x8_fp16 = (UInt32)(6) - -const nnp_convolution_transform_strategy = UInt32 -const nnp_convolution_transform_strategy_compute = (UInt32)(1) -const nnp_convolution_transform_strategy_precompute = (UInt32)(2) -const nnp_convolution_transform_strategy_reuse = (UInt32)(3) - -const pthreadpool_t = Ptr{Nothing} - -mutable struct nnp_size - width::Csize_t - height::Csize_t - nnp_size() = new(Csize_t(0), Csize_t(0)) - nnp_size(w, h) = new(Csize_t(w), Csize_t(h)) -end - -Base.unsafe_convert(::Type{Ptr{nnp_size}}, a::nnp_size) = Ptr{a} - -mutable struct nnp_padding - top::Csize_t - right::Csize_t - bottom::Csize_t - left::Csize_t - nnp_padding() = new(Csize_t(0), Csize_t(0), Csize_t(0), Csize_t(0)) - nnp_padding(val) = new(Csize_t(val), Csize_t(val), Csize_t(val), Csize_t(val)) - nnp_padding(t, r, b, l) = new(Csize_t(t), Csize_t(r), Csize_t(b), Csize_t(l)) -end - -Base.unsafe_convert(::Type{Ptr{nnp_padding}}, a::nnp_padding) = Ptr{a} - -mutable struct nnp_profile - total::Cdouble - input_transform::Cdouble - kernel_transform::Cdouble - output_transform::Cdouble - block_multiplication::Cdouble - nnp_profile() = new(Cdouble(0.0), Cdouble(0.0), Cdouble(0.0), Cdouble(0.0), Cdouble(0.0)) - nnp_profile(t, it, kt, ot, bm) = new(Cdouble(t), Cdouble(it), Cdouble(kt), Cdouble(ot), Cdouble(bm)) -end - -Base.unsafe_convert(::Type{Ptr{nnp_profile}}, a::nnp_profile) = Ptr{a} diff --git a/src/nnpack/performance.jl b/src/nnpack/performance.jl deleted file mode 100644 index 24abdb411..000000000 --- a/src/nnpack/performance.jl +++ /dev/null @@ -1,31 +0,0 @@ -function select_threadpool(cdims::DenseConvDims, batch_size::Int) - inp_size = input_size(cdims)[1] - if batch_size >= 32 - return shared_threadpool_dict[Int(NNPACK_CPU_THREADS)][] - elseif batch_size >= 16 && inp_size >= 64 - return shared_threadpool_dict[Int(NNPACK_CPU_THREADS)][] - elseif inp_size <= 32 - return C_NULL - elseif inp_size >= 128 - return shared_threadpool_dict[Int(NNPACK_CPU_THREADS)][] - elseif inp_size * batch_size >= 256 - return shared_threadpool_dict[Int(NNPACK_CPU_THREADS)][] - end - return C_NULL -end - -function select_threadpool(pdims::PoolDims, batch_size::Int) - inp_size = input_size(pdims)[1] - if batch_size >= 32 - return shared_threadpool_dict[Int(NNPACK_CPU_THREADS)][] - elseif batch_size >= 16 && inp_size >= 64 - return shared_threadpool_dict[Int(NNPACK_CPU_THREADS)][] - elseif inp_size <= 32 - return C_NULL - elseif inp_size >= 128 - return shared_threadpool_dict[Int(NNPACK_CPU_THREADS)][] - elseif inp_size * batch_size >= 256 - return shared_threadpool_dict[Int(NNPACK_CPU_THREADS)][] - end - return C_NULL -end diff --git a/src/padding.jl b/src/padding.jl index 1e1ec0b37..76aa1ddd8 100644 --- a/src/padding.jl +++ b/src/padding.jl @@ -111,6 +111,7 @@ end gen_pad(pad::Int, dims, N) = gen_pad(ntuple(_ -> pad, length(dims)), dims, N) gen_pad(pad::Int, dims::Colon, N) = ntuple(_ -> (pad, pad), N) gen_pad(pad, dims::Colon, N) = gen_pad(pad, ntuple(identity, N), N) +gen_pad(pad, dims::Int, N) = gen_pad(pad, (dims,), N) gen_pad(pad::Int, dims::Int, N) = gen_pad((pad,pad), (dims,), N) function gen_pad(pad::NTuple{L,Int}, dims::NTuple{D,Int}, N) where {L,D} ntuple(N) do d @@ -254,27 +255,28 @@ julia> pad_reflect(r, (1,2,1,2)) 4 1 4 7 4 1 ``` """ -function pad_reflect(x::AbstractArray, pad::NTuple{M,Int}; +function pad_reflect(x::AbstractArray, pad::NTuple{M,Int}; dims=1:M÷2) where M length(dims) == M ÷ 2 || throw(ArgumentError("The number of dims should be equal to the number of padding dimensions")) for (i, d) in enumerate(dims) x = pad_reflect(x, (pad[2i-1], pad[2i]); dims = d) - end + end return x end -function pad_reflect(x::AbstractArray{F,N}, pad::NTuple{2,Int}; - dims::Int = 1) where {F,N} +function pad_reflect( + x::AbstractArray{F,N}, pad::NTuple{2,Int}; dims::Int = 1, +) where {F,N} lpad, rpad = pad - n = size(x, dims) - xl = selectdim(x, dims, lpad+1:-1:2) - xr = selectdim(x, dims, n-1:-1:n-rpad) - # Alternative selection, not sure which is faster... - # xl = reverse(selectdim(x, dims, 2:lpad+1), dims) - # xr = reverse(selectdim(x, dims, n-rpad:n-1), dims) - return cat(xl, x, xr, dims = dims) + xl = lpad == 0 ? + similar(x, ntuple(i -> i == dims ? 0 : size(x, i), ndims(x))) : + reverse(selectdim(x, dims, 2:lpad+1); dims) + xr = rpad == 0 ? + similar(x, ntuple(i -> i == dims ? 0 : size(x, i), ndims(x))) : + reverse(selectdim(x, dims, n-rpad:n-1); dims) + return cat(xl, x, xr; dims) end """ @@ -312,24 +314,29 @@ julia> pad_symmetric(r, (1,2,1,2)) 2 2 5 8 8 5 ``` """ -function pad_symmetric(x::AbstractArray, pad::NTuple{M,Int}; +function pad_symmetric(x::AbstractArray, pad::NTuple{M,Int}; dims=1:M÷2) where M length(dims) == M ÷ 2 || throw(ArgumentError("The number of dims should be equal to the number of padding dimensions")) for (i, d) in enumerate(dims) x = pad_symmetric(x, (pad[2i-1], pad[2i]); dims = d) - end + end return x end -function pad_symmetric(x::AbstractArray{F,N}, pad::NTuple{2,Int}; - dims::Int = 1) where {F,N} +function pad_symmetric( + x::AbstractArray{F,N}, pad::NTuple{2,Int}; dims::Int = 1, +) where {F,N} lpad, rpad = pad - n = size(x, dims) - xl = selectdim(x, dims, lpad:-1:1) - xr = selectdim(x, dims, n:-1:n-rpad+1) - return cat(xl, x, xr, dims = dims) + + xl = lpad == 0 ? + similar(x, ntuple(i -> i == dims ? 0 : size(x, i), ndims(x))) : + reverse(selectdim(x, dims, 1:lpad); dims) + xr = rpad == 0 ? + similar(x, ntuple(i -> i == dims ? 0 : size(x, i), ndims(x))) : + reverse(selectdim(x, dims, n-rpad+1:n); dims) + return cat(xl, x, xr; dims) end """ diff --git a/src/pooling.jl b/src/pooling.jl index 1cf666f54..59db9b465 100644 --- a/src/pooling.jl +++ b/src/pooling.jl @@ -107,7 +107,7 @@ end # Finally, let's generate auto-allocating versions of all our functions, for all backends: -for backend in (Symbol(), :_direct, :_nnpack) +for backend in (Symbol(), :_direct) # First make auto-allocating versions of the basic pooling calls: for name in (:maxpool, :meanpool, :lpnormpool) @eval begin @@ -132,16 +132,6 @@ for backend in (Symbol(), :_direct, :_nnpack) end end -## Use NNPACK if it is available and operation is supported. -## The corresponding gradient is not available in NNPACK -## Commented out due to #210 -# if is_nnpack_available() -# function maxpool(x::Array{Float32, 4}, pdims::PoolDims{2, K, S, P, (1, 1)}; kwargs...) where {T, K, S, P} -# func = nnpack_supported_operation(pdims) ? maxpool_nnpack : maxpool_direct -# return func(x, pdims; kwargs...) -# end -# end - expand(N, i::Tuple) = i expand(N, i::Integer) = ntuple(_ -> i, N) @@ -172,7 +162,7 @@ Perform mean pool operation with window size `k` on input tensor `x`. Arguments: -* `x` and `k`: Expects `ndim(x) ∈ 3:5``, and always `length(k) == ndim(x) - 2` +* `x` and `k`: Expects `ndim(x) ∈ 3:5`, and always `length(k) == ndim(x) - 2` * `pad`: See [`pad_zeros`](@ref) for details. * `stride`: Either a tuple with the same length as `k`, or one integer for all directions. Default is `k`. """ @@ -192,7 +182,7 @@ This pooling operator from [Learned-Norm Pooling for Deep Feedforward and Recurr Arguments: -* `x` and `k`: Expects `ndim(x) ∈ 3:5``, and always `length(k) == ndim(x) - 2` +* `x` and `k`: Expects `ndim(x) ∈ 3:5`, and always `length(k) == ndim(x) - 2` * `p` is restricted to `0 < p < Inf`. * `pad`: See [`pad_zeros`](@ref) for details. * `stride`: Either a tuple with the same length as `k`, or one integer for all directions. Default is `k`. diff --git a/src/rotation.jl b/src/rotation.jl new file mode 100644 index 000000000..0452d62c0 --- /dev/null +++ b/src/rotation.jl @@ -0,0 +1,291 @@ +""" + _rotate_coordinates(sinθ, cosθ, i, j, rotation_center, round_or_floor) + +This rotates the coordinates and either applies round(nearest neighbour) +or floor for :bilinear interpolation) +""" +@inline function _rotate_coordinates(sinθ, cosθ, i, j, rotation_center, round_or_floor) + y = i - rotation_center[1] + x = j - rotation_center[2] + yrot = cosθ * y - sinθ * x + rotation_center[1] + xrot = sinθ * y + cosθ * x + rotation_center[2] + yrot_f = round_or_floor(yrot) + xrot_f = round_or_floor(xrot) + yrot_int = round_or_floor(Int, yrot) + xrot_int = round_or_floor(Int, xrot) + return yrot, xrot, yrot_f, xrot_f, yrot_int, xrot_int +end + + +""" + _bilinear_helper(yrot, xrot, yrot_f, xrot_f, yrot_int, xrot_int) + +Some helper variables +""" +@inline function _bilinear_helper(yrot, xrot, yrot_f, xrot_f) + xdiff = (xrot - xrot_f) + xdiff_1minus = 1 - xdiff + ydiff = (yrot - yrot_f) + ydiff_1minus = 1 - ydiff + + return ydiff, ydiff_1minus, xdiff, xdiff_1minus +end + + +""" + _prepare_imrotate(arr, θ, rotation_center) + +Prepate `sin` and `cos`, creates the output array and converts type +of `rotation_center` if required. +""" +function _prepare_imrotate(arr::AbstractArray{T}, θ, rotation_center) where T + # needed for rotation matrix + θ = mod(real(T)(θ), real(T)(2π)) + rotation_center = real(T).(rotation_center) + sinθ, cosθ = sincos(real(T)(θ)) + out = similar(arr) + fill!(out, 0) + return sinθ, cosθ, rotation_center, out +end + + +""" + _check_trivial_rotations!(out, arr, θ, rotation_center) + +When `θ = 0 || π /2 || π || 3/2 || π` and if `rotation_center` +is in the middle of the array. +For an even array of size 4, the rotation_center would need to be 2.5. +For an odd array of size 5, the rotation_center would need to be 3. + +In those cases, rotations are trivial just by reversing or swapping some axes. +""" +function _check_trivial_rotations!(out, arr, θ, rotation_center; adjoint=false) + if iszero(θ) + out .= arr + return true + end + # check for special cases where rotations are trivial + if (iseven(size(arr, 1)) && iseven(size(arr, 2)) && + rotation_center[1] ≈ size(arr, 1) ÷ 2 + 0.5 && rotation_center[2] ≈ size(arr, 2) ÷ 2 + 0.5) || + (isodd(size(arr, 1)) && isodd(size(arr, 2)) && + (rotation_center[1] == size(arr, 1) ÷ 2 + 1 && rotation_center[1] == size(arr, 2) ÷ 2 + 1)) + if θ ≈ π / 2 + if adjoint == false + out .= reverse(PermutedDimsArray(arr, (2, 1, 3, 4)), dims=(2,)) + else + out .= reverse(PermutedDimsArray(arr, (2, 1, 3, 4)), dims=(1,)) + end + return true + elseif θ ≈ π + out .= reverse(arr, dims=(1,2)) + return true + elseif θ ≈ 3 / 2 * π + if adjoint == false + out .= reverse(PermutedDimsArray(arr, (2, 1, 3, 4)), dims=(1,)) + else + out .= reverse(PermutedDimsArray(arr, (2, 1, 3, 4)), dims=(2,)) + end + return true + end + end + + return false +end + + +""" + imrotate(arr::AbstractArray{T, 4}, θ; method=:bilinear, rotation_center=size(arr) .÷ 2 .+ 1) + +Rotates an array in the first two dimensions around the center pixel `rotation_center`. +The default value of `rotation_center` is defined such that there is a integer center pixel for even and odd sized arrays which it is rotated around. +For an even sized array of size `(4,4)` this would be `(3,3)`, for an odd array of size `(3,3)` this would be `(2,2)` +However, `rotation_center` can be also non-integer numbers if specified. + +The angle `θ` is interpreted in radians. + +The adjoint is defined with ChainRulesCore.jl. This method also runs with CUDA (and in principle all KernelAbstractions.jl supported backends). + +# Keywords +* `method=:bilinear` for bilinear interpolation or `method=:nearest` for nearest neighbour +* `rotation_center=size(arr) .÷ 2 .+ 1` means there is a real center pixel around it is rotated. + +# Examples +```julia-repl +julia> arr = zeros((4,4,1,1)); arr[2,2,1,1] = 1; + +julia> arr +4×4×1×1 Array{Float64, 4}: +[:, :, 1, 1] = + 0.0 0.0 0.0 0.0 + 0.0 1.0 0.0 0.0 + 0.0 0.0 0.0 0.0 + 0.0 0.0 0.0 0.0 + +julia> NNlib.imrotate(arr, deg2rad(90)) # rotation around (3,3) +4×4×1×1 Array{Float64, 4}: +[:, :, 1, 1] = + 0.0 0.0 0.0 0.0 + 0.0 0.0 0.0 1.0 + 0.0 0.0 0.0 0.0 + 0.0 0.0 0.0 0.0 + +julia> NNlib.imrotate(arr, deg2rad(90), rotation_center=(2,2)) +4×4×1×1 Array{Float64, 4}: +[:, :, 1, 1] = + 0.0 0.0 0.0 0.0 + 0.0 1.0 0.0 0.0 + 0.0 0.0 0.0 0.0 + 0.0 0.0 0.0 0.0 + +julia> arr = zeros((3,3,1,1)); arr[1,2,1,1] = 1 +1 + +julia> arr +3×3×1×1 Array{Float64, 4}: +[:, :, 1, 1] = + 0.0 1.0 0.0 + 0.0 0.0 0.0 + 0.0 0.0 0.0 + +julia> NNlib.imrotate(arr, deg2rad(45)) +3×3×1×1 Array{Float64, 4}: +[:, :, 1, 1] = + 0.0 0.207107 0.0 + 0.0 0.0 0.207107 + 0.0 0.0 0.0 + +julia> NNlib.imrotate(arr, deg2rad(45), method=:nearest) +3×3×1×1 Array{Float64, 4}: +[:, :, 1, 1] = + 0.0 0.0 1.0 + 0.0 0.0 0.0 + 0.0 0.0 0.0 +``` +""" +function imrotate(arr::AbstractArray{T, 4}, θ; method=:bilinear, rotation_center::Tuple=size(arr) .÷ 2 .+ 1) where T + if (T <: Integer && method==:nearest || !(T <: Integer)) == false + throw(ArgumentError("If the array has an Int eltype, only method=:nearest is supported")) + end + # prepare out, the sin and cos and type of rotation_center + sinθ, cosθ, rotation_center, out = _prepare_imrotate(arr, θ, rotation_center) + # such as 0°, 90°, 180°, 270° and only if the rotation_center is suitable + _check_trivial_rotations!(out, arr, θ, rotation_center) && return out + + # KernelAbstractions specific + backend = KernelAbstractions.get_backend(arr) + if method == :bilinear + kernel! = imrotate_kernel_bilinear!(backend) + elseif method == :nearest + kernel! = imrotate_kernel_nearest!(backend) + else + throw(ArgumentError("No interpolation method such as $method")) + end + kernel!(out, arr, sinθ, cosθ, rotation_center, size(arr, 1), size(arr, 2), + ndrange=size(arr)) + return out +end + + +""" + ∇imrotate(dy, arr::AbstractArray{T, 4}, θ; method=:bilinear, + rotation_center=size(arr) .÷ 2 .+ 1) + +Adjoint for `imrotate`. Gradient only with respect to `arr` and not `θ`. + +# Arguments +* `dy`: input gradient +* `arr`: Input from primal computation +* `θ`: rotation angle in radians +* `method=:bilinear` or `method=:nearest` +* `rotation_center=size(arr) .÷ 2 .+ 1` rotates around a real center pixel for even and odd sized arrays +""" +function ∇imrotate(dy, arr::AbstractArray{T, 4}, θ; method=:bilinear, + rotation_center::Tuple=size(arr) .÷ 2 .+ 1) where T + + sinθ, cosθ, rotation_center, out = _prepare_imrotate(arr, θ, rotation_center) + # for the adjoint, the trivial rotations go in the other direction! + # pass dy and not arr + _check_trivial_rotations!(out, dy, θ, rotation_center, adjoint=true) && return out + + backend = KernelAbstractions.get_backend(arr) + if method == :bilinear + kernel! = ∇imrotate_kernel_bilinear!(backend) + elseif method == :nearest + kernel! = ∇imrotate_kernel_nearest!(backend) + else + throw(ArgumentError("No interpolation method such as $method")) + end + # don't pass arr but dy! + kernel!(out, dy, sinθ, cosθ, rotation_center, size(arr, 1), size(arr, 2), + ndrange=size(arr)) + return out +end + + +@kernel function imrotate_kernel_nearest!(out, arr, sinθ, cosθ, rotation_center, imax, jmax) + i, j, c, b = @index(Global, NTuple) + + r(x...) = round(x..., RoundNearestTiesAway) + _, _, _, _, yrot_int, xrot_int = _rotate_coordinates(sinθ, cosθ, i, j, rotation_center, r) + if 1 ≤ yrot_int ≤ imax && 1 ≤ xrot_int ≤ jmax + @inbounds out[i, j, c, b] = arr[yrot_int, xrot_int, c, b] + end +end + + +@kernel function imrotate_kernel_bilinear!(out, arr, sinθ, cosθ, rotation_center, imax, jmax) + i, j, c, b = @index(Global, NTuple) + + yrot, xrot, yrot_f, xrot_f, yrot_int, xrot_int = _rotate_coordinates(sinθ, cosθ, i, j, rotation_center, floor) + if 1 ≤ yrot_int ≤ imax - 1 && 1 ≤ xrot_int ≤ jmax - 1 + + ydiff, ydiff_1minus, xdiff, xdiff_1minus = + _bilinear_helper(yrot, xrot, yrot_f, xrot_f) + @inbounds out[i, j, c, b] = + ( xdiff_1minus * ydiff_1minus * arr[yrot_int , xrot_int , c, b] + + xdiff_1minus * ydiff * arr[yrot_int + 1 , xrot_int , c, b] + + xdiff * ydiff_1minus * arr[yrot_int , xrot_int + 1 , c, b] + + xdiff * ydiff * arr[yrot_int + 1 , xrot_int + 1 , c, b]) + end +end + + +@kernel function ∇imrotate_kernel_nearest!(out, arr, sinθ, cosθ, rotation_center, imax, jmax) + i, j, c, b = @index(Global, NTuple) + + r(x...) = round(x..., RoundNearestTiesAway) + _, _, _, _, yrot_int, xrot_int = _rotate_coordinates(sinθ, cosθ, i, j, rotation_center, r) + if 1 ≤ yrot_int ≤ imax && 1 ≤ xrot_int ≤ jmax + Atomix.@atomic out[yrot_int, xrot_int, c, b] += arr[i, j, c, b] + end +end + + +@kernel function ∇imrotate_kernel_bilinear!(out, arr, sinθ, cosθ, rotation_center, imax, jmax) + i, j, c, b = @index(Global, NTuple) + + yrot, xrot, yrot_f, xrot_f, yrot_int, xrot_int = _rotate_coordinates(sinθ, cosθ, i, j, rotation_center, floor) + if 1 ≤ yrot_int ≤ imax - 1 && 1 ≤ xrot_int ≤ jmax - 1 + o = arr[i, j, c, b] + ydiff, ydiff_1minus, xdiff, xdiff_1minus = + _bilinear_helper(yrot, xrot, yrot_f, xrot_f) + Atomix.@atomic out[yrot_int , xrot_int , c, b] += xdiff_1minus * ydiff_1minus * o + Atomix.@atomic out[yrot_int + 1 , xrot_int , c, b] += xdiff_1minus * ydiff * o + Atomix.@atomic out[yrot_int , xrot_int + 1, c, b] += xdiff * ydiff_1minus * o + Atomix.@atomic out[yrot_int + 1 , xrot_int + 1, c, b] += xdiff * ydiff * o + end +end + + +# is this rrule good? +# no @thunk and @unthunk +function ChainRulesCore.rrule(::typeof(imrotate), arr::AbstractArray{T}, θ; + method=:bilinear, rotation_center=size(arr) .÷ 2 .+ 1) where T + res = imrotate(arr, θ; method, rotation_center) + function pb_rotate(dy) + ad = ∇imrotate(unthunk(dy), arr, θ; method, rotation_center) + return NoTangent(), ad, NoTangent() + end + + return res, pb_rotate +end diff --git a/src/sampling.jl b/src/sampling.jl index f3de51660..07b326812 100644 --- a/src/sampling.jl +++ b/src/sampling.jl @@ -1,7 +1,8 @@ @inline in_bounds(h, w, H, W) = 1 ≤ h ≤ H && 1 ≤ w ≤ W +@inline in_bounds(h, w, d, H, W, D) = 1 ≤ h ≤ H && 1 ≤ w ≤ W && 1 ≤ d ≤ D # Borders are considered out-of-bounds for gradient. @inline clip_coordinate(coordinate, dim_size) = min(dim_size, max(1, coordinate)) -@inline function ∇clip_coordinate(coordinate::C, dim_size) where C +@inline function ∇clip_coordinate(coordinate::C, dim_size) where {C} if coordinate ≤ 1 return C(1), C(0) elseif coordinate ≥ dim_size @@ -25,83 +26,88 @@ end """ grid_sample(input::AbstractArray{T, 4}, grid::AbstractArray{T, 4}; padding_mode = :zeros) + grid_sample(input::AbstractArray{T, 5}, grid::AbstractArray{T, 4}; padding_mode = :zeros) -Given `input`, compute output by sampling `input` values at pixel -locations from `grid`. Uses bilinear interpolation to calculate output values. + Given `input`, compute output by sampling `input` values at pixel + locations from `grid`. Uses bilinear interpolation to calculate output values. -This implementation assumes the extrema (`-1` and `1`) are considered -as referring to the center points of the input’s corner pixels -(i.e. align corners is `true`). + This implementation assumes the extrema (`-1` and `1`) are considered + as referring to the center points of the input’s corner pixels + (i.e. align corners is `true`). -# Arguments + # Arguments -- `input`: Input array in `(W_in, H_in, C, N)` shape. -- `grid`: Input grid in `(2, W_out, H_out, N)` shape. - Where for each `(W_out, H_out, N)` grid contains `(x, y)` - coordinates that specify sampling locations normalized by the `input` shape. + - `input`: Input array in `(W_in, H_in, [D_in,] C, N)` shape. + - `grid`: Input grid in `(2, W_out, H_out, [D_out,] N)` shape. + Where for each `(W_out, H_out, [D_out,] N)` grid contains `(x, y [,z])` + coordinates that specify sampling locations normalized by the `input` shape. - Therefore, `x` and `y` should have values in `[-1, 1]` range. - For example, `(x = -1, y = -1)` is the left-top pixel of `input`, - and `(x = 1, y = 1)` is the right-bottom pixel of `input`. + Therefore, `x`, `y` and [`z`] should have values in `[-1, 1]` range. + For example, `(x = -1, y = -1, [z = -1])` is the left-top[-front] pixel of `input`, + and `(x = 1, y = 1, [z = 1])` is the right-bottom-back pixel of `input`. - Out-of-bound values are handled according to the `padding_mode`. -- `padding_mode`: Out-of-bound padding. - `:zeros` to use `0` for out-of-bound grid locations. - `:border` to use border values for out-of-bound grid locations. - Default is `:zeros`. + Out-of-bound values are handled according to the `padding_mode`. + - `padding_mode`: Out-of-bound padding. + `:zeros` to use `0` for out-of-bound grid locations. + `:border` to use border values for out-of-bound grid locations. + Default is `:zeros`. -# Returns + # Returns -`(W_out, H_out, C, N)` sampled grid from `input`. + `(W_out, H_out, [D_out,] C, N)` sampled grid from `input`. -# Examples + # Examples -In the example below, grid contains two out-of-bound sampling locations, -which are handled differently, depending on the `padding_mode`. + In the example below, grid contains two out-of-bound sampling locations, + which are handled differently, depending on the `padding_mode`. -```jldoctest -julia> x = reshape(collect(1.0:4.0), (2, 2, 1, 1)) -2×2×1×1 Array{Float64, 4}: -[:, :, 1, 1] = - 1.0 3.0 - 2.0 4.0 + ```jldoctest + julia> x = reshape(collect(1.0:4.0), (2, 2, 1, 1)) + 2×2×1×1 Array{Float64, 4}: + [:, :, 1, 1] = + 1.0 3.0 + 2.0 4.0 -julia> grid = Array{Float64}(undef, 2, 3, 2, 1); + julia> grid = Array{Float64}(undef, 2, 3, 2, 1); -julia> grid[:, 1, 1, 1] .= (-3, -1); + julia> grid[:, 1, 1, 1] .= (-3, -1); -julia> grid[:, 2, 1, 1] .= (0, -1); + julia> grid[:, 2, 1, 1] .= (0, -1); -julia> grid[:, 3, 1, 1] .= (1, -1); + julia> grid[:, 3, 1, 1] .= (1, -1); -julia> grid[:, 1, 2, 1] .= (-1, 1); + julia> grid[:, 1, 2, 1] .= (-1, 1); -julia> grid[:, 2, 2, 1] .= (0, 1); + julia> grid[:, 2, 2, 1] .= (0, 1); -julia> grid[:, 3, 2, 1] .= (3, 1); + julia> grid[:, 3, 2, 1] .= (3, 1); -julia> grid_sample(x, grid; padding_mode=:zeros) -3×2×1×1 Array{Float64, 4}: -[:, :, 1, 1] = - 0.0 3.0 - 1.5 3.5 - 2.0 0.0 + julia> grid_sample(x, grid; padding_mode=:zeros) + 3×2×1×1 Array{Float64, 4}: + [:, :, 1, 1] = + 0.0 3.0 + 1.5 3.5 + 2.0 0.0 -julia> grid_sample(x, grid; padding_mode=:border) -3×2×1×1 Array{Float64, 4}: -[:, :, 1, 1] = - 1.0 3.0 - 1.5 3.5 - 2.0 4.0 -``` + julia> grid_sample(x, grid; padding_mode=:border) + 3×2×1×1 Array{Float64, 4}: + [:, :, 1, 1] = + 1.0 3.0 + 1.5 3.5 + 2.0 4.0 + ``` """ -function grid_sample(input::AbstractArray{T, 4}, grid; padding_mode = :zeros) where T - _, _, iC, iN = size(input) - _, gW, gH, _ = size(grid) - output = similar(input, T, (gW, gH, iC, iN)) +function grid_sample(input::AbstractArray{T,N}, grid; padding_mode = :zeros) where {T,N} + if N ∉ (4,5) + error("grid_sample is only supported for 4D and 5D arrays.") + end + iC, iN = size(input)[end-1:end] + output_size = size(grid)[2:end-1] # W_out, H_out, [D_out] + output = similar(input, T, (output_size..., iC, iN)) grid_sample!(output, input, grid, padding_mode) end -function grid_sample!(output, input, grid, padding_mode) + +function grid_sample!(output::AbstractArray{T,4}, input::AbstractArray{T,4}, grid, padding_mode=:zeros) where {T} pad = Val(padding_mode) iW, iH, iC, iN = size(input) _, gW, gH, _ = size(grid) @@ -113,15 +119,29 @@ function grid_sample!(output, input, grid, padding_mode) end output end + +function grid_sample!(output::AbstractArray{T,5}, input::AbstractArray{T,5}, grid, padding_mode=:zeros) where {T} + pad = Val(padding_mode) + iW, iH, iD, iC, iN = size(input) + _, gW, gH, gD, _ = size(grid) + # Loop over each output pixel. + Threads.@threads for n in 1:iN + for w in 1:gW, h in 1:gH, d in 1:gD + _grid_sample_kernel!(output, input, grid, pad, w, h, d, n, iW, iH, iD, iC) + end + end + output +end + @inline function _grid_sample_kernel!( - output, input, grid, padding_mode, w, h, n, iW, iH, iC, -) + output::AbstractArray{T,4}, input::AbstractArray{T,4}, grid, padding_mode, w, h, n, iW, iH, iC, +) where {T} # Get the corresponding (x, y) coordinates from the grid. @inbounds x, y = grid[1, w, h, n], grid[2, w, h, n] ix = compute_source_index(x, iW, padding_mode) iy = compute_source_index(y, iH, padding_mode) # Get corner pixel values from (ix, iy) in north-east-south-west directions. - ix_nw, iy_nw = floor(Int, ix), floor(Int, iy) + ix_nw, iy_nw = unsafe_trunc(Int, floor(ix)), unsafe_trunc(Int, floor(iy)) ix_ne, iy_ne = ix_nw + 1, iy_nw ix_sw, iy_sw = ix_nw, iy_nw + 1 ix_se, iy_se = ix_ne, iy_sw @@ -132,7 +152,7 @@ end se = (ix - ix_nw) * (iy - iy_nw) # ∀ channel: Calculate bilinear weighted pixel value. @inbounds for c in 1:iC - r = 0.0 + r = zero(T) if in_bounds(iy_nw, ix_nw, iH, iW) r += input[ix_nw, iy_nw, c, n] * nw end @@ -149,6 +169,67 @@ end end end +@inline function _grid_sample_kernel!( + output::AbstractArray{T,5}, input::AbstractArray{T,5}, grid, padding_mode, w, h, d, n, iW, iH, iD, iC, +) where {T} + # Get the corresponding (x, y, z) coordinates from the grid. + @inbounds x, y, z = grid[1, w, h, d, n], grid[2, w, h, d, n], grid[3, w, h, d, n] + ix = compute_source_index(x, iW, padding_mode) + iy = compute_source_index(y, iH, padding_mode) + iz = compute_source_index(z, iD, padding_mode) + + # Get corner voxel values from (ix, iy, iz) in 8 directions (north-east-south-west-bottom-up). + ix_nw, iy_nw, iz_nw = unsafe_trunc(Int, floor(ix)), unsafe_trunc(Int, floor(iy)), unsafe_trunc(Int, floor(iz)) + ix_ne, iy_ne, iz_ne = ix_nw + 1, iy_nw, iz_nw + ix_sw, iy_sw, iz_sw = ix_nw, iy_nw + 1, iz_nw + ix_se, iy_se, iz_se = ix_ne, iy_sw, iz_nw + ix_nw_u, iy_nw_u, iz_nw_u = ix_nw, iy_nw, iz_nw + 1 + ix_ne_u, iy_ne_u, iz_ne_u = ix_ne, iy_ne, iz_ne + 1 + ix_sw_u, iy_sw_u, iz_sw_u = ix_sw, iy_sw, iz_sw + 1 + ix_se_u, iy_se_u, iz_se_u = ix_se, iy_se, iz_se + 1 + + # Get volumes to each neighbor (a.k.a. interpolation weights). + nw = (ix_se - ix) * (iy_se - iy) * (iz_se_u - iz) + ne = (ix - ix_sw) * (iy_sw - iy) * (iz_sw_u - iz) + sw = (ix_ne - ix) * (iy - iy_ne) * (iz_ne_u - iz) + se = (ix - ix_nw) * (iy - iy_nw) * (iz_nw_u - iz) + nw_u = (ix_se - ix) * (iy_se - iy) * (iz - iz_nw) + ne_u = (ix - ix_sw) * (iy_sw - iy) * (iz - iz_sw) + sw_u = (ix_ne - ix) * (iy - iy_ne) * (iz - iz_ne) + se_u = (ix - ix_nw) * (iy - iy_nw) * (iz - iz_nw) + + # ∀ channel: Calculate trilinear weighted voxel value. + @inbounds for c in 1:iC + r = zero(T) + if in_bounds(iy_nw, ix_nw, iz_nw, iH, iW, iD) + r += input[ix_nw, iy_nw, iz_nw, c, n] * nw + end + if in_bounds(iy_ne, ix_ne, iz_ne, iH, iW, iD) + r += input[ix_ne, iy_ne, iz_ne, c, n] * ne + end + if in_bounds(iy_sw, ix_sw, iz_sw, iH, iW, iD) + r += input[ix_sw, iy_sw, iz_sw, c, n] * sw + end + if in_bounds(iy_se, ix_se, iz_se, iH, iW, iD) + r += input[ix_se, iy_se, iz_se, c, n] * se + end + if in_bounds(iy_nw_u, ix_nw_u, iz_nw_u, iH, iW, iD) + r += input[ix_nw_u, iy_nw_u, iz_nw_u, c, n] * nw_u + end + if in_bounds(iy_ne_u, ix_ne_u, iz_ne_u, iH, iW, iD) + r += input[ix_ne_u, iy_ne_u, iz_ne_u, c, n] * ne_u + end + if in_bounds(iy_sw_u, ix_sw_u, iz_sw_u, iH, iW, iD) + r += input[ix_sw_u, iy_sw_u, iz_sw_u, c, n] * sw_u + end + if in_bounds(iy_se_u, ix_se_u, iz_se_u, iH, iW, iD) + r += input[ix_se_u, iy_se_u, iz_se_u, c, n] * se_u + end + output[w, h, d, c, n] = r + end +end + + """ ∇grid_sample(Δ::AbstractArray{T, 4}, input::AbstractArray{T, 4}, grid::AbstractArray{T, 4}; padding_mode = :zeros) where T @@ -168,12 +249,16 @@ end `dinput` (same shape as `input`) and `dgrid` (same shape as `grid`) gradients. """ -function ∇grid_sample(Δ::AbstractArray{T, 4}, input::AbstractArray{T, 4}, grid; padding_mode = :zeros) where T +function ∇grid_sample(Δ::AbstractArray{T,N}, input::AbstractArray{T,N}, grid; padding_mode=:zeros) where {T, N} + if N ∉ (4,5) + error("∇grid_sample is only supported for 4D and 5D arrays.") + end dx = zeros(T, size(input)) dgrid = similar(grid) ∇grid_sample!(dx, dgrid, Δ, input, grid, padding_mode) end -function ∇grid_sample!(dx, dgrid, Δ, input, grid, padding_mode) + +function ∇grid_sample!(dx::AbstractArray{T,4}, dgrid::AbstractArray{T,4}, Δ::AbstractArray{T,4}, input::AbstractArray{T,4}, grid::AbstractArray{T,4}, padding_mode) where {T} pad = Val(padding_mode) iW, iH, iC, iN = size(input) gW, gH = size(grid, 2), size(grid, 3) @@ -185,16 +270,30 @@ function ∇grid_sample!(dx, dgrid, Δ, input, grid, padding_mode) end dx, dgrid end + +function ∇grid_sample!(dx::AbstractArray{T,5}, dgrid::AbstractArray{T,5}, Δ::AbstractArray{T,5}, input::AbstractArray{T,5}, grid::AbstractArray{T,5}, padding_mode) where {T} + pad = Val(padding_mode) + iW, iH, iD, iC, iN = size(input) + gW, gH, gD = size(grid, 2), size(grid, 3), size(grid, 4) + # Loop over each output voxel. + Threads.@threads for n in 1:iN + for w in 1:gW, h in 1:gH, d in 1:gD + _∇grid_sample_kernel!(dx, dgrid, Δ, input, grid, pad, w, h, d, n, iW, iH, iD, iC) + end + end + dx, dgrid +end + @inline function _∇grid_sample_kernel!( - dx, dgrid, Δ, input, grid, padding_mode, w, h, n, iW, iH, iC, -) + dx::AbstractArray{T,4}, dgrid::AbstractArray{V,4}, Δ::AbstractArray{T,4}, input::AbstractArray{T,4}, grid::AbstractArray{V,4}, padding_mode, w, h, n, iW, iH, iC, +) where {T,V} # Get corresponding (x, y) from grid. @inbounds x, y = grid[1, w, h, n], grid[2, w, h, n] - # Compute multipliers for gradinets on ix, iy. + # Compute multipliers for gradients on ix, iy. ix, gix_mult = ∇compute_source_index(x, iW, padding_mode) iy, giy_mult = ∇compute_source_index(y, iH, padding_mode) # Get corner pixel values from (ix, iy) in north-east-south-west directions. - ix_nw, iy_nw = floor(Int, ix), floor(Int, iy) + ix_nw, iy_nw = unsafe_trunc(Int, floor(ix)), unsafe_trunc(Int, floor(iy)) ix_ne, iy_ne = ix_nw + 1, iy_nw ix_sw, iy_sw = ix_nw, iy_nw + 1 ix_se, iy_se = ix_ne, iy_sw @@ -204,7 +303,7 @@ end sw = (ix_ne - ix) * (iy - iy_ne) se = (ix - ix_nw) * (iy - iy_nw) # ∀ channel: Calculate billinear weighted pixel value. - gix, giy = 0.0, 0.0 + gix, giy = zero(V), zero(V) @inbounds for c in 1:iC g_out = Δ[w, h, c, n] # Calculate dx and dgrid partials. @@ -237,10 +336,131 @@ end @inbounds dgrid[2, w, h, n] = giy_mult * giy end +@inline function _∇grid_sample_kernel!( + dx::AbstractArray{T,5}, dgrid::AbstractArray{V,5}, Δ::AbstractArray{T,5}, input::AbstractArray{T,5}, grid::AbstractArray{V,5}, padding_mode, w, h, d, n, iW, iH, iD, iC, +) where {T,V} + # Get corresponding (x, y, z) from grid. + @inbounds x, y, z = grid[1, w, h, d, n], grid[2, w, h, d, n], grid[3, w, h, d, n] + # Compute multipliers for gradients on ix, iy, iz. + ix, gix_mult = ∇compute_source_index(x, iW, padding_mode) + iy, giy_mult = ∇compute_source_index(y, iH, padding_mode) + iz, giz_mult = ∇compute_source_index(z, iD, padding_mode) + + # Get corner pixel values from (ix, iy, iz) + ix_0 = unsafe_trunc(Int, floor(ix)) + iy_0 = unsafe_trunc(Int, floor(iy)) + iz_0 = unsafe_trunc(Int, floor(iz)) + ix_1 = ix_0 + 1 + iy_1 = iy_0 + 1 + iz_1 = iz_0 + 1 + + # Get difference of coordinate + wx_0 = ix - ix_0 + wy_0 = iy - iy_0 + wz_0 = iz - iz_0 + wx_1 = ix_1 - ix + wy_1 = iy_1 - iy + wz_1 = iz_1 - iz + + # Calculate weights (volume of diagnal vertex cube) + # w_{abc} = wx_{¬a}*wy_{¬b}*wz_{¬c} + weight_000 = wx_1 * wy_1 * wz_1 + weight_001 = wx_1 * wy_1 * wz_0 + weight_010 = wx_1 * wy_0 * wz_1 + weight_011 = wx_1 * wy_0 * wz_0 + weight_100 = wx_0 * wy_1 * wz_1 + weight_101 = wx_0 * wy_1 * wz_0 + weight_110 = wx_0 * wy_0 * wz_1 + weight_111 = wx_0 * wy_0 * wz_0 + + # ∂w_{abc}/∂x=(-1)^{¬a} wy_{¬b}*wz_{¬c}, ∂w/∂y = (-1)^{¬b} wx_{¬a}*wz_{¬c}, ∂w/∂z=(-1)^{¬c} wx_{¬a}*wy_{¬b} + # abc are the index of the vertex of the cube (001,010...) + + # Initialize gradient accumulators + gix, giy, giz = zero(V), zero(V), zero(V) + + @inbounds for c in 1:iC + g_out = Δ[w, h, d, c, n] + + # Calculate dx and dgrid partials for all 8 corners + if in_bounds(iy_0, ix_0, iz_0, iH, iW, iD) + _safe_add!(dx, g_out * weight_000, ix_0, iy_0, iz_0, c, n) + val = input[ix_0, iy_0, iz_0, c, n] + gix -= val * wy_1 * wz_1 * g_out + giy -= val * wx_1 * wz_1 * g_out + giz -= val * wx_1 * wy_1 * g_out + end + + if in_bounds(iy_0, ix_0, iz_1, iH, iW, iD) + _safe_add!(dx, g_out * weight_001, ix_0, iy_0, iz_1, c, n) + val = input[ix_0, iy_0, iz_1, c, n] + gix -= val * wy_1 * wz_0 * g_out + giy -= val * wx_1 * wz_0 * g_out + giz += val * wx_1 * wy_1 * g_out + end + + if in_bounds(iy_1, ix_0, iz_0, iH, iW, iD) + _safe_add!(dx, g_out * weight_010, ix_0, iy_1, iz_0, c, n) + val = input[ix_0, iy_1, iz_0, c, n] + gix -= val * wy_0 * wz_1 * g_out + giy += val * wx_1 * wz_1 * g_out + giz -= val * wx_1 * wy_0 * g_out + end + + if in_bounds(iy_1, ix_0, iz_1, iH, iW, iD) + _safe_add!(dx, g_out * weight_011, ix_0, iy_1, iz_1, c, n) + val = input[ix_0, iy_1, iz_1, c, n] + gix -= val * wy_0 * wz_0 * g_out + giy += val * wx_1 * wz_0 * g_out + giz += val * wx_1 * wy_0 * g_out + end + + if in_bounds(iy_0, ix_1, iz_0, iH, iW, iD) + _safe_add!(dx, g_out * weight_100, ix_1, iy_0, iz_0, c, n) + val = input[ix_1, iy_0, iz_0, c, n] + gix += val * wy_1 * wz_1 * g_out + giy -= val * wx_0 * wz_1 * g_out + giz -= val * wx_0 * wy_1 * g_out + end + + if in_bounds(iy_0, ix_1, iz_1, iH, iW, iD) + _safe_add!(dx, g_out * weight_101, ix_1, iy_0, iz_1, c, n) + val = input[ix_1, iy_0, iz_1, c, n] + gix += val * wy_1 * wz_0 * g_out + giy -= val * wx_0 * wz_0 * g_out + giz += val * wx_0 * wy_1 * g_out + end + + if in_bounds(iy_1, ix_1, iz_0, iH, iW, iD) + _safe_add!(dx, g_out * weight_110, ix_1, iy_1, iz_0, c, n) + val = input[ix_1, iy_1, iz_0, c, n] + gix += val * wy_0 * wz_1 * g_out + giy += val * wx_0 * wz_1 * g_out + giz -= val * wx_0 * wy_0 * g_out + end + + if in_bounds(iy_1, ix_1, iz_1, iH, iW, iD) + _safe_add!(dx, g_out * weight_111, ix_1, iy_1, iz_1, c, n) + val = input[ix_1, iy_1, iz_1, c, n] + gix += val * wy_0 * wz_0 * g_out + giy += val * wx_0 * wz_0 * g_out + giz += val * wx_0 * wy_0 * g_out + end + end + + @inbounds dgrid[1, w, h, d, n] = gix_mult * gix + @inbounds dgrid[2, w, h, d, n] = giy_mult * giy + @inbounds dgrid[3, w, h, d, n] = giz_mult * giz +end + @inline function _safe_add!(dx, value, ix, iy, c, n) @inbounds dx[ix, iy, c, n] += value end +@inline function _safe_add!(dx, value, ix, iy, iz, c, n) + @inbounds dx[ix, iy, iz, c, n] += value +end + function rrule(::typeof(grid_sample), x, grid; padding_mode) y = grid_sample(x, grid; padding_mode=padding_mode) function grid_sample_pullback(Δ) diff --git a/src/scatter.jl b/src/scatter.jl index 6057e4528..3507b906d 100644 --- a/src/scatter.jl +++ b/src/scatter.jl @@ -81,7 +81,7 @@ function scatter!(op::OP, dst::AbstractArray, src::AbstractArray, idx::AbstractA dst end -for AT in (AbstractArray, AbstractGPUArray) +for AT in (AbstractArray, AnyGPUArray) @eval function scatter!(op::typeof(mean), dst::$AT, src::$AT, idx::$AT) Ns = scatter!(+, zero(dst), one.(src), idx) dst_ = scatter!(+, zero(dst), src, idx) @@ -90,7 +90,7 @@ for AT in (AbstractArray, AbstractGPUArray) end end -function scatter!(op::OP, dst::AbstractGPUArray, src::AbstractGPUArray, idx::AbstractGPUArray) where OP +function scatter!(op::OP, dst::AnyGPUArray, src::AnyGPUArray, idx::AnyGPUArray) where OP n_dims = scatter_dims(dst, src, idx) args = if n_dims == 0 ndrange = length(idx) @@ -108,7 +108,7 @@ end @kernel function _scatter!(op::OP, dst, src, idxs) where OP i = @index(Global) - @inbounds idx = Tuple(idxs[i]) + @inbounds idx = Tuple(_convert_i64(idxs[i])) @inbounds Atomix.modify!(Atomix.IndexableRef(dst, idx), op, src[i]) # FIXME `@atomic` macro silently fails to perform atomic op below # @atomic dst[idx...] = op(dst[idx...], src[i]) @@ -119,14 +119,20 @@ end ) where OP i = @index(Global) j, k = divrem(i - 1, max_dims_idx) - @inbounds idx = (Tuple(dim_ids[k + 1])..., Tuple(idxs[j + 1])...) + @inbounds idx = (Tuple(dim_ids[k + 1])..., Tuple(_convert_i64(idxs[j + 1]))...) @inbounds Atomix.modify!(Atomix.IndexableRef(dst, idx), op, src[i]) - # FIXME + # FIXME `@atomic` macro silently fails to perform atomic op below # dim_i = Tuple(dim_ids[k + 1]) # idx = idxs[j + 1] # @atomic dst[dim_i..., idx...] = op(dst[dim_i..., idx...], src[i]) end +# Allow non-Int64 indices by converting them to Int64 when index eltype <: Integer. +# All other index types (tuples, cartesian indices) must be in Int64 already. +@inline _convert_i64(x::Int) = x +@inline _convert_i64(x::Integer) = Int(x) +@inline _convert_i64(x) = x + """ NNlib.scatter(op, src, idx; [init, dstsize]) @@ -222,7 +228,7 @@ end function ∇scatter_src( op::Union{typeof(*), typeof(/)}, Δ, dst, - src::AbstractGPUArray{Tsrc, Nsrc}, idx::AbstractGPUArray{Tidx, Nidx}, + src::AnyGPUArray{Tsrc, Nsrc}, idx::AnyGPUArray{Tidx, Nidx}, ) where {Tsrc, Nsrc, Tidx, Nidx} n_dims = Nsrc - Nidx Δsrc = NNlib.modify_src(op, NNlib.gather(Δ, idx), src) diff --git a/src/softmax.jl b/src/softmax.jl index 182f2fb93..709b828be 100644 --- a/src/softmax.jl +++ b/src/softmax.jl @@ -39,7 +39,7 @@ Note that, when used with Flux.jl, `softmax` must not be passed to layers like ` which accept an activation function. The activation is broadcasted over the result, thus applies to individual numbers. But `softmax` always needs to see the whole column. -```julia +```julia-repl julia> using Flux julia> x = randn(Float32, 4, 4, 3, 13); @@ -62,7 +62,8 @@ function softmax!(out::AbstractArray{T}, x::AbstractArray; dims = 1) where {T} if all(isfinite, max_) @fastmath out .= exp.(x .- max_) else - @fastmath @. out = ifelse(isequal(max_,Inf), ifelse(isequal(x,Inf), 1, 0), exp(x - max_)) + _zero, _one, _inf = T(0), T(1), T(Inf) + @fastmath @. out = ifelse(isequal(max_,_inf), ifelse(isequal(x,_inf), _one, _zero), exp(x - max_)) end tmp = dims isa Colon ? sum(out) : sum!(max_, out) out ./= tmp @@ -112,7 +113,8 @@ function logsoftmax!(out::AbstractArray{T}, x::AbstractArray; dims = 1) where {T if all(isfinite, max_) out .= x .- max_ else - @. out = ifelse(isequal(max_,Inf), ifelse(isequal(x,Inf), 0, -Inf), x - max_) + _zero, _minf, _inf = T(0), T(-Inf), T(Inf) + @. out = ifelse(isequal(max_,_inf), ifelse(isequal(x,_inf), _zero, _minf), x - max_) end @fastmath log_ = log.(sum(exp, out; dims)) out .-= log_ diff --git a/src/upsample.jl b/src/upsample.jl index a320ca9e6..b108cab0d 100644 --- a/src/upsample.jl +++ b/src/upsample.jl @@ -173,6 +173,16 @@ function upsample_nearest(x::AbstractArray{T,N}, scales::NTuple{S, <:Integer}) w out end +""" + ∇upsample_nearest(Δ::AbstractArray{T,3}, scales::NTuple{S, <:Integer}) where T + +# Arguments +- `Δ`: Incoming gradient array, backpropagated from downstream layers +- `scales`: scales by which the image was upsampled in the first place + +# Outputs +- `dx`: Downsampled version of `Δ` +""" function ∇upsample_nearest(x::AbstractArray{T,N}, scales::NTuple{S, <:Integer}) where {T,N,S} outsize = ntuple(N) do d d > S && return size(x,d) @@ -385,12 +395,12 @@ end @kernel function _upsample_linear_kernel!(::CPU, y::T, x::T, rwidth, align::Val{A}) where { T <: AbstractArray{<:Any, 3}, A, } - @uniform in_width::UInt32, channels::UInt32, batch::UInt32 = size(x) - @uniform out_width::UInt32 = size(y, 1) - c::UInt32, n::UInt32 = @index(Global, NTuple) + @uniform in_width, channels, batch = size(x) + @uniform out_width = size(y, 1) + c, n = @index(Global, NTuple) yv, xv = @view(y[:, c, n]), @view(x[:, c, n]) - @inbounds for i in UnitRange{UInt32}(one(UInt32), out_width) - iw0, iw1, w0λ, w1λ = source_idx_and_λ(rwidth, i - one(UInt32), align, in_width) + @inbounds for i in 1:out_width + iw0, iw1, w0λ, w1λ = source_idx_and_λ(rwidth, i - 1, align, in_width) yv[i] = w0λ * xv[iw0] + w1λ * xv[iw1] end end @@ -398,12 +408,12 @@ end @kernel function _∇upsample_linear_kernel!(::CPU, dx::T1, Δ::T2, rwidth, align::Val{A}) where { T1 <: AbstractArray{<:Any, 3}, T2 <: AbstractArray{<:Any, 3}, A, } - @uniform in_width::UInt32, channels::UInt32, batch::UInt32 = size(Δ) - @uniform out_width::UInt32 = size(dx, 1) - c::UInt32, n::UInt32 = @index(Global, NTuple) + @uniform in_width, channels, batch = size(Δ) + @uniform out_width = size(dx, 1) + c, n = @index(Global, NTuple) Δv, dxv = @view(Δ[:, c, n]), @view(dx[:, c, n]) - @inbounds for i in UnitRange{UInt32}(one(UInt32), in_width) - ow0, ow1, w0λ, w1λ = source_idx_and_λ(rwidth, i - one(UInt32), align, out_width) + @inbounds for i in 1:in_width + ow0, ow1, w0λ, w1λ = source_idx_and_λ(rwidth, i - 1, align, out_width) val = Δv[i] dxv[ow0] += w0λ * val dxv[ow1] += w1λ * val @@ -411,15 +421,14 @@ end end # Linear (GPU): parallelization along width dimension. -# TODO replace AbstractArray -> AbstractGPUArray once device arrays subtype it. @kernel function _upsample_linear_kernel!(::B, y::T, x::T, rwidth, align::Val{A}) where { B <: GPU, T <: AbstractArray{<:Any, 3}, A, } - @uniform in_width::UInt32, channels::UInt32, batch::UInt32 = size(x) - i::UInt32 = @index(Global) - iw0, iw1, w0λ, w1λ = source_idx_and_λ(rwidth, i - one(UInt32), align, in_width) - @inbounds for n in UnitRange{UInt32}(one(UInt32), batch), c in UnitRange{UInt32}(one(UInt32), channels) + @uniform in_width, channels, batch = size(x) + i = @index(Global) + iw0, iw1, w0λ, w1λ = source_idx_and_λ(rwidth, i - 1, align, in_width) + @inbounds for n in 1:batch, c in 1:channels y[i, c, n] = w0λ * x[iw0, c, n] + w1λ * x[iw1, c, n] end end @@ -427,11 +436,11 @@ end @kernel function _∇upsample_linear_kernel!(::B, dx::T, Δ::T, rwidth, align::Val{A}) where { B <: GPU, T <: AbstractArray{<:Any, 3}, A, } - @uniform in_width::UInt32, channels::UInt32, batch::UInt32 = size(Δ) - @uniform out_width::UInt32 = size(dx, 1) - i::UInt32 = @index(Global) - ow0, ow1, w0λ, w1λ = source_idx_and_λ(rwidth, i - one(UInt32), align, out_width) - @inbounds for n in UnitRange{UInt32}(one(UInt32), batch), c in UnitRange{UInt32}(one(UInt32), channels) + @uniform in_width, channels, batch = size(Δ) + @uniform out_width = size(dx, 1) + i = @index(Global) + ow0, ow1, w0λ, w1λ = source_idx_and_λ(rwidth, i - 1, align, out_width) + @inbounds for n in 1:batch, c in 1:channels val = Δ[i, c, n] @atomic dx[ow0, c, n] += w0λ * val @atomic dx[ow1, c, n] += w1λ * val @@ -443,14 +452,14 @@ end @kernel function _upsample_linear_kernel!(::CPU, y::T, x::T, rwidth, rheight, align::Val{A}) where { T <: AbstractArray{<:Any, 4}, A, } - @uniform in_width::UInt32, in_height::UInt32, channels::UInt32, batch::UInt32 = size(x) - @uniform out_width::UInt32, out_height::UInt32 = size(y)[1:2] - c::UInt32, n::UInt32 = @index(Global, NTuple) + @uniform in_width, in_height, channels, batch = size(x) + @uniform out_width, out_height = size(y)[1:2] + c, n = @index(Global, NTuple) yv, xv = @view(y[:, :, c, n]), @view(x[:, :, c, n]) - for j in UnitRange{UInt32}(one(UInt32), out_height) - ih0, ih1, h0λ, h1λ = source_idx_and_λ(rheight, j - one(UInt32), align, in_height) - for i in UnitRange{UInt32}(one(UInt32), out_width) - iw0, iw1, w0λ, w1λ = source_idx_and_λ(rwidth, i - one(UInt32), align, in_width) + for j in 1:out_height + ih0, ih1, h0λ, h1λ = source_idx_and_λ(rheight, j - 1, align, in_height) + for i in 1:out_width + iw0, iw1, w0λ, w1λ = source_idx_and_λ(rwidth, i - 1, align, in_width) @inbounds yv[i, j] = h0λ * (w0λ * xv[iw0, ih0] + w1λ * xv[iw1, ih0]) + h1λ * (w0λ * xv[iw0, ih1] + w1λ * xv[iw1, ih1]) @@ -461,14 +470,14 @@ end @kernel function _∇upsample_linear_kernel!(::CPU, dx::T1, Δ::T2, rwidth, rheight, align::Val{A}) where { T1 <: AbstractArray{<:Any, 4}, T2 <: AbstractArray{<:Any, 4}, A, } - @uniform in_width::UInt32, in_height::UInt32, channels::UInt32, batch::UInt32 = size(Δ) - @uniform out_width::UInt32, out_height::UInt32 = size(dx)[1:2] - c::UInt32, n::UInt32 = @index(Global, NTuple) + @uniform in_width, in_height, channels, batch = size(Δ) + @uniform out_width, out_height = size(dx)[1:2] + c, n = @index(Global, NTuple) Δv, dxv = @view(Δ[:, :, c, n]), @view(dx[:, :, c, n]) - for j in UnitRange{UInt32}(one(UInt32), in_height) - oh0, oh1, h0λ, h1λ = source_idx_and_λ(rheight, j - one(UInt32), align, out_height) - @inbounds for i in UnitRange{UInt32}(one(UInt32), in_width) - ow0, ow1, w0λ, w1λ = source_idx_and_λ(rwidth, i - one(UInt32), align, out_width) + for j in 1:in_height + oh0, oh1, h0λ, h1λ = source_idx_and_λ(rheight, j - 1, align, out_height) + @inbounds for i in 1:in_width + ow0, ow1, w0λ, w1λ = source_idx_and_λ(rwidth, i - 1, align, out_width) val = Δv[i, j] dxv[ow0, oh0] += w0λ * h0λ * val dxv[ow1, oh0] += w1λ * h0λ * val @@ -483,11 +492,11 @@ end @kernel function _upsample_linear_kernel!(::B, y::T, x::T, rwidth, rheight, align::Val{A}) where { B <: GPU, T <: AbstractArray{<:Any, 4}, A, } - @uniform in_width::UInt32, in_height::UInt32, channels::UInt32, batch::UInt32 = size(x) - i::UInt32, j::UInt32 = @index(Global, NTuple) - iw0, iw1, w0λ, w1λ = source_idx_and_λ(rwidth, i - one(UInt32), align, in_width) - ih0, ih1, h0λ, h1λ = source_idx_and_λ(rheight, j - one(UInt32), align, in_height) - @inbounds for n in UnitRange{UInt32}(one(UInt32), batch), c in UnitRange{UInt32}(one(UInt32), channels) + @uniform in_width, in_height, channels, batch = size(x) + i, j = @index(Global, NTuple) + iw0, iw1, w0λ, w1λ = source_idx_and_λ(rwidth, i - 1, align, in_width) + ih0, ih1, h0λ, h1λ = source_idx_and_λ(rheight, j - 1, align, in_height) + @inbounds for n in 1:batch, c in 1:channels y[i, j, c, n] = h0λ * (w0λ * x[iw0, ih0, c, n] + w1λ * x[iw1, ih0, c, n]) + h1λ * (w0λ * x[iw0, ih1, c, n] + w1λ * x[iw1, ih1, c, n]) @@ -497,12 +506,12 @@ end @kernel function _∇upsample_linear_kernel!(::B, dx::T, Δ::T, rwidth, rheight, align::Val{A}) where { B <: GPU, T <: AbstractArray{<:Any, 4}, A, } - @uniform in_width::UInt32, in_height::UInt32, channels::UInt32, batch::UInt32 = size(Δ) - @uniform out_width::UInt32, out_height::UInt32 = size(dx)[1:2] - i::UInt32, j::UInt32 = @index(Global, NTuple) - ow0, ow1, w0λ, w1λ = source_idx_and_λ(rwidth, i - one(UInt32), align, out_width) - oh0, oh1, h0λ, h1λ = source_idx_and_λ(rheight, j - one(UInt32), align, out_height) - @inbounds for n in UnitRange{UInt32}(one(UInt32), batch), c in UnitRange{UInt32}(one(UInt32), channels) + @uniform in_width, in_height, channels, batch = size(Δ) + @uniform out_width, out_height = size(dx)[1:2] + i, j = @index(Global, NTuple) + ow0, ow1, w0λ, w1λ = source_idx_and_λ(rwidth, i - 1, align, out_width) + oh0, oh1, h0λ, h1λ = source_idx_and_λ(rheight, j - 1, align, out_height) + @inbounds for n in 1:batch, c in 1:channels val = Δ[i, j, c, n] @atomic dx[ow0, oh0, c, n] += w0λ * h0λ * val @atomic dx[ow1, oh0, c, n] += w1λ * h0λ * val @@ -516,17 +525,17 @@ end @kernel function _upsample_linear_kernel!(::CPU, y::T, x::T, rwidth, rheight, rdepth, align::Val{A}) where { T <: AbstractArray{<:Any, 5}, A, } - @uniform in_width::UInt32, in_height::UInt32, in_depth::UInt32 = size(x)[1:3] - @uniform channels::UInt32, batch::UInt32 = size(x, 4), size(x, 5) - @uniform out_width::UInt32, out_height::UInt32, out_depth::UInt32 = size(y)[1:3] - c::UInt32, n::UInt32 = @index(Global, NTuple) + @uniform in_width, in_height, in_depth = size(x)[1:3] + @uniform channels, batch = size(x, 4), size(x, 5) + @uniform out_width, out_height, out_depth = size(y)[1:3] + c, n = @index(Global, NTuple) yv, xv = @view(y[:, :, :, c, n]), @view(x[:, :, :, c, n]) - for k in UnitRange{UInt32}(one(UInt32), out_depth) - id0, id1, d0λ, d1λ = source_idx_and_λ(rdepth, k - one(UInt32), align, in_depth) - for j in UnitRange{UInt32}(one(UInt32), out_height) - ih0, ih1, h0λ, h1λ = source_idx_and_λ(rheight, j - one(UInt32), align, in_height) - for i in UnitRange{UInt32}(one(UInt32), out_width) - iw0, iw1, w0λ, w1λ = source_idx_and_λ(rwidth, i - one(UInt32), align, in_width) + for k in 1:out_depth + id0, id1, d0λ, d1λ = source_idx_and_λ(rdepth, k - 1, align, in_depth) + for j in 1:out_height + ih0, ih1, h0λ, h1λ = source_idx_and_λ(rheight, j - 1, align, in_height) + for i in 1:out_width + iw0, iw1, w0λ, w1λ = source_idx_and_λ(rwidth, i - 1, align, in_width) @inbounds yv[i, j, k] = d0λ * ( h0λ * (w0λ * xv[iw0, ih0, id0] + w1λ * xv[iw1, ih0, id0]) + @@ -542,17 +551,17 @@ end @kernel function _∇upsample_linear_kernel!(::CPU, dx::T1, Δ::T2, rwidth, rheight, rdepth, align::Val{A}) where { T1 <: AbstractArray{<:Any, 5}, T2 <: AbstractArray{<:Any, 5}, A, } - @uniform in_width::UInt32, in_height::UInt32, in_depth::UInt32 = size(Δ)[1:3] - @uniform channels::UInt32, batch::UInt32 = size(Δ, 4), size(Δ, 5) - @uniform out_width::UInt32, out_height::UInt32, out_depth::UInt32 = size(dx)[1:3] - c::UInt32, n::UInt32 = @index(Global, NTuple) + @uniform in_width, in_height, in_depth = size(Δ)[1:3] + @uniform channels, batch = size(Δ, 4), size(Δ, 5) + @uniform out_width, out_height, out_depth = size(dx)[1:3] + c, n = @index(Global, NTuple) Δv, dxv = @view(Δ[:, :, :, c, n]), @view(dx[:, :, :, c, n]) - for k in UnitRange{UInt32}(one(UInt32), in_depth) - od0, od1, d0λ, d1λ = source_idx_and_λ(rdepth, k - one(UInt32), align, out_depth) - for j in UnitRange{UInt32}(one(UInt32), in_height) - oh0, oh1, h0λ, h1λ = source_idx_and_λ(rheight, j - one(UInt32), align, out_height) - @inbounds for i in UnitRange{UInt32}(one(UInt32), in_width) - ow0, ow1, w0λ, w1λ = source_idx_and_λ(rwidth, i - one(UInt32), align, out_width) + for k in 1:in_depth + od0, od1, d0λ, d1λ = source_idx_and_λ(rdepth, k - 1, align, out_depth) + for j in 1:in_height + oh0, oh1, h0λ, h1λ = source_idx_and_λ(rheight, j - 1, align, out_height) + @inbounds for i in 1:in_width + ow0, ow1, w0λ, w1λ = source_idx_and_λ(rwidth, i - 1, align, out_width) val = Δv[i, j, k] dxv[ow0, oh0, od0] += w0λ * h0λ * d0λ * val dxv[ow1, oh0, od0] += w1λ * h0λ * d0λ * val @@ -573,13 +582,13 @@ end @kernel function _upsample_linear_kernel!(::B, y::T, x::T, rwidth, rheight, rdepth, align::Val{A}) where { B <: GPU, T <: AbstractArray{<:Any, 5}, A, } - @uniform in_width::UInt32, in_height::UInt32, in_depth::UInt32 = size(x)[1:3] - @uniform channels::UInt32, batch::UInt32 = size(x, 4), size(x, 5) - i::UInt32, j::UInt32, k::UInt32 = @index(Global, NTuple) - iw0, iw1, w0λ, w1λ = source_idx_and_λ(rwidth, i - one(UInt32), align, in_width) - ih0, ih1, h0λ, h1λ = source_idx_and_λ(rheight, j - one(UInt32), align, in_height) - id0, id1, d0λ, d1λ = source_idx_and_λ(rdepth, k - one(UInt32), align, in_depth) - @inbounds for n in UnitRange{UInt32}(one(UInt32), batch), c in UnitRange{UInt32}(one(UInt32), channels) + @uniform in_width, in_height, in_depth = size(x)[1:3] + @uniform channels, batch = size(x, 4), size(x, 5) + i, j, k = @index(Global, NTuple) + iw0, iw1, w0λ, w1λ = source_idx_and_λ(rwidth, i - 1, align, in_width) + ih0, ih1, h0λ, h1λ = source_idx_and_λ(rheight, j - 1, align, in_height) + id0, id1, d0λ, d1λ = source_idx_and_λ(rdepth, k - 1, align, in_depth) + @inbounds for n in 1:batch, c in 1:channels y[i, j, k, c, n] = d0λ * ( h0λ * (w0λ * x[iw0, ih0, id0, c, n] + w1λ * x[iw1, ih0, id0, c, n]) + @@ -593,14 +602,14 @@ end @kernel function _∇upsample_linear_kernel!(::B, dx::T, Δ::T, rwidth, rheight, rdepth, align::Val{A}) where { B <: GPU, T <: AbstractArray{<:Any, 5}, A, } - @uniform in_width::UInt32, in_height::UInt32, in_depth::UInt32 = size(Δ)[1:3] - @uniform channels::UInt32, batch::UInt32 = size(Δ, 4), size(Δ, 5) - @uniform out_width::UInt32, out_height::UInt32, out_depth::UInt32 = size(dx)[1:3] - i::UInt32, j::UInt32, k::UInt32 = @index(Global, NTuple) - ow0, ow1, w0λ, w1λ = source_idx_and_λ(rwidth, i - one(UInt32), align, out_width) - oh0, oh1, h0λ, h1λ = source_idx_and_λ(rheight, j - one(UInt32), align, out_height) - od0, od1, d0λ, d1λ = source_idx_and_λ(rdepth, k - one(UInt32), align, out_depth) - @inbounds for n in UnitRange{UInt32}(one(UInt32), batch), c in UnitRange{UInt32}(one(UInt32), channels) + @uniform in_width, in_height, in_depth = size(Δ)[1:3] + @uniform channels, batch = size(Δ, 4), size(Δ, 5) + @uniform out_width, out_height, out_depth = size(dx)[1:3] + i, j, k = @index(Global, NTuple) + ow0, ow1, w0λ, w1λ = source_idx_and_λ(rwidth, i - 1, align, out_width) + oh0, oh1, h0λ, h1λ = source_idx_and_λ(rheight, j - 1, align, out_height) + od0, od1, d0λ, d1λ = source_idx_and_λ(rdepth, k - 1, align, out_depth) + @inbounds for n in 1:batch, c in 1:channels val = Δ[i, j, k, c, n] @atomic dx[ow0, oh0, od0, c, n] += w0λ * h0λ * d0λ * val @atomic dx[ow1, oh0, od0, c, n] += w1λ * h0λ * d0λ * val @@ -615,17 +624,21 @@ end end @inline function source_idx_and_λ( - ratio::T, out_idx::UInt32, ::Val{align}, in_width::UInt32, + ratio::T, out_idx::Int, ::Val{align}, in_width::Int, ) where {T, align} real_index = align ? ratio * out_idx : max(zero(T), ratio * (out_idx + T(0.5)) - T(0.5)) - iw0 = floor(UInt32, real_index) - offset::UInt32 = ifelse(iw0 < in_width - one(UInt32), one(UInt32), zero(UInt32)) - iw1 = iw0 + offset + one(UInt32) + iw0 = if T <: Rational + floor(Int, real_index) # Not GPU-friendly, but allows for Rational support. + else + unsafe_trunc(Int, floor(real_index)) + end + offset = ifelse(iw0 < in_width - 1, 1, 0) + iw1 = iw0 + offset + 1 w1lambda = real_index - iw0 w0lambda = one(T) - w1lambda - return iw0 + one(UInt32), iw1, w0lambda, w1lambda + return iw0 + 1, iw1, w0lambda, w1lambda end diff --git a/src/utils.jl b/src/utils.jl index 3d23e7383..baf95c8da 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -10,7 +10,7 @@ pass it an array whose gradient is of interest. There is also an overload for ForwardDiff.jl's `Dual` types (and arrays of them). # Examples -``` +```julia-repl julia> using ForwardDiff, Zygote, NNlib julia> f_good(x) = if NNlib.within_gradient(x) diff --git a/test/Project.toml b/test/Project.toml new file mode 100644 index 000000000..5fccd84c6 --- /dev/null +++ b/test/Project.toml @@ -0,0 +1,25 @@ +[deps] +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" +Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" +EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a" +ImageTransformations = "02fcd773-0e25-5acc-982a-7f6622650795" +Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59" +Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" +UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" +Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" +SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" diff --git a/test/activations.jl b/test/activations.jl index 3a14bfde8..0bb5047c2 100644 --- a/test/activations.jl +++ b/test/activations.jl @@ -12,6 +12,8 @@ BINARY_ACTIVATIONS = filter(f -> hasmethod(f, Tuple{Float64, Float64}), ACTIVATI @test rrelu(0.0) == 0.0 @test elu(0.0) == 0.0 @test gelu(0.0) == 0.0 +@test gelu_tanh(0.0) == 0.0 +@test gelu_erf(0.0) == 0.0 @test swish(0.0) == 0.0 @test hardswish(0.0) == 0.0 @test lisht(0.0) == 0.0 @@ -36,6 +38,8 @@ BINARY_ACTIVATIONS = filter(f -> hasmethod(f, Tuple{Float64, Float64}), ACTIVATI @test rrelu(1.0) == 1.0 @test elu(1.0) == 1.0 @test gelu(1.0) == 0.8411919906082768 +@test gelu_tanh(1.0) == 0.8411919906082768 +@test gelu_erf(1.0) == 0.8413447460685429 @test swish(1.0) == sigmoid(1.0) @test hardswish(1.0) == hardsigmoid(1.0) @test lisht(1.0) ≈ 1.0 * tanh(1.0) @@ -58,6 +62,8 @@ BINARY_ACTIVATIONS = filter(f -> hasmethod(f, Tuple{Float64, Float64}), ACTIVATI @test -1/3.0 <= rrelu(-1.0) <= -1/8.0 @test elu(-1.0) == exp(-1.0) - 1.0 @test gelu(-1.0) ≈ -0.15880800939172324 +@test gelu_tanh(-1.0) ≈ -0.15880800939172324 +@test gelu_erf(-1.0) == -0.15865525393145707 @test swish(-1.0) == -sigmoid(-1.0) @test hardswish(-1.0) == -hardsigmoid(-1.0) @test lisht(-1.0) ≈ -1.0 * tanh(-1.0) @@ -114,7 +120,7 @@ end a == softsign && continue @test !isnan(a(Inf32)) - a in [gelu, swish, hardswish, logcosh, mish] && continue + a in [gelu, gelu_tanh, gelu_erf, swish, hardswish, logcosh, mish] && continue @test !isnan(a(-Inf32)) end end diff --git a/test/batchedmul.jl b/test/batchedmul.jl index 2396cff9a..1b8b08e18 100644 --- a/test/batchedmul.jl +++ b/test/batchedmul.jl @@ -96,7 +96,7 @@ end @test C1 ≈ C2 # 5-arg mul! - @test 10 .* C1 ≈ batched_mul!(C2, A′, B′, 10) + @test 10 .* C1 ≈ batched_mul!(C2, A′, B′, 10) rtol=1e-7 C2 .= 10 @test C1 .+ 100 ≈ batched_mul!(C2, A′, B′, 1, 10) diff --git a/test/conv.jl b/test/conv.jl index 8edc4bf24..cf3232778 100644 --- a/test/conv.jl +++ b/test/conv.jl @@ -2,6 +2,7 @@ using NNlib, Test using NNlib: input_size, kernel_size, channels_in, channels_out, channel_multiplier, stride, padding, dilation, flipkernel, output_size, groupcount +using Random: AbstractRNG, SamplerType @testset "ConvDims" begin for T in (DenseConvDims, DepthwiseConvDims) @@ -276,13 +277,7 @@ ddims(x) = dropdims(x, dims=(ndims(x)-1, ndims(x))) w = reshape(Float64[1:prod(size(dw));], size(dw)..., 1, 1) convs = [NNlib.conv, NNlib.conv_im2col, NNlib.conv_direct,] - NNlib.is_nnpack_available() && push!(convs, NNlib.conv_nnpack) for conv in convs - if NNlib.is_nnpack_available() - if conv == NNlib.conv_nnpack && !NNlib.nnpack_supported_operation(DenseConvDims(x, w)) - continue - end - end @testset "$(conv)" begin cdims = DenseConvDims(x, w) # First, your basic convolution with no parameters @@ -310,15 +305,9 @@ ddims(x) = dropdims(x, dims=(ndims(x)-1, ndims(x))) end end - # Test all in-place implementations/interfaces + # Test all in-place implementations/interfaces convs = [NNlib.conv!, NNlib.conv_im2col!, NNlib.conv_direct!,] - NNlib.is_nnpack_available() && push!(convs, NNlib.conv_nnpack!) for conv! in convs - if NNlib.is_nnpack_available() - if conv! == NNlib.conv_nnpack! && !NNlib.nnpack_supported_operation(DenseConvDims(x, w)) - continue - end - end α, β = 2e0, -1e0 @testset "$(conv!)" begin @@ -399,7 +388,17 @@ ddims(x) = dropdims(x, dims=(ndims(x)-1, ndims(x))) end end - # Test all in-place implementations/interfaces + # Test im2col + + for beta in (-2.0, -1.0, 0.0, 0.5, 1.0, 2.0) + cache_dx, cache_dy, cache_w = ([0.17;;; 0.19;;; 0.23], [0.11;;; 0.13;;; 0.15], [1.0;;;]) + dx_old = copy(cache_dx) + cdims = DenseConvDims(cache_dx, cache_w) + NNlib.∇conv_data_im2col!(cache_dx, cache_dy, cache_w, cdims; alpha=1.0, beta) + @test isapprox(cache_dx, dx_old * beta + cache_dy, rtol = 1.0e-7) + end + + # Test all in-place implementations/interfaces for (∇conv_filter!, ∇conv_data!) in ( (NNlib.∇conv_filter!, NNlib.∇conv_data!), (NNlib.∇conv_filter_im2col!, NNlib.∇conv_data_im2col!), @@ -407,47 +406,46 @@ ddims(x) = dropdims(x, dims=(ndims(x)-1, ndims(x))) ) #α, β = 2*rand(rng) - 1, 2*rand(rng) - 1 α, β = 2e0, -1e0 - flag = ∇conv_data! in (NNlib.∇conv_data!, NNlib.∇conv_data_im2col!) @testset "$(∇conv_filter!)/$(∇conv_data!)" begin # First, your basic convolution with no parameters cdims = DenseConvDims(x, w) dy = NNlib.conv(x, w, cdims) - @test isapprox(ddims(∇conv_filter!(copy(w), x, dy, cdims; alpha=α, beta=β)), α*dw + β*w, rtol = 1.0e-7) - @test isapprox(ddims(∇conv_data!(copy(x), dy, w, cdims; alpha=α, beta=β)), α*dx + β*x, rtol = 1.0e-7) broken=flag + @test isapprox(ddims(∇conv_filter!(copy(w), x, dy, cdims; alpha=α, beta=β)), α*dw + β*w, rtol = 1.0e-7) + @test isapprox(ddims(∇conv_data!(copy(x), dy, w, cdims; alpha=α, beta=β)), α*dx + β*x, rtol = 1.0e-7) # Next, test convolution on views and alternate datatypes: - @test isapprox(ddims(∇conv_filter!(copy(w), x, view(dy, repeat([:], ndims(dy))...), cdims; alpha=α, beta=β)), α*dw + β*w, rtol = 1.0e-7) - @test isapprox(ddims(∇conv_data!(copy(x), view(dy, repeat([:], ndims(dy))...), w, cdims; alpha=α, beta=β)), α*dx + β*x, rtol = 1.0e-7) broken=flag + @test isapprox(ddims(∇conv_filter!(copy(w), x, view(dy, repeat([:], ndims(dy))...), cdims; alpha=α, beta=β)), α*dw + β*w, rtol = 1.0e-7) + @test isapprox(ddims(∇conv_data!(copy(x), view(dy, repeat([:], ndims(dy))...), w, cdims; alpha=α, beta=β)), α*dx + β*x, rtol = 1.0e-7) - @test isapprox(ddims(∇conv_filter!(Float32.(copy(w)), Float32.(x), Float32.(dy), cdims; alpha=Float32(α), beta=Float32(β))), α*dw + β*w, rtol = 1.0e-7) - @test isapprox(ddims(∇conv_data!(Float32.(copy(x)), Float32.(dy), Float32.(w), cdims; alpha=Float32(α), beta=Float32(β))), α*dx + β*x, rtol = 1.0e-7) broken=flag + @test isapprox(ddims(∇conv_filter!(Float32.(copy(w)), Float32.(x), Float32.(dy), cdims; alpha=Float32(α), beta=Float32(β))), α*dw + β*w, rtol = 1.0e-7) + @test isapprox(ddims(∇conv_data!(Float32.(copy(x)), Float32.(dy), Float32.(w), cdims; alpha=Float32(α), beta=Float32(β))), α*dx + β*x, rtol = 1.0e-7) # Next, introduce stride: cdims = DenseConvDims(x, w; stride=2) dy = NNlib.conv(x, w, cdims) flag_ = ∇conv_filter! == NNlib.∇conv_filter_direct! && rank in (1,3) @test isapprox(ddims(∇conv_filter!(copy(w), x, dy, cdims; alpha=α, beta=β)), α*dw_stride + β*w, rtol = 1.0e-7) broken=flag_ - @test isapprox(ddims(∇conv_data!(copy(x), dy, w, cdims; alpha=α, beta=β)), α*dx_stride + β*x, rtol = 1.0e-7) broken=flag + @test isapprox(ddims(∇conv_data!(copy(x), dy, w, cdims; alpha=α, beta=β)), α*dx_stride + β*x, rtol = 1.0e-7) # Next, introduce dilation: cdims = DenseConvDims(x, w; dilation=2) dy = NNlib.conv(x, w, cdims) flag_ = ∇conv_data! == NNlib.∇conv_data_direct! && rank == 3 - @test isapprox(ddims(∇conv_filter!(copy(w), x, dy, cdims; alpha=α, beta=β)), α*dw_dil + β*w, rtol = 1.0e-7) - @test isapprox(ddims(∇conv_data!(copy(x), dy, w, cdims; alpha=α, beta=β)), α*dx_dil + β*x, rtol = 1.0e-7) broken=flag || flag_ + @test isapprox(ddims(∇conv_filter!(copy(w), x, dy, cdims; alpha=α, beta=β)), α*dw_dil + β*w, rtol = 1.0e-7) + @test isapprox(ddims(∇conv_data!(copy(x), dy, w, cdims; alpha=α, beta=β)), α*dx_dil + β*x, rtol = 1.0e-7) broken=flag_ # Next, introduce padding: cdims = DenseConvDims(x, w; padding=1) dy = NNlib.conv(x, w, cdims) - @test isapprox(ddims(∇conv_filter!(copy(w), x, dy, cdims; alpha=α, beta=β)), α*dw_pad + β*w, rtol = 1.0e-7) - @test isapprox(ddims(∇conv_data!(copy(x), dy, w, cdims; alpha=α, beta=β)), α*dx_pad + β*x, rtol = 1.0e-7) broken=flag + @test isapprox(ddims(∇conv_filter!(copy(w), x, dy, cdims; alpha=α, beta=β)), α*dw_pad + β*w, rtol = 1.0e-7) + @test isapprox(ddims(∇conv_data!(copy(x), dy, w, cdims; alpha=α, beta=β)), α*dx_pad + β*x, rtol = 1.0e-7) # Next, test crosscor/conv with a flipped kernel cdims = DenseConvDims(x, w; flipkernel=true) dy = NNlib.conv(x, w, cdims) - @test isapprox(ddims(∇conv_filter!(copy(w), x, dy, cdims; alpha=α, beta=β)), α*dw_flip + β*w, rtol = 1.0e-7) - @test isapprox(ddims(∇conv_data!(copy(x), dy, w, cdims; alpha=α, beta=β)), α*dx_flip + β*x, rtol = 1.0e-7) broken=flag + @test isapprox(ddims(∇conv_filter!(copy(w), x, dy, cdims; alpha=α, beta=β)), α*dw_flip + β*w, rtol = 1.0e-7) + @test isapprox(ddims(∇conv_data!(copy(x), dy, w, cdims; alpha=α, beta=β)), α*dx_flip + β*x, rtol = 1.0e-7) end end end @@ -460,13 +458,7 @@ end w = reshape(complex.(Float64[1:4;] .+ 2, Float64[1:4;] .+ 3), 1, 4, 1) cdims = DenseConvDims(x, w) convs = [NNlib.conv, NNlib.conv_im2col, NNlib.conv_direct,] - NNlib.is_nnpack_available() && push!(convs, NNlib.conv_nnpack) for conv in convs - if NNlib.is_nnpack_available() - if conv == NNlib.conv_nnpack && !NNlib.nnpack_supported_operation(cdims) - continue - end - end @testset "$(conv)" begin @test isapprox(ddims(conv(x, w, cdims)), [transpose(vec(w)) * vec(x)], rtol = 1.0e-7) end @@ -856,6 +848,44 @@ end @test size(NNlib.∇conv_filter_direct!(w, x, y, cdims)) == w_size end +# https://github.com/FluxML/NNlib.jl/issues/490 +# https://github.com/FluxML/NNlib.jl/issues/405 +@testset "conv_direct! - Unusual input types" begin + # Create test type that can't be indexed when undefined. + # This simulates the worst-case scenario for custom types. + struct MyFloat <: Real + set::Set{Float32} + end + + # Test that direct indexing fails when undefined. + v = Array{MyFloat}(undef, 3) + @test_throws UndefRefError v[1] + + # Define minimal set of functions required for conv_direct! + MyFloat(x::MyFloat) = x + MyFloat(x::Real) = MyFloat(Set(Float32(x))) + + Base.:+(x::MyFloat, y::MyFloat) = MyFloat(only(x.set) + only(y.set)) + Base.:*(x::MyFloat, y::MyFloat) = MyFloat(only(x.set) * only(y.set)) + Base.promote_rule(::Type{MyFloat}, ::Type{Float32}) = MyFloat + Base.rand(::AbstractRNG, ::SamplerType{MyFloat}) = MyFloat(rand(Float32)) + Base.zero(::MyFloat) = MyFloat(zero(Float32)) + Base.zero(::Type{MyFloat}) = MyFloat(zero(Float32)) + + # Test conv_direct! + x_size = (6, 7, 8, 5, 3) + y_size = (5, 6, 7, 4, 3) + w_size = (2, 2, 2, 5, 4) + x = rand(MyFloat, x_size); + w = randn(Float32, w_size); + y = Array{MyFloat}(undef, y_size...); + cdims = DenseConvDims(x_size, w_size) + y_out = NNlib.conv_direct!(y, x, w, cdims) + + @test eltype(y_out) == MyFloat + @test size(y_out) == y_size +end + @testset "AutoDiff: spatial_rank=$spatial_rank" for spatial_rank in (1, 2, 3) x = rand(rng, repeat([5], spatial_rank)..., 3, 2) w = rand(rng, repeat([3], spatial_rank)..., 3, 3) @@ -877,3 +907,93 @@ end gradtest((y, w) -> ∇depthwiseconv_data(y, w, dcdims), y, w) gradtest((y, w) -> sum(∇depthwiseconv_data(y, w, dcdims)), y, w) end + +@static if Test_Enzyme + +@testset "EnzymeRules: conv! spatial_rank=$spatial_rank" for spatial_rank in (1, 2, 3) + x = rand(rng, repeat([5], spatial_rank)..., 3, 2) + w = rand(rng, repeat([3], spatial_rank)..., 3, 3) + + cdims = DenseConvDims(x, w) + + curconv = conv + curconv! = conv! + dst = curconv(x, w, cdims) + + for Tret in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated), + Tdst in (EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated), + Tx in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated), + Tw in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated) + + EnzymeTestUtils.are_activities_compatible(Tret, Tdst, Tx, Tw) || continue + + EnzymeTestUtils.test_reverse(curconv!, Tret, (dst, Tdst), (x, Tx), (w, Tw), (cdims, EnzymeCore.Const)) + end +end + +@testset "EnzymeRules: ∇conv_data! spatial_rank=$spatial_rank" for spatial_rank in (1, 2, 3) + x = rand(rng, repeat([5], spatial_rank)..., 3, 2) + w = rand(rng, repeat([3], spatial_rank)..., 3, 3) + y = conv(x, w, cdims) + + cdims = DenseConvDims(x, w) + + curconv = ∇conv_data + curconv! = ∇conv_data! + dst = curconv(y, w, cdims) + + for Tret in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated), + Tdst in (EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated), + Ty in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated), + Tw in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated) + + EnzymeTestUtils.are_activities_compatible(Tret, Tdst, Tx, Tw) || continue + + EnzymeTestUtils.test_reverse(curconv!, Tret, (dst, Tdst), (y, Ty), (w, Tw), (cdims, EnzymeCore.Const)) + end +end + +@testset "EnzymeRules: ∇conv_filter! spatial_rank=$spatial_rank" for spatial_rank in (1, 2, 3) + x = rand(rng, repeat([5], spatial_rank)..., 3, 2) + w = rand(rng, repeat([3], spatial_rank)..., 3, 3) + y = conv(x, w, cdims) + + cdims = DenseConvDims(x, w) + + curconv = ∇conv_filter + curconv! = ∇conv_filter! + dst = curconv(x, w, cdims) + + for Tret in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated), + Tdst in (EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated), + Tx in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated), + Ty in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated) + + EnzymeTestUtils.are_activities_compatible(Tret, Tdst, Tx, Ty) || continue + + EnzymeTestUtils.test_reverse(curconv!, Tret, (dst, Tdst), (x, Tx), (y, Ty), (cdims, EnzymeCore.Const)) + end +end + +@testset "EnzymeRules: depthwiseconv! spatial_rank=$spatial_rank" for spatial_rank in (1, 2, 3) + x = rand(rng, repeat([5], spatial_rank)..., 3, 2) + w = rand(rng, repeat([3], spatial_rank)..., 3, 3) + + cdims = DepthwiseConvDims(x, w) + + curconv = depthwiseconv + curconv! = depthwiseconv! + dst = curconv(x, w, cdims) + + for Tret in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated), + Tdst in (EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated), + Tx in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated), + Tw in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated) + + EnzymeTestUtils.are_activities_compatible(Tret, Tdst, Tx, Tw) || continue + + EnzymeTestUtils.test_reverse(curconv!, Tret, (dst, Tdst), (x, Tx), (w, Tw), (cdims, EnzymeCore.Const)) + end +end + +end diff --git a/test/dropout.jl b/test/dropout.jl index 07a48edc5..0da70111e 100644 --- a/test/dropout.jl +++ b/test/dropout.jl @@ -1,5 +1,5 @@ using NNlib, Test, Statistics, Random, LinearAlgebra -using Zygote, StableRNGs, ChainRulesCore +using Zygote, StableRNGs, ChainRulesCore, Enzyme @testset "dropout" begin # Basics @@ -75,3 +75,32 @@ using Zygote, StableRNGs, ChainRulesCore @test_throws ArgumentError dropout(x1, 2) @test_throws ArgumentError dropout!(y1, x1, 3) end + +@static if Test_Enzyme + +@testset "EnzymeRules: dropout " begin + rng = Random.default_rng() + + x1 = randn(Float32, 3000, 4000) + dx1 = zeros(Float32, 3000, 4000) + + dout = randn(Float32, 3000, 4000) + + p = 0.2f0 + + forward, reverse = Enzyme.autodiff_thunk(ReverseSplitWithPrimal, typeof(Const(dropout)), Duplicated, typeof(Const(rng)), typeof(Duplicated(x1, dx1)), typeof(Const(0.2f0))) + + tape, primal, shadow = forward(Const(dropout), Const(rng), Duplicated(x1, dx1), Const(p)) + + shadow .= dout + + reverse(Const(dropout), Const(rng), Duplicated(x1, dx1), Const(p), tape) + + @test dx1[.!tape[1]] ≈ zero(x1)[.!tape[1]] + + val = convert(Float32, 1/(1-p)) + + @test dx1[tape[1]] ≈ (val * dout)[tape[1]] +end + +end \ No newline at end of file diff --git a/test/ext_amdgpu/activations.jl b/test/ext_amdgpu/activations.jl index 2abb0c272..afc59a45c 100644 --- a/test/ext_amdgpu/activations.jl +++ b/test/ext_amdgpu/activations.jl @@ -1,10 +1,11 @@ @testset "Compare CPU & GPU" begin - for (T, atol) in ((Float16, 1f-2), (Float32, 1f-5)) - x = randn(T, 16) - gputest(x -> NNlib.relu.(x), x; atol) - gputest(x -> NNlib.relu6.(x), x; atol) - gputest(x -> NNlib.softplus.(x), x; atol) - gputest(x -> tanh.(x), x; atol) - gputest(x -> identity.(x), x; atol) + for (T, atol) in ((Float16, 1.0f-2), (Float32, 1.0f-5)) + @testset "ndims: $(ndims(x))" for x in (randn(T, 16), randn(T, ntuple(_ -> 2, 5)...), randn(T, ntuple(_ -> 2, 6)...)) + gputest(x -> NNlib.relu.(x), x; atol) + gputest(x -> NNlib.relu6.(x), x; atol) + gputest(x -> NNlib.softplus.(x), x; atol) + gputest(x -> tanh.(x), x; atol) + gputest(x -> identity.(x), x; atol) + end end end diff --git a/test/ext_cuda/activations.jl b/test/ext_cuda/activations.jl index fb9d2ebfc..9d15b1fc5 100644 --- a/test/ext_cuda/activations.jl +++ b/test/ext_cuda/activations.jl @@ -41,3 +41,13 @@ end @test Array(y) == [tanh(1f0)] @test Array(x) == [tanh(tanh(1f0))] end + +@testset "fused act addition broadcast" begin + x = CUDA.rand(Float32, 10, 10) + b = CUDA.rand(Float32, 10) + + for act in getfield.((NNlib,), NNlib.ACTIVATIONS) + fused_act_add = act ∘ + + @test fused_act_add.(x, b) ≈ act.(x .+ b) + end +end diff --git a/test/ext_cuda/batchnorm.jl b/test/ext_cuda/batchnorm.jl index 0adea7024..17bce0f36 100644 --- a/test/ext_cuda/batchnorm.jl +++ b/test/ext_cuda/batchnorm.jl @@ -1,3 +1,5 @@ +using Statistics + @testset "Batchnorm" begin v = CUDA.rand(Float32, 2) m = CUDA.rand(Float32, 2, 5) @@ -24,4 +26,13 @@ @test_throws ArgumentError batchnorm(v, v, m, α, β, 1.0; kws...) end end + @testset "test mode" begin + y_no_track_stats = batchnorm(v, v, m, nothing, nothing, 1.0; training=false, track_stats=false) + running_mean = mean(m, dims=[2]) + running_var = var(m, mean=running_mean, dims=[2], corrected=false) + y_track_stats = batchnorm(v, v, m, running_mean, running_var, 1.0; training=false, track_stats=true) + # batchnorm without tracked stats should equal bathnorm with tracked stats where the + # stats are calculated only on the input. + @test y_no_track_stats ≈ y_track_stats + end end diff --git a/test/ext_cuda/conv.jl b/test/ext_cuda/conv.jl index 00ae228ba..cb34d6ab7 100644 --- a/test/ext_cuda/conv.jl +++ b/test/ext_cuda/conv.jl @@ -101,7 +101,7 @@ using NNlib: DenseConvDims @testset "scale-beta" begin gputest((y, x, w) -> act.(NNlib.conv!(copy(y), x, w, cdims; beta=T(2.0))), y, x, w, checkgrad=false, broken=false) gputest((w, x, y) -> act.(NNlib.∇conv_filter!(copy(w), x, y, cdims; beta=T(2.0))), w, x, y, checkgrad=false, broken=false) - gputest((x, y, w) -> act.(NNlib.∇conv_data!(copy(x), y, w, cdims; beta=T(2.0))), x, y, w, checkgrad=false, broken=true) + gputest((x, y, w) -> act.(NNlib.∇conv_data!(copy(x), y, w, cdims; beta=T(2.0))), x, y, w, checkgrad=false, broken=false) if T <: Complex gputest((y, x, w) -> act.(NNlib.conv!(copy(y), x, w, cdims; beta=T(2.0))), y, real(x), w, checkgrad=false) diff --git a/test/ext_cuda/gather.jl b/test/ext_cuda/gather.jl index 36d42dbcc..9fa30efa8 100644 --- a/test/ext_cuda/gather.jl +++ b/test/ext_cuda/gather.jl @@ -89,4 +89,18 @@ @test y isa CuArray{Float32,3} @test size(y) == (size(src)[1:Nsrc-M]..., size(index)...) gputest(src -> NNlib.gather(src, index), src, checkgrad=true) + + @testset "views" begin + x = cu(rand(2, 5)) + v = view(x, axes(x)...) + i = cu([1, 2]) + outx = NNlib.gather(x, i) + outv = NNlib.gather(v, i) + @test outx == outv + + # discontinuous view + v2 = view(x, :, [1,3,5]) + outv2 = NNlib.gather(v2, i) + @test collect(outv2) == NNlib.gather(collect(v2), collect(i)) + end end diff --git a/test/ext_cuda/sampling.jl b/test/ext_cuda/sampling.jl index 78844d120..8da22cfa7 100644 --- a/test/ext_cuda/sampling.jl +++ b/test/ext_cuda/sampling.jl @@ -51,3 +51,69 @@ end gputest(grid_sample, input, grid; atol=1e-6, padding_mode=padding_mode) end end + +@testset "Grid Sampling 3D" begin + for T in (Float32, Float64) + x = ones(T, (2, 2, 2, 1, 1)) # 3D input with depth=2 + grid = Array{T}(undef, 3, 2, 2, 2, 1) # 3D grid with depth=2 + grid[:, 1, 1, 1, 1] .= (-1, -1, -1) + grid[:, 2, 1, 1, 1] .= (1, -1, -1) + grid[:, 1, 2, 1, 1] .= (-1, 1, -1) + grid[:, 2, 2, 1, 1] .= (1, 1, -1) + grid[:, 1, 1, 2, 1] .= (-1, -1, 1) + grid[:, 2, 1, 2, 1] .= (1, -1, 1) + grid[:, 1, 2, 2, 1] .= (-1, 1, 1) + grid[:, 2, 2, 2, 1] .= (1, 1, 1) + + ∇grid_true = Array{T}(undef, size(grid)) + ∇grid_true[:, 1, 1, 1, 1] .= (0.0, 0.0, 0.0) + ∇grid_true[:, 2, 1, 1, 1] .= (-0.5, 0.0, 0.0) + ∇grid_true[:, 1, 2, 1, 1] .= (0.0, -0.5, 0.0) + ∇grid_true[:, 2, 2, 1, 1] .= (-0.5, -0.5, 0.0) + ∇grid_true[:, 1, 1, 2, 1] .= (0.0, 0.0, -0.5) + ∇grid_true[:, 2, 1, 2, 1] .= (-0.5, 0.0, -0.5) + ∇grid_true[:, 1, 2, 2, 1] .= (0.0, -0.5, -0.5) + ∇grid_true[:, 2, 2, 2, 1] .= (-0.5, -0.5, -0.5) + + + x_gpu, grid_gpu = CuArray(x), CuArray(grid) + + padding_mode = :zeros + y_gpu = grid_sample(x_gpu, grid_gpu; padding_mode=padding_mode) + @test x == collect(y_gpu) + @test eltype(y_gpu) == T + + external_grad = CUDA.ones(T, size(y_gpu)) + ∇input, ∇grid = ∇grid_sample(external_grad, x_gpu, grid_gpu; padding_mode=padding_mode) + @test x == collect(∇input) + @test ∇grid_true == collect(∇grid) + @test eltype(∇input) == T + @test eltype(∇grid) == T + + padding_mode = :border + fill!(∇grid_true, 0.0) + sampled = grid_sample(x_gpu, grid_gpu; padding_mode=padding_mode) + @test x == collect(sampled) + @test eltype(sampled) == T + + ∇input, ∇grid = ∇grid_sample(external_grad, x_gpu, grid_gpu; padding_mode=padding_mode) + @test x == collect(∇input) + @test ∇grid_true == collect(∇grid) + @test eltype(∇input) == T + @test eltype(∇grid) == T + end +end + +@testset "Compare grid sampling with NNlib 3D" begin + w, h, d, c, n = 16, 16, 16, 2, 4 # Added depth dimension `d` + input = rand(Float64, w, h, d, c, n) + grid = zeros(Float64, 3, w, h, d, n) # 3D grid with depth `d` + @inbounds for xi in 1:w, yi in 1:h, zi in 1:d, ni in 1:n + grid[1, xi, yi, zi, ni] = (xi / w) * 2.0 - 1.0 + 0.01 + grid[2, xi, yi, zi, ni] = (yi / h) * 2.0 - 1.0 + grid[3, xi, yi, zi, ni] = (zi / d) * 2.0 - 1.0 + end + for padding_mode in (:zeros, :border) + gputest(grid_sample, input, grid; atol=1e-6, padding_mode=padding_mode) + end +end diff --git a/test/fold.jl b/test/fold.jl deleted file mode 100644 index 35296fb47..000000000 --- a/test/fold.jl +++ /dev/null @@ -1,40 +0,0 @@ -using NNlib, Test - -@testset "unfold wrapper" begin - x = rand(rng, 16, 16, 3, 10) - w = rand(rng, 5, 5, 3, 2) - @test size(NNlib.unfold(x, size(w))) == (144, 75, 10) - @test size(NNlib.unfold(x, size(w); pad=2)) == (256, 75, 10) - @test size(NNlib.unfold(x, size(w); stride=2)) == (36, 75, 10) - @test size(NNlib.unfold(x, size(w); dilation=2)) == (64, 75, 10) -end - -@testset "Inverses: spatial_rank=$spatial_rank" for spatial_rank in (1, 2, 3) - x = rand(rng, repeat([8], spatial_rank)..., 3, 2) - w = rand(rng, repeat([3], spatial_rank)..., 3, 3) - cdims = DenseConvDims(x, w; padding=1) - y = NNlib.unfold(x, cdims) - z = NNlib.fold(y, size(x), cdims) - divisor = NNlib.fold(NNlib.unfold(ones(eltype(x), size(x)...), cdims), size(x), cdims) - @test isapprox(z ./ divisor, x, rtol=1.0e-7) - - # introduce stride - cdims = DenseConvDims(x, w; padding=1, stride=2) - y = NNlib.unfold(x, cdims) - z = NNlib.fold(y, size(x), cdims) - divisor = NNlib.fold(NNlib.unfold(ones(eltype(x), size(x)...), cdims), size(x), cdims) - @test isapprox(z ./ divisor, x, rtol=1.0e-7) -end - -@testset "AutoDiff: spatial_rank=$spatial_rank" for spatial_rank in (1, 2, 3) - x = rand(rng, repeat([5], spatial_rank)..., 3, 2) - w = rand(rng, repeat([3], spatial_rank)..., 3, 3) - cdims = DenseConvDims(x, w) - gradtest(x -> NNlib.unfold(x, cdims), x) - test_rrule(NNlib.unfold, x, cdims) - - y = NNlib.unfold(x, cdims) - gradtest(y -> NNlib.fold(y, size(x), cdims), y) - test_rrule(NNlib.fold, y, size(x), cdims) -end - diff --git a/test/inference.jl b/test/inference.jl index 31785597c..9b3e74db8 100644 --- a/test/inference.jl +++ b/test/inference.jl @@ -3,9 +3,6 @@ import NNlib: conv_direct, conv_im2col, channels_in, channels_out @testset "Conv Inference" begin for T in (Float32, Float64) impl = [conv, conv_direct, conv_im2col] - if NNlib.is_nnpack_available() && T == Float32 - push!(impl, NNlib.conv_nnpack) - end x = rand(T, 10, 10, 3, 2) w = rand(T, 3, 3, 3, 1) diff --git a/test/padding.jl b/test/padding.jl index 6550d62d8..a066d0547 100644 --- a/test/padding.jl +++ b/test/padding.jl @@ -1,65 +1,68 @@ using NNlib: pad_constant, pad_repeat, pad_zeros, pad_reflect, pad_symmetric, pad_circular @testset "padding constant" begin - x = rand(2, 2, 2) - + x = rand(2, 2, 2) + p = NNlib.gen_pad((1,2,3,4,5,6), (1,2,3), 4) @test p == ((1, 2), (3, 4), (5, 6), (0, 0)) - + @test_throws ArgumentError NNlib.gen_pad((1,2,3,4,5,), (1,2,3), 4) - + p = NNlib.gen_pad((1,3), (1,3), 4) @test p == ((1, 1), (0, 0), (3, 3), (0, 0)) - + p = NNlib.gen_pad(1, (1,2,3), 4) @test p == ((1, 1), (1, 1), (1, 1), (0, 0)) - + p = NNlib.gen_pad(3, :, 2) @test p == ((3, 3), (3, 3)) - + + p = NNlib.gen_pad((1,0), 1, 2) + @test p == ((1,0), (0,0)) + y = pad_constant(x, (3, 2, 4)) @test size(y) == (8, 6, 10) @test y[4:5, 3:4, 5:6] ≈ x y[4:5, 3:4, 5:6] .= 0 @test all(y .== 0) - + @test pad_constant(x, (3, 2, 4)) ≈ pad_zeros(x, (3, 2, 4)) - @test pad_zeros(x, 2) ≈ pad_zeros(x, (2,2,2)) - + @test pad_zeros(x, 2) ≈ pad_zeros(x, (2,2,2)) + y = pad_constant(x, (3, 2, 4, 5), 1.2, dims = (1,3)) @test size(y) == (7, 2, 11) @test y[4:5, 1:2, 5:6] ≈ x y[4:5, 1:2, 5:6] .= 1.2 @test all(y .== 1.2) - + @test pad_constant(x, (2,2,2,2), 1.2, dims = (1,3)) ≈ pad_constant(x, 2, 1.2, dims = (1,3)) - + @test pad_constant(x, 1, dims = 1:2) == - pad_constant(x, 1, dims = (1,2)) - + pad_constant(x, 1, dims = (1,2)) + @test size(pad_constant(x, 1, dims = 1)) == (4,2,2) - + @test all(pad_zeros(randn(2), (1, 2))[[1, 4, 5]] .== 0) - + gradtest(x -> pad_constant(x, 2), rand(2,2,2)) gradtest(x -> pad_constant(x, (2, 1, 1, 2)), rand(2,2)) gradtest(x -> pad_constant(x, (2, 1,)), rand(2)) end @testset "padding repeat" begin - x = rand(2, 2, 2) - + x = rand(2, 2, 2) + # y = @inferred pad_repeat(x, (3, 2, 4, 5)) y = pad_repeat(x, (3, 2, 4, 5)) @test size(y) == (7, 11, 2) @test y[4:5, 5:6, :] ≈ x - + # y = @inferred pad_repeat(x, (3, 2, 4, 5), dims=(1,3)) y = pad_repeat(x, (3, 2, 4, 5), dims=(1,3)) @test size(y) == (7, 2, 11) @test y[4:5, :, 5:6] ≈ x - + @test pad_repeat(reshape(1:9, 3, 3), (1,2)) == [1 4 7 1 4 7 @@ -67,15 +70,15 @@ end 3 6 9 3 6 9 3 6 9] - + @test pad_repeat(reshape(1:9, 3, 3), (2,2), dims=2) == [1 1 1 4 7 7 7 2 2 2 5 8 8 8 3 3 3 6 9 9 9] - + @test pad_repeat(x, (2, 2, 2, 2), dims=(1,3)) ≈ pad_repeat(x, 2, dims=(1,3)) - + gradtest(x -> pad_repeat(x, (2,2,2,2)), rand(2,2,2)) end @@ -84,7 +87,7 @@ end @test y == [7 4 1 4 7 4 1 8 5 2 5 8 5 2 9 6 3 6 9 6 3] - + y = pad_reflect(reshape(1:9, 3, 3), (2,2,2,2)) @test y == [9 6 3 6 9 6 3 8 5 2 5 8 5 2 @@ -93,14 +96,26 @@ end 9 6 3 6 9 6 3 8 5 2 5 8 5 2 7 4 1 4 7 4 1] - - x = rand(4, 4, 4) + + x = rand(4, 4, 4) @test pad_reflect(x, (2, 2, 2, 2), dims=(1,3)) ≈ pad_reflect(x, 2, dims=(1,3)) - - # pad_reflect needs larger test input as padding must + + # pad_reflect needs larger test input as padding must # be strictly less than array size in that dimension gradtest(x -> pad_reflect(x, (2,2,2,2)), rand(3,3,3)) + + x = reshape(1:9, 3, 3, 1, 1) + @test NNlib.pad_reflect(x, (1, 0, 1, 0); dims=1:2) == [ + 5 2 5 8; + 4 1 4 7; + 5 2 5 8; + 6 3 6 9;;;;] + @test NNlib.pad_reflect(x, (0, 1, 0, 1); dims=1:2) == [ + 1 4 7 4; + 2 5 8 5; + 3 6 9 6; + 2 5 8 5;;;;] end @testset "padding symmetric" begin @@ -108,7 +123,7 @@ end @test y == [4 1 1 4 7 7 4 5 2 2 5 8 8 5 6 3 3 6 9 9 6] - + y = pad_symmetric(reshape(1:9, 3, 3), (2,2,2,2)) @test y == [5 2 2 5 8 8 5 4 1 1 4 7 7 4 @@ -117,12 +132,24 @@ end 6 3 3 6 9 9 6 6 3 3 6 9 9 6 5 2 2 5 8 8 5] - - x = rand(4, 4, 4) + + x = rand(4, 4, 4) @test pad_symmetric(x, (2, 2, 2, 2), dims=(1,3)) ≈ pad_symmetric(x, 2, dims=(1,3)) - + gradtest(x -> pad_symmetric(x, (2,2,2,2)), rand(2,2,2)) + + x = reshape(1:9, 3, 3, 1, 1) + @test NNlib.pad_symmetric(x, (1, 0, 1, 0); dims=1:2) == [ + 1 1 4 7; + 1 1 4 7; + 2 2 5 8; + 3 3 6 9;;;;] + @test NNlib.pad_symmetric(x, (0, 1, 0, 1); dims=1:2) == [ + 1 4 7 7; + 2 5 8 8; + 3 6 9 9; + 3 6 9 9;;;;] end @testset "padding circular" begin @@ -130,7 +157,7 @@ end @test y == [4 7 1 4 7 1 4 5 8 2 5 8 2 5 6 9 3 6 9 3 6] - + y = pad_circular(reshape(1:9, 3, 3), (2,2,2,2)) @test y == [5 8 2 5 8 2 5 6 9 3 6 9 3 6 @@ -139,10 +166,10 @@ end 6 9 3 6 9 3 6 4 7 1 4 7 1 4 5 8 2 5 8 2 5] - - x = rand(4, 4, 4) + + x = rand(4, 4, 4) @test pad_circular(x, (2, 2, 2, 2), dims=(1,3)) ≈ pad_circular(x, 2, dims=(1,3)) - + gradtest(x -> pad_circular(x, (2,2,2,2)), rand(2,2,2)) end diff --git a/test/pooling.jl b/test/pooling.jl index b4b4f40b7..f9d57ade7 100644 --- a/test/pooling.jl +++ b/test/pooling.jl @@ -869,16 +869,6 @@ maxpool_answer_nature = Dict( @test y_maxpool_dir ≈ y_maxpool atol = 1e-6 @test isapprox(config.dx_maxpool, NNlib.∇maxpool_direct(dy, y_maxpool_dir, x, pdims), rtol=1e-5) @test isapprox(config.dx_meanpool, NNlib.∇meanpool_direct(dy, y_meanpool_dir, x, pdims), rtol=1e-5) - - # CHECK NNPACK - if NNlib.is_nnpack_available() && T == Float32 - if NNlib.nnpack_supported_operation(pdims) - y_maxpool_nnp = NNlib.maxpool_nnpack(x, pdims) - @test y_maxpool_nnp ≈ y_maxpool atol = 1e-6 - # NNPACK maxpool gradient still missing - # @test isapprox(config.dx_maxpool, NNlib.∇maxpool_nnpack(dy, y_maxpool_nnp, config.x, pdims), rtol=1e-5) - end - end end for (rank_name, config_dict) in maxpool_answer_nature @@ -940,16 +930,6 @@ maxpool_answer_nature = Dict( @test RD.gradient(_x -> only(maxpool(_x,(2,2))), x)[:,:,1,1] == [0 0; 0 1] @test only(meanpool(x, (2,2))) == 2.5 @test all(==(0.25), RD.gradient(_x -> only(meanpool(_x,(2,2))), x)) - - # if NNlib.is_nnpack_available() - # if NNlib.nnpack_supported_operation(pdims1) - # @test NNlib.maxpool_nnpack(x, pdims1) isa Array{Float32, 4} - # end - # if NNlib.nnpack_supported_operation(pdims2) - # print("you should not see this") - # @test NNlib.maxpool_nnpack(x, pdims2) isa Array{Float32, 4} - # end - # end end @testset "AutoDiff: spatial_rank=$spatial_rank" for spatial_rank in (1, 2) @@ -967,3 +947,25 @@ end gradtest(x -> sum(maxpool(x, k)), x, skip = spatial_rank==2) gradtest(x -> sum(meanpool(x, k)), x) end + +@static if Test_Enzyme + +@testset "EnzymeRules: pooling! $pool spatial_rank=$spatial_rank " for spatial_rank in (1, 2), + (pool, pool!) in ((maxpool, maxpool!), (meanpool, meanpool!)) + + x = rand(rng, repeat([10], spatial_rank)..., 3, 2) + pdims = PoolDims(x, 2) + y = pool(x, pdims) + + for Tret in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated), + Tdst in (EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated), + Tsrc in (EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated) + + EnzymeTestUtils.are_activities_compatible(Tret, Tdst, Tsrc) || continue + + EnzymeTestUtils.test_reverse(pool!, Tret, (y, Tdst), (x, Tsrc), (pdims, EnzymeCore.Const)) + end + +end + +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 03602a40d..b8080b6ba 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,8 @@ using NNlib, Test, Statistics, Random using ChainRulesCore, ChainRulesTestUtils using Base.Broadcast: broadcasted +import EnzymeTestUtils +using EnzymeCore import FiniteDifferences import ForwardDiff import Zygote @@ -8,8 +10,15 @@ using Zygote: gradient using StableRNGs using Documenter using Adapt +using ImageTransformations +using Interpolations: Constant using KernelAbstractions +using FFTW import ReverseDiff as RD # used in `pooling.jl` +import Pkg +using SpecialFunctions + +const Test_Enzyme = VERSION <= v"1.10-" DocMeta.setdocmeta!(NNlib, :DocTestSetup, :(using NNlib, UnicodePlots); recursive=true) @@ -34,32 +43,53 @@ end cpu(x) = adapt(CPU(), x) -include("gather.jl") -include("scatter.jl") -include("upsample.jl") +include("testsuite/gather.jl") +include("testsuite/scatter.jl") +include("testsuite/upsample.jl") +include("testsuite/rotation.jl") +include("testsuite/spectral.jl") +include("testsuite/fold.jl") function nnlib_testsuite(Backend; skip_tests = Set{String}()) @conditional_testset "Upsample" skip_tests begin upsample_testsuite(Backend) end + @conditional_testset "rotation" skip_tests begin + rotation_testsuite(Backend) + end @conditional_testset "Gather" skip_tests begin gather_testsuite(Backend) end @conditional_testset "Scatter" skip_tests begin scatter_testsuite(Backend) end + @conditional_testset "Spectral" skip_tests begin + spectral_testsuite(Backend) + end + @conditional_testset "Fold" skip_tests begin + fold_testsuite(Backend) + end end @testset verbose=true "NNlib.jl" begin if get(ENV, "NNLIB_TEST_CPU", "true") == "true" - @testset "CPU" begin + @testset "CPU" begin @testset "Doctests" begin doctest(NNlib, manual=false) end nnlib_testsuite(CPU) + if Threads.nthreads(:default) > 1 + @test NNlib.should_use_spawn() + NNlib.@disallow_spawns begin + @test NNlib.should_use_spawn() == false + end + else + @test NNlib.should_use_spawn() == false + end + @testset "Activation Functions" begin include("activations.jl") include("bias_act.jl") @@ -86,10 +116,6 @@ end include("dropout.jl") end - @testset "Fold/Unfold" begin - include("fold.jl") - end - @testset "Inference" begin include("inference.jl") end @@ -123,6 +149,8 @@ end end if get(ENV, "NNLIB_TEST_CUDA", "false") == "true" + Pkg.add(["CUDA", "cuDNN"]) + using CUDA if CUDA.functional() @testset "CUDA" begin @@ -135,19 +163,20 @@ end end else @info "Skipping CUDA tests, set NNLIB_TEST_CUDA=true to run them" - end + end if get(ENV, "NNLIB_TEST_AMDGPU", "false") == "true" + Pkg.add("AMDGPU") + using AMDGPU AMDGPU.versioninfo() if AMDGPU.functional() && AMDGPU.functional(:MIOpen) - @show AMDGPU.MIOpen.version() @testset "AMDGPU" begin nnlib_testsuite(ROCBackend) - AMDGPU.synchronize(; blocking=false) + AMDGPU.synchronize(; blocking=false, stop_hostcalls=true) include("ext_amdgpu/runtests.jl") - AMDGPU.synchronize(; blocking=false) + AMDGPU.synchronize(; blocking=false, stop_hostcalls=true) end else @info "AMDGPU.jl package is not functional. Skipping AMDGPU tests." diff --git a/test/sampling.jl b/test/sampling.jl index f560be568..bfd36bfaf 100644 --- a/test/sampling.jl +++ b/test/sampling.jl @@ -61,3 +61,88 @@ end y_true = ones(Float64, size(y)) @test y_true == y end + +@testset "Known gradients 3D" begin + x = ones(Float64, (2, 2, 2, 1, 1)) # 3D input with depth=2 + grid = Array{Float64}(undef, 3, 2, 2, 2, 1) # 3D grid with depth=2 + grid[:, 1, 1, 1, 1] .= (-1, -1, -1) + grid[:, 2, 1, 1, 1] .= (1, -1, -1) + grid[:, 1, 2, 1, 1] .= (-1, 1, -1) + grid[:, 2, 2, 1, 1] .= (1, 1, -1) + grid[:, 1, 1, 2, 1] .= (-1, -1, 1) + grid[:, 2, 1, 2, 1] .= (1, -1, 1) + grid[:, 1, 2, 2, 1] .= (-1, 1, 1) + grid[:, 2, 2, 2, 1] .= (1, 1, 1) + + ∇grid_true = Array{Float64}(undef, size(grid)) + ∇grid_true[:, 1, 1, 1, 1] .= (0.0, 0.0, 0.0) + ∇grid_true[:, 2, 1, 1, 1] .= (-0.5, 0.0, 0.0) + ∇grid_true[:, 1, 2, 1, 1] .= (0.0, -0.5, 0.0) + ∇grid_true[:, 2, 2, 1, 1] .= (-0.5, -0.5, 0.0) + ∇grid_true[:, 1, 1, 2, 1] .= (0.0, 0.0, -0.5) + ∇grid_true[:, 2, 1, 2, 1] .= (-0.5, 0.0, -0.5) + ∇grid_true[:, 1, 2, 2, 1] .= (0.0, -0.5, -0.5) + ∇grid_true[:, 2, 2, 2, 1] .= (-0.5, -0.5, -0.5) + + # ∇grid_true[:, :, :, 1, 1] = [ + # [[0.0, 0.0, 0.0], [-0.5, 0.0, 0.0]], + # [[0.0, -0.5, 0.0], [-0.5, -0.5, 0.0]] + # ] + # ∇grid_true[:, :, :, 2, 1] = [ + # [[0.0, 0.0, -0.5], [-0.5, 0.0, -0.5]] + # [[0.0, -0.5, -0.5], [-0.5, -0.5, -0.5]] + # ] + + padding_mode = :zeros + sampled = grid_sample(x, grid; padding_mode=padding_mode) + @test x == sampled + @test eltype(sampled) == Float64 + external_grad = ones(size(sampled)) + ∇input, ∇grid = ∇grid_sample(external_grad, x, grid; padding_mode=padding_mode) + @test ∇input == x + @test ∇grid == ∇grid_true + @test eltype(∇input) == Float64 + @test eltype(∇grid) == Float64 + + # ∇grid from FiniteDifferences is incorrect in case when 0-padding. + # gradtest(grid_sample, x, grid; fkwargs=(padding_mode=padding_mode,)) + + padding_mode = :border + fill!(∇grid_true, 0.0) + sampled = grid_sample(x, grid; padding_mode=padding_mode) + @test x == sampled + @test eltype(sampled) == Float64 + external_grad = ones(size(sampled)) + ∇input, ∇grid = ∇grid_sample(external_grad, x, grid; padding_mode=padding_mode) + @test ∇input == x + @test ∇grid == ∇grid_true + @test eltype(∇input) == Float64 + @test eltype(∇grid) == Float64 + + gradtest(grid_sample, x, grid; fkwargs=(padding_mode=padding_mode,)) +end + +@testset "Test out-of-bounds for different paddings 3D" begin + x = ones(Float64, (2, 2, 2, 1, 1)) # 3D input with depth=2 + grid = Array{Float64}(undef, 3, 2, 2, 2, 1) # 3D grid with depth=2 + grid[:, 1, 1, 1, 1] .= (-3, -1, -1) + grid[:, 2, 1, 1, 1] .= (0, -1, -1) + grid[:, 1, 2, 1, 1] .= (-1, 3, -1) + grid[:, 2, 2, 1, 1] .= (0, 1, -1) + grid[:, 1, 1, 2, 1] .= (-1, -1, 3) + grid[:, 2, 1, 2, 1] .= (0, -1, 3) + grid[:, 1, 2, 2, 1] .= (-1, 1, 3) + grid[:, 2, 2, 2, 1] .= (0, 1, 3) + + # With 0-padding, out-of-bound values will contribute nothing to + # the output values, because they are too far from any bound. + y = grid_sample(x, grid; padding_mode=:zeros) + y_true = reshape(Float64[[0, 1] [0, 1] [0, 0] [0, 0]], size(y)) + @test y_true == y + + # With border-padding, out-of-bound values simply become border values + # and the result should be all ones. + y = grid_sample(x, grid; padding_mode=:border) + y_true = ones(Float64, size(y)) + @test y_true == y +end diff --git a/test/testsuite/fold.jl b/test/testsuite/fold.jl new file mode 100644 index 000000000..294768bbe --- /dev/null +++ b/test/testsuite/fold.jl @@ -0,0 +1,48 @@ +import NNlib + +function fold_testsuite(Backend) + device(x) = adapt(Backend(), x) + gradtest_fn = Backend == CPU ? gradtest : gputest + + @testset "unfold wrapper" begin + x = device(rand(rng, 16, 16, 3, 10)) + w = device(rand(rng, 5, 5, 3, 2)) + @test size(NNlib.unfold(x, size(w))) == (144, 75, 10) + @test size(NNlib.unfold(x, size(w); pad=2)) == (256, 75, 10) + @test size(NNlib.unfold(x, size(w); stride=2)) == (36, 75, 10) + @test size(NNlib.unfold(x, size(w); dilation=2)) == (64, 75, 10) + end + + @testset "Inverses: spatial_rank=$spatial_rank" for spatial_rank in (1, 2, 3) + x = device(rand(rng, repeat([8], spatial_rank)..., 3, 2)) + w = device(rand(rng, repeat([3], spatial_rank)..., 3, 3)) + + cdims = DenseConvDims(x, w; padding=1) + y = NNlib.unfold(x, cdims) + z = NNlib.fold(y, size(x), cdims) + + o = device(ones(eltype(x), size(x)...)) + divisor = NNlib.fold(NNlib.unfold(o, cdims), size(x), cdims) + @test isapprox(z ./ divisor, x, rtol=1.0e-7) + + # introduce stride + cdims = DenseConvDims(x, w; padding=1, stride=2) + y = NNlib.unfold(x, cdims) + z = NNlib.fold(y, size(x), cdims) + divisor = NNlib.fold(NNlib.unfold(o, cdims), size(x), cdims) + @test isapprox(z ./ divisor, x, rtol=1.0e-7) + end + + @testset "AutoDiff: spatial_rank=$spatial_rank" for spatial_rank in (1, 2, 3) + x = device(rand(rng, repeat([5], spatial_rank)..., 3, 2)) + w = device(rand(rng, repeat([3], spatial_rank)..., 3, 3)) + cdims = DenseConvDims(x, w) + + gradtest_fn(x -> NNlib.unfold(x, cdims), x) + Backend == CPU && test_rrule(NNlib.unfold, x, cdims) + + y = NNlib.unfold(x, cdims) + gradtest_fn(y -> NNlib.fold(y, size(x), cdims), y) + Backend == CPU && test_rrule(NNlib.fold, y, size(x), cdims) + end +end diff --git a/test/gather.jl b/test/testsuite/gather.jl similarity index 87% rename from test/gather.jl rename to test/testsuite/gather.jl index e3221145b..92e3bfb7d 100644 --- a/test/gather.jl +++ b/test/testsuite/gather.jl @@ -1,4 +1,6 @@ using NNlib: gather, gather! +import EnzymeTestUtils +using EnzymeCore function gather_testsuite(Backend) device(x) = adapt(Backend(), x) @@ -152,6 +154,27 @@ function gather_testsuite(Backend) gradtest_fn((s, i) -> gather(s, i), src, idx) end + @static if Test_Enzyme + + @testset "EnzymeRules: gather! gradient for scalar index" begin + src = device(Float64[3, 4, 5, 6, 7]) + idx = device([ + 1 2 3 4; + 4 2 1 3; + 3 5 5 3]) + dst = gather(src, idx) + for Tret in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated), + Tdst in (EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated), + Tsrc in (EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated) + + EnzymeTestUtils.are_activities_compatible(Tret, Tdst, Tsrc) || continue + + EnzymeTestUtils.test_reverse(gather!, Tret, (dst, Tdst), (src, Tsrc), (idx, EnzymeCore.Const)) + end + end + + end + @testset "gather gradient for tuple index" begin src = device(Float64[ 3 5 7 diff --git a/test/testsuite/rotation.jl b/test/testsuite/rotation.jl new file mode 100644 index 000000000..ae079791b --- /dev/null +++ b/test/testsuite/rotation.jl @@ -0,0 +1,80 @@ +function rotation_testsuite(Backend) + device(x) = adapt(Backend(), x) + gradtest_fn = Backend == CPU ? gradtest : gputest + T = Float64 + atol = T == Float32 ? 1e-3 : 1e-6 + rtol = T == Float32 ? 1f-3 : 1f-6 + angles = deg2rad.([0, 0.0001, 35, 90, -90, -90.0123, 170, 180, 270, 360, 450, 1234.1234]) + + @testset "imrotate" begin + @testset "Simple test" begin + arr = device(zeros((6, 6, 1, 1))); + arr[3:4, 4, 1, 1] .= 1; + @test all(cpu(NNlib.imrotate(arr, deg2rad(45))) .≈ [0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.29289321881345254 0.585786437626905 0.0; 0.0 0.0 0.08578643762690495 1.0 0.2928932188134524 0.0; 0.0 0.0 0.0 0.08578643762690495 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0]) + end + + + @testset "Compare with ImageTransformations" begin + for sz in [(51,51,1,1), (52,52,1,1), (51,52,1,1), (52,51,1,1)] + rotation_center = (sz[1:2] .+ 1) ./ 2 + arr1 = device(zeros(T, sz)) + arr1[15:40, 15:40, :, :] .= device(1 .+ randn((26, 26))) + arr2 = device(zeros(T, (sz[1], sz[2], sz[3], 3))) + arr2[15:40, 15:40, :, :] .= device(arr1[15:40, 15:40, :, :]) + + for method in [:nearest, :bilinear] + @testset "$method" begin + for angle in angles + res1 = cpu(NNlib.imrotate(arr1, angle; method, rotation_center=rotation_center)) + res2 = cpu(NNlib.imrotate(arr2, angle; method, rotation_center=rotation_center)) + if method == :nearest + res_IT = ImageTransformations.imrotate(cpu(arr1)[:, :, 1, 1], angle, axes(arr1)[1:2], method=Constant(), fillvalue=0) + elseif method == :bilinear + res_IT = ImageTransformations.imrotate(cpu(arr1)[:, :, 1, 1], angle, axes(arr1)[1:2], fillvalue=0) + end + if method == :nearest + @test ≈(1 .+ res1[:, :, :, :], 1 .+ res_IT[:, :], rtol=0.5) + @test ≈(1 .+ res1[:, :, :, :], 1 .+ res2[:, :,:, 1], rtol=0.5) + @test ≈(1 .+ res1[:, :, :, :], 1 .+ res2[:, :,:, 2], rtol=0.5) + else + @test all(.≈(1 .+ res1[:, :, :, :], 1 .+ res_IT[:, :], rtol=rtol)) + @test all(.≈(1 .+ res1[:, :, :, :], 1 .+ res2[:, :,:, 1], rtol=rtol)) + @test all(.≈(1 .+ res1[:, :, :, :], 1 .+ res2[:, :,:, 2], rtol=rtol)) + end + end + end + end + end + end + + @testset "Compare for plausibilty" begin + @testset "Special cases of rotation" begin + arr = device(zeros(T, (10, 10, 1, 3))) + arr[6, 6, :, 1] .= 1 + arr[6, 6, :, 2] .= 2 + arr[6, 6, :, 3] .= 3 + + for method in [:bilinear, :nearest] + @test all(.≈(arr , NNlib.imrotate(arr, deg2rad(0); method))) + @test all(.≈(arr , NNlib.imrotate(arr, deg2rad(90); method))) + @test all(.≈(arr , NNlib.imrotate(arr, deg2rad(180); method))) + @test all(.≈(arr , NNlib.imrotate(arr, deg2rad(270); method))) + @test all(.≈(arr , NNlib.imrotate(arr, deg2rad(360); method))) + end + end + end + + @testset "Test gradients" begin + for method in [:nearest, :bilinear] + for angle in angles + gradtest_fn( + x -> NNlib.imrotate(x, angle; method), + device(rand(T, 11,11,1,1)); atol) + gradtest_fn( + x -> NNlib.imrotate(x, angle; method), + device(rand(T, 10,10,1,1)); atol) + end + end + end + end +end diff --git a/test/scatter.jl b/test/testsuite/scatter.jl similarity index 73% rename from test/scatter.jl rename to test/testsuite/scatter.jl index 26fc06cde..aa0b1c41e 100644 --- a/test/scatter.jl +++ b/test/testsuite/scatter.jl @@ -69,37 +69,36 @@ res = Dict( ) function test_scatter(device, types, ops; pt, ops_skip_types) - for T in types + for T in types, IT in (Int8, Int64) PT = promote_type(T, pt) - @testset "$T" begin - for op in ops - skip_types = get(ops_skip_types, op, []) - @testset "$op" begin - for idx = values(idxs), dims = [0, 1] - idx = device(idx) - dst = device(dsts[dims]) - - mutated = true - target_y = res[(op, dims, mutated)] - src = device(srcs[(dims, mutated)]) - if op == / - src = src .* T(2) - end - - @test cpu(scatter!(op, T.(dst), T.(src), idx)) == T.(target_y) - @test cpu(scatter!(op, T.(dst), src, idx)) == PT.(target_y) - if op == / - @test cpu(scatter!(op, T.(dst), T.(src), idx)) == PT.(target_y) - else - @test cpu(scatter!(op, copy(dst), T.(src), idx)) == PT.(target_y) - end - - if T ∉ skip_types - mutated = false - src = device(srcs[(dims, mutated)]) - @test cpu(scatter(op, T.(src), idx)) == T.(res[(op, dims, mutated)]) - end - end + @testset "eltype $T - idx eltype $IT - $op" for op in ops + skip_types = get(ops_skip_types, op, []) + for idx = values(idxs), dims = [0, 1] + # Tests with indices of different types. + eltype(idx) == Int && (idx = IT.(idx);) + + idx = device(idx) + dst = device(dsts[dims]) + + mutated = true + target_y = res[(op, dims, mutated)] + src = device(srcs[(dims, mutated)]) + if op == / + src = src .* T(2) + end + + @test cpu(scatter!(op, T.(dst), T.(src), idx)) == T.(target_y) + @test cpu(scatter!(op, T.(dst), src, idx)) == PT.(target_y) + if op == / + @test cpu(scatter!(op, T.(dst), T.(src), idx)) == PT.(target_y) + else + @test cpu(scatter!(op, copy(dst), T.(src), idx)) == PT.(target_y) + end + + if T ∉ skip_types + mutated = false + src = device(srcs[(dims, mutated)]) + @test cpu(scatter(op, T.(src), idx)) == T.(res[(op, dims, mutated)]) end end end @@ -174,14 +173,14 @@ function scatter_testsuite(Backend) else (+, -, mean, max, min) end - for op in ops, i in (0, 1) + for op in ops, i in (0, 1), IT in (Int8, Int64) PT = ( # If not CPU and CUDA -> use Int64 for min/max. Backend != CPU && Symbol(Backend) != :CUDABackend && (op == max || op == min)) ? Int64 : T src = device(srcs[(i, true)]) - idx = device(idxs[:int]) + idx = device(IT.(idxs[:int])) dst = device(PT.(dsts[i])) Backend == CPU ? gradtest_fn(x -> scatter!(op, copy(x), src, idx), dst; fdm=fdm(op)) : @@ -195,17 +194,41 @@ function scatter_testsuite(Backend) else (+, -, mean, max, min) end - for op in ops, i in (0, 1) + for op in ops, i in (0, 1), IT in (Int8, Int64) PT = ( # If not CPU and CUDA -> use Int64 for min/max. Backend != CPU && Symbol(Backend) != :CUDABackend && (op == max || op == min)) ? Int64 : T src = PT.(device(srcs[(i, false)])) - idx = device(idxs[:int]) + idx = device(IT.(idxs[:int])) Backend == CPU ? gradtest_fn(xs -> scatter(op, xs, idx), src; fdm=fdm(op)) : gradtest_fn((xs, i) -> scatter(op, xs, i), src, idx) end end + + + @static if Test_Enzyme + + @testset "EnzymeRules" begin + idx = device([2, 2, 3, 4, 4]) + src = device(ones(T, 3, 5)) + + for op in (+, -) + + dst = scatter(op, src, idx) + + for Tret in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated), + Tdst in (EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated), + Tsrc in (EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated) + + EnzymeTestUtils.are_activities_compatible(Tret, Tdst, Tsrc) || continue + + EnzymeTestUtils.test_reverse(scatter!, Tret, (op, EnzymeCore.Const), (dst, Tdst), (src, Tsrc), (idx, EnzymeCore.Const)) + end + end + end + + end end end diff --git a/test/testsuite/spectral.jl b/test/testsuite/spectral.jl new file mode 100644 index 000000000..a78a79159 --- /dev/null +++ b/test/testsuite/spectral.jl @@ -0,0 +1,161 @@ +function spectral_testsuite(Backend) + cpu(x) = adapt(CPU(), x) + device(x) = adapt(Backend(), x) + gradtest_fn = Backend == CPU ? gradtest : gputest + + @testset "Window functions" begin + for window_fn in (hann_window, hamming_window) + @inferred window_fn(10, Float32) + @inferred window_fn(10, Float64) + + w = window_fn(10) + @test length(w) == 10 + @test eltype(w) == Float32 + + wp = window_fn(10; periodic=false) + @test wp[1:5] ≈ reverse(wp[6:10]) + + @test window_fn(10; periodic=true) ≈ window_fn(10 + 1; periodic=false)[1:10] + end + end + + @testset "STFT" for batch in ((), (3,)) + @testset "Grads" begin + if Backend != CPU + x = rand(Float32, 16, batch...) + window = hann_window(16) + + gradtest_fn(s -> abs.(stft(s; n_fft=16)), x) + gradtest_fn((s, w) -> abs.(stft(s; n_fft=16, window=w)), x, window) + + x = rand(Float32, 2045, batch...) + n_fft = 256 + window = hann_window(n_fft) + gradtest_fn((s, w) -> abs.(stft(s; n_fft, window=w)), x, window) + gradtest_fn((s, w) -> abs.(stft(s; n_fft, window=w, center=false)), x, window) + gradtest_fn((s, w) -> abs.(stft(s; n_fft, window=w, normalized=true)), x, window) + end + end + + @testset "Batch $batch" begin + x = device(ones(Float32, 16, batch...)) + # TODO fix type stability for pad_reflect + # @inferred stft(x; n_fft=16) + + bd = ntuple(_ -> Colon(), length(batch)) + + y = stft(x; n_fft=16) + @test size(y) == (9, 5, batch...) + @test all(real(cpu(y))[1, :, bd...] .≈ 16) + + xx = istft(y; n_fft=16) + @test size(xx) == (16, batch...) + @test cpu(x) ≈ cpu(xx) + + # Test multiple hops. + x = device(rand(Float32, 2048, batch...)) + y = stft(x; n_fft=1024) + xx = istft(y; n_fft=1024) + @test cpu(x) ≈ cpu(xx) + + # Test odd sizes. + x = device(rand(Float32, 1111, batch...)) + y = stft(x; n_fft=256) + xx = istft(y; n_fft=256, original_length=size(x, 1)) + @test cpu(x) ≈ cpu(xx) + + # Output from inverse is cropped on the right + # without knowing the original size. + xx = istft(y; n_fft=256) + @test length(xx) < length(x) + @test cpu(x)[[1:s for s in size(xx)]...] ≈ cpu(xx) + + # Test different options. + + # Normalized. + x = device(rand(Float32, 1234, batch...)) + y = stft(x; n_fft=512, normalized=true) + xx = istft(y; n_fft=512, normalized=true, original_length=size(x, 1)) + @test cpu(x) ≈ cpu(xx) + + # With window. + window = device(hann_window(512)) + y = stft(x; n_fft=512, window) + xx = istft(y; n_fft=512, window, original_length=size(x, 1)) + @test cpu(x) ≈ cpu(xx) + + # Hop. + for hop_length in (32, 33, 255, 256, 511, 512) + y = stft(x; n_fft=512, hop_length) + xx = istft(y; n_fft=512, hop_length, original_length=size(x, 1)) + @test cpu(x) ≈ cpu(xx) + end + + # N FFT. + for n_fft in (32, 33, 64, 65, 128, 129, 512) + y = stft(x; n_fft) + xx = istft(y; n_fft, original_length=size(x, 1)) + @test cpu(x) ≈ cpu(xx) + end + end + end + + @testset "Spectrogram" begin + x = device(rand(Float32, 1024)) + window = device(hann_window(1024)) + + y = stft(x; + n_fft=1024, hop_length=128, window, + center=true, normalized=false) + spec = spectrogram(x; + n_fft=1024, hop_length=128, window, + center=true, normalized=false) + @test abs.(y).^2 ≈ spec + + # Gradient with `0`s in spectrogram. + # We add small ϵ to spectrogram before computing power + # to prevent `NaN` in gradient due to `abs(0)`. + x = device(ones(Float32, 1024)) + g = Zygote.gradient(x) do x + sum(spectrogram(x; + n_fft=1024, hop_length=128, window, + center=true, normalized=false)) + end + @test !any(isnan.(g[1])) + + # Batched. + x = device(rand(Float32, 1024, 3)) + spec = spectrogram(x; + n_fft=1024, hop_length=128, window, + center=true, normalized=false) + for i in 1:3 + y = stft(x[:, i]; + n_fft=1024, hop_length=128, window, + center=true, normalized=false) + @test abs.(y).^2 ≈ spec[:, :, i] + end + + if Backend != CPU + @testset "Grads" begin + for batch in ((), (3,)) + x = rand(Float32, 2045, batch...) + n_fft = 256 + window = hann_window(n_fft) + gradtest_fn((s, w) -> spectrogram(s; n_fft, hop_length=n_fft ÷ 4, window=w), x, window) + gradtest_fn((s, w) -> spectrogram(s; n_fft, hop_length=n_fft ÷ 4, window=w, center=false), x, window) + gradtest_fn((s, w) -> spectrogram(s; n_fft, hop_length=n_fft ÷ 4, window=w, normalized=true), x, window) + end + end + end + end + + @testset "Power to dB" begin + x = device(rand(Float32, 1024)) + window = device(hann_window(1024)) + spec = spectrogram(x; pad=0, n_fft=1024, hop_length=128, window) + + @test spec ≈ NNlib.db_to_power(NNlib.power_to_db(spec)) + @inferred NNlib.power_to_db(spec) + @inferred NNlib.db_to_power(NNlib.power_to_db(spec)) + end +end diff --git a/test/upsample.jl b/test/testsuite/upsample.jl similarity index 100% rename from test/upsample.jl rename to test/testsuite/upsample.jl