Skip to content

Commit a9a03d8

Browse files
authored
Merge pull request #2378 from CliMA/he/ft-add-left-and-right-projection
ft: add (left and right) projection of 2-tensors
2 parents 8f9ee78 + 6051afa commit a9a03d8

File tree

4 files changed

+59
-94
lines changed

4 files changed

+59
-94
lines changed

NEWS.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,11 @@ ClimaCore.jl Release Notes
44
main
55
-------
66

7+
v0.14.41
8+
-------
9+
- Add `project` for both sides of 2-tensors (previously only left side was supported).
10+
Simplify implementation. [2379](https://github.com/Clima/ClimaCore.jl/pull/2379)
11+
712
v0.14.40
813
-------
914
- Store `reverse_mode` in the `IntervalMesh` struct, so that we can access it when writing

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ClimaCore"
22
uuid = "d414da3d-4745-48bb-8d80-42e94e092884"
33
authors = ["CliMA Contributors <[email protected]>"]
4-
version = "0.14.40"
4+
version = "0.14.41"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/Geometry/axistensors.jl

Lines changed: 33 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using StaticArrays, LinearAlgebra
1+
using StaticArrays, LinearAlgebra, UnrolledUtilities
22

33
"""
44
AbstractAxis
@@ -8,6 +8,7 @@ An axis of a [`AxisTensor`](@ref).
88
abstract type AbstractAxis{I} end
99

1010
Base.Broadcast.broadcastable(a::AbstractAxis) = a
11+
axis_indices(::AbstractAxis{I}) where {I} = I
1112

1213
"""
1314
dual(ax::AbstractAxis)
@@ -521,101 +522,41 @@ function _project(
521522
x
522523
end
523524

524-
#= Set `assert_exact_transform() = true` for debugging=#
525-
assert_exact_transform() = false
526-
527-
@generated function _transform(
528-
ato::Ato,
529-
x::Axis2Tensor{T, Tuple{Afrom, A2}},
530-
) where {
531-
Ato <: AbstractAxis{Ito},
532-
Afrom <: AbstractAxis{Ifrom},
533-
A2 <: AbstractAxis{J},
534-
} where {Ito, Ifrom, J, T}
535-
N = length(Ifrom)
536-
M = length(J)
537-
if assert_exact_transform()
538-
errcond = false
539-
for n in 1:N
540-
i = Ifrom[n]
541-
if i Ito
542-
for m in 1:M
543-
errcond = :($errcond || x[$n, $m] != zero(T))
544-
end
545-
end
546-
end
547-
end
548-
vals = []
549-
for m in 1:M
550-
for i in Ito
551-
val = :(zero(T))
552-
for n in 1:N
553-
if i == Ifrom[n]
554-
val = :(x[$n, $m])
555-
break
556-
end
557-
end
558-
push!(vals, val)
559-
end
560-
end
561-
quote
562-
Base.@_propagate_inbounds_meta
563-
if assert_exact_transform()
564-
if $errcond
565-
throw(InexactError(:transform, Ato, x))
566-
end
567-
end
568-
@inbounds Axis2Tensor(
569-
(ato, axes(x, 2)),
570-
SMatrix{$(length(Ito)), $M}($(vals...)),
525+
@inline transform(ato::CovariantAxis, v::CovariantTensor) = project(ato, v)
526+
@inline transform(ato::ContravariantAxis, v::ContravariantTensor) =
527+
project(ato, v)
528+
@inline transform(ato::CartesianAxis, v::CartesianTensor) = project(ato, v)
529+
@inline transform(ato::LocalAxis, v::LocalTensor) = project(ato, v)
530+
531+
@inline function project(
532+
ato_l::AbstractAxis,
533+
v::Axis2Tensor,
534+
ato_r::AbstractAxis,
535+
)
536+
@assert symbols.(axes(v)) == symbols.((ato_l, ato_r)) "Axes do not match"
537+
T = eltype(v)
538+
Ifrom_l, Ifrom_r = axis_indices.(axes(v))
539+
product_to = unrolled_product(axis_indices(ato_l), axis_indices(ato_r))
540+
vals = unrolled_map(product_to) do (m_l, m_r)
541+
n_l = unrolled_findfirst(
542+
((n_l, ifrom),) -> m_l == ifrom,
543+
enumerate(Ifrom_l),
571544
)
572-
end
573-
end
574-
575-
@generated function _project(
576-
ato::Ato,
577-
x::Axis2Tensor{T, Tuple{Afrom, A2}},
578-
) where {
579-
Ato <: AbstractAxis{Ito},
580-
Afrom <: AbstractAxis{Ifrom},
581-
A2 <: AbstractAxis{J},
582-
} where {Ito, Ifrom, J, T}
583-
N = length(Ifrom)
584-
M = length(J)
585-
vals = []
586-
for m in 1:M
587-
for i in Ito
588-
val = :(zero(T))
589-
for n in 1:N
590-
if i == Ifrom[n]
591-
val = :(x[$n, $m])
592-
break
593-
end
594-
end
595-
push!(vals, val)
596-
end
597-
end
598-
quote
599-
Base.@_propagate_inbounds_meta
600-
@inbounds Axis2Tensor(
601-
(ato, axes(x, 2)),
602-
SMatrix{$(length(Ito)), $M}($(vals...)),
545+
n_r = unrolled_findfirst(
546+
((n_r, ifrom),) -> m_r == ifrom,
547+
enumerate(Ifrom_r),
603548
)
549+
isnothing(n_l) || isnothing(n_r) ? zero(T) : v[n_l, n_r]
604550
end
551+
S = SMatrix{length(ato_l), length(ato_r)}(vals...)
552+
@inbounds Axis2Tensor((ato_l, ato_r), S)
553+
end
554+
@inline project(ato::AbstractAxis, v::Axis2Tensor) = project(ato, v, axes(v, 2))
555+
@inline project(v::Axis2Tensor, ato::AbstractAxis) = project(axes(v, 1), v, ato)
556+
@inline function project(ato_l, v::AxisVector)
557+
@assert symbols(axes(v, 1)) == symbols(ato_l)
558+
_project(ato_l, v)
605559
end
606-
607-
@inline transform(ato::CovariantAxis, v::CovariantTensor) = _project(ato, v)
608-
@inline transform(ato::ContravariantAxis, v::ContravariantTensor) =
609-
_project(ato, v)
610-
@inline transform(ato::CartesianAxis, v::CartesianTensor) = _project(ato, v)
611-
@inline transform(ato::LocalAxis, v::LocalTensor) = _project(ato, v)
612-
613-
@inline project(ato::CovariantAxis, v::CovariantTensor) = _project(ato, v)
614-
@inline project(ato::ContravariantAxis, v::ContravariantTensor) =
615-
_project(ato, v)
616-
@inline project(ato::CartesianAxis, v::CartesianTensor) = _project(ato, v)
617-
@inline project(ato::LocalAxis, v::LocalTensor) = _project(ato, v)
618-
619560

620561
"""
621562
outer(x, y)

test/Geometry/axistensors.jl

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ using Test, JET
22
using ClimaCore.Geometry, ClimaCore.DataLayouts
33
using LinearAlgebra, StaticArrays
44
import ClimaCore
5-
ClimaCore.Geometry.assert_exact_transform() = true
65

76
@testset "AxisTensors" begin
87
x = Geometry.Covariant12Vector(1.0, 2.0)
@@ -169,6 +168,26 @@ end
169168
Geometry.Covariant12Axis(),
170169
Geometry.Covariant13Vector(2.0, 2.0) * Geometry.Cartesian1Vector(1.0)',
171170
) == Geometry.Covariant12Vector(2.0, 0.0) * Geometry.Cartesian1Vector(1.0)'
171+
172+
# Test projection over rightmost axis
173+
x_C12 = Geometry.Covariant12Vector(2.0, 2.0)
174+
x_Cart123 = Geometry.Cartesian123Vector(1.0, 1.0, 1.0)
175+
@test Geometry.project(x_C12 * x_Cart123', Geometry.Cartesian3Axis()) ==
176+
x_C12 * Geometry.Cartesian3Vector(1.0)'
177+
@test Geometry.project(x_C12 * x_Cart123', Geometry.Cartesian23Axis()) ==
178+
x_C12 * Geometry.Cartesian23Vector(1.0, 1.0)'
179+
180+
# Test projection over both axes
181+
@test Geometry.project(
182+
Geometry.Covariant12Axis(),
183+
x_C12 * x_Cart123',
184+
Geometry.Cartesian123Axis(),
185+
) == x_C12 * x_Cart123'
186+
@test Geometry.project(
187+
Geometry.Covariant2Axis(),
188+
x_C12 * x_Cart123',
189+
Geometry.Cartesian13Axis(),
190+
) == Geometry.Covariant2Vector(2.0) * Geometry.Cartesian13Vector(1.0, 1.0)'
172191
end
173192

174193

0 commit comments

Comments
 (0)