Skip to content

Zygote gradients of nested ArrayPartitions lose their shape #480

@lukas-weber

Description

@lukas-weber

Describe the bug 🐞

Using Zygote.gradient on a function of a nested ArrayPartition returns a gradient of ArrayPartition that has its inner partition flattened to a Vector.

Possibly this is due to/could be fixed by implementing Zygote.ProjectTo?

Expected behavior

The gradient should mimic the type and shape of the input. If the input is an ArrayPartition of an ArrayPartition, it should be the same for the gradient.

Minimal Reproducible Example 👇

using Zygote
using RecursiveArrayTools
using LinearAlgebra

x = ArrayPartition(ArrayPartition(rand(3,4), rand(3,4)), rand(2))
g = Zygote.gradient(norm, x)[1]

typeof(x)
# ArrayPartition{Float64, Tuple{ArrayPartition{Float64, Tuple{Matrix{Float64}, Matrix{Float64}}}, Vector{Float64}}}
typeof(g)
# ArrayPartition{Float64, Tuple{Vector{Float64}, Vector{Float64}}}

Environment (please complete the following information):

  • Output of using Pkg; Pkg.status()
Status `/tmp/jl_Ci688N/Project.toml`
  [731186ca] RecursiveArrayTools v3.37.1
  [e88e6eb3] Zygote v0.7.10
  • Output of using Pkg; Pkg.status(; mode = PKGMODE_MANIFEST)
Status `/tmp/jl_Ci688N/Manifest.toml`
  [621f4979] AbstractFFTs v1.5.0
  [7d9f7c33] Accessors v0.1.42
  [79e6a3ab] Adapt v4.3.0
  [4fba245c] ArrayInterface v7.19.0
  [082447d4] ChainRules v1.72.5
  [d360d2e6] ChainRulesCore v1.26.0
  [bbf7d656] CommonSubexpressions v0.3.1
  [34da2185] Compat v4.18.0
  [a33af91c] CompositionsBase v0.1.2
  [187b0558] ConstructionBase v1.6.0
  [9a962f9c] DataAPI v1.16.0
  [e2d170a0] DataValueInterfaces v1.0.0
  [163ba53b] DiffResults v1.1.0
  [b552c78f] DiffRules v1.15.1
  [ffbed154] DocStringExtensions v0.9.5
  [e2ba6199] ExprTools v0.1.10
  [1a297f60] FillArrays v1.13.0
  [f6369f11] ForwardDiff v1.0.1
  [46192b85] GPUArraysCore v0.2.0
  [7869d1d1] IRTools v0.4.15
  [3587e190] InverseFunctions v0.1.17
  [92d709cd] IrrationalConstants v0.2.4
  [82899510] IteratorInterfaceExtensions v1.0.0
  [692b3bcd] JLLWrappers v1.7.1
  [2ab3a3ac] LogExpFunctions v0.3.29
  [1914dd2f] MacroTools v0.5.16
  [77ba4419] NaNMath v1.1.3
  [bac558e1] OrderedCollections v1.8.1
⌅ [aea7be01] PrecompileTools v1.2.1
  [21216c6a] Preferences v1.5.0
  [c1ae055f] RealDot v0.1.0
  [3cdcf5f2] RecipesBase v1.3.4
  [731186ca] RecursiveArrayTools v3.37.1
  [ae029012] Requires v1.3.1
  [7e49a35a] RuntimeGeneratedFunctions v0.5.15
  [dc90abb0] SparseInverseSubset v0.1.2
  [276daf66] SpecialFunctions v2.5.1
  [1e83bf80] StaticArraysCore v1.4.3
  [10745b16] Statistics v1.11.1
  [09ab397b] StructArrays v0.7.1
  [2efcf032] SymbolicIndexingInterface v0.3.43
  [3783bdb8] TableTraits v1.0.1
  [bd369af6] Tables v1.12.1
  [e88e6eb3] Zygote v0.7.10
  [700de1a5] ZygoteRules v0.2.7
  [efe28fd5] OpenSpecFun_jll v0.5.6+0
  [56f22d72] Artifacts v1.11.0
  [2a0f44e3] Base64 v1.11.0
  [ade2ca70] Dates v1.11.0
  [8ba89e20] Distributed v1.11.0
  [b77e0a4c] InteractiveUtils v1.11.0
  [8f399da3] Libdl v1.11.0
  [37e2e46d] LinearAlgebra v1.11.0
  [d6f4376e] Markdown v1.11.0
  [de0858da] Printf v1.11.0
  [9a3f8284] Random v1.11.0
  [ea8e919c] SHA v0.7.0
  [9e88b42a] Serialization v1.11.0
  [6462fe0b] Sockets v1.11.0
  [2f01184e] SparseArrays v1.11.0
  [4607b0f0] SuiteSparse
  [fa267f1f] TOML v1.0.3
  [cf7118a7] UUIDs v1.11.0
  [4ec0a83e] Unicode v1.11.0
  [e66e0078] CompilerSupportLibraries_jll v1.1.1+0
  [4536629a] OpenBLAS_jll v0.3.27+1
  [05823500] OpenLibm_jll v0.8.5+0
  [bea87d4a] SuiteSparse_jll v7.7.0+0
  [8e850b90] libblastrampoline_jll v5.11.0+0
  • Output of versioninfo()
Julia Version 1.11.6
Commit 9615af0f269 (2025-07-09 12:58 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 32 × Intel(R) Xeon(R) Gold 6244 CPU @ 3.60GHz
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, cascadelake)
Threads: 1 default, 0 interactive, 1 GC (on 32 virtual cores)
Environment:
  LD_LIBRARY_PATH = /mnt/sw/nix/store/vxi6jxb3xfpnaq3mpix2vsa6dhr443fk-openblas-0.3.26/lib:/mnt/home/lweber/app/cpython/prefix/lib:/mnt/home/lweber/app/cpython/prefix/lib:/mnt/home/lweber/app/cpython/prefix/lib:/mnt/home/lweber/app/cpython/prefix/lib:/mnt/home/lweber/app/cpython/prefix/lib:/mnt/home/lweber/app/cpython/prefix/lib:/mnt/home/lweber/app/cpython/prefix/lib
  JULIA_DEPOT_PATH = /home/lweber/.julia:

Additional context

This caused a problem down the line in JuliaManifolds/ManifoldDiff.jl#79

Metadata

Metadata

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions