1- using StaticArrays, LinearAlgebra
1+ using StaticArrays, LinearAlgebra, UnrolledUtilities
22
33"""
44 AbstractAxis
@@ -8,6 +8,7 @@ An axis of a [`AxisTensor`](@ref).
88abstract type AbstractAxis{I} end
99
1010Base. Broadcast. broadcastable (a:: AbstractAxis ) = a
11+ axis_indices (:: AbstractAxis{I} ) where {I} = I
1112
1213"""
1314 dual(ax::AbstractAxis)
@@ -521,101 +522,41 @@ function _project(
521522 x
522523end
523524
524- #= Set `assert_exact_transform() = true` for debugging=#
525- assert_exact_transform () = false
526-
527- @generated function _transform (
528- ato:: Ato ,
529- x:: Axis2Tensor{T, Tuple{Afrom, A2}} ,
530- ) where {
531- Ato <: AbstractAxis{Ito} ,
532- Afrom <: AbstractAxis{Ifrom} ,
533- A2 <: AbstractAxis{J} ,
534- } where {Ito, Ifrom, J, T}
535- N = length (Ifrom)
536- M = length (J)
537- if assert_exact_transform ()
538- errcond = false
539- for n in 1 : N
540- i = Ifrom[n]
541- if i ∉ Ito
542- for m in 1 : M
543- errcond = :($ errcond || x[$ n, $ m] != zero (T))
544- end
545- end
546- end
547- end
548- vals = []
549- for m in 1 : M
550- for i in Ito
551- val = :(zero (T))
552- for n in 1 : N
553- if i == Ifrom[n]
554- val = :(x[$ n, $ m])
555- break
556- end
557- end
558- push! (vals, val)
559- end
560- end
561- quote
562- Base. @_propagate_inbounds_meta
563- if assert_exact_transform ()
564- if $ errcond
565- throw (InexactError (:transform , Ato, x))
566- end
567- end
568- @inbounds Axis2Tensor (
569- (ato, axes (x, 2 )),
570- SMatrix {$(length(Ito)), $M} ($ (vals... )),
525+ @inline transform (ato:: CovariantAxis , v:: CovariantTensor ) = project (ato, v)
526+ @inline transform (ato:: ContravariantAxis , v:: ContravariantTensor ) =
527+ project (ato, v)
528+ @inline transform (ato:: CartesianAxis , v:: CartesianTensor ) = project (ato, v)
529+ @inline transform (ato:: LocalAxis , v:: LocalTensor ) = project (ato, v)
530+
531+ @inline function project (
532+ ato_l:: AbstractAxis ,
533+ v:: Axis2Tensor ,
534+ ato_r:: AbstractAxis ,
535+ )
536+ @assert symbols .(axes (v)) == symbols .((ato_l, ato_r)) " Axes do not match"
537+ T = eltype (v)
538+ Ifrom_l, Ifrom_r = axis_indices .(axes (v))
539+ product_to = unrolled_product (axis_indices (ato_l), axis_indices (ato_r))
540+ vals = unrolled_map (product_to) do (m_l, m_r)
541+ n_l = unrolled_findfirst (
542+ ((n_l, ifrom),) -> m_l == ifrom,
543+ enumerate (Ifrom_l),
571544 )
572- end
573- end
574-
575- @generated function _project (
576- ato:: Ato ,
577- x:: Axis2Tensor{T, Tuple{Afrom, A2}} ,
578- ) where {
579- Ato <: AbstractAxis{Ito} ,
580- Afrom <: AbstractAxis{Ifrom} ,
581- A2 <: AbstractAxis{J} ,
582- } where {Ito, Ifrom, J, T}
583- N = length (Ifrom)
584- M = length (J)
585- vals = []
586- for m in 1 : M
587- for i in Ito
588- val = :(zero (T))
589- for n in 1 : N
590- if i == Ifrom[n]
591- val = :(x[$ n, $ m])
592- break
593- end
594- end
595- push! (vals, val)
596- end
597- end
598- quote
599- Base. @_propagate_inbounds_meta
600- @inbounds Axis2Tensor (
601- (ato, axes (x, 2 )),
602- SMatrix {$(length(Ito)), $M} ($ (vals... )),
545+ n_r = unrolled_findfirst (
546+ ((n_r, ifrom),) -> m_r == ifrom,
547+ enumerate (Ifrom_r),
603548 )
549+ isnothing (n_l) || isnothing (n_r) ? zero (T) : v[n_l, n_r]
604550 end
551+ S = SMatrix {length(ato_l), length(ato_r)} (vals... )
552+ @inbounds Axis2Tensor ((ato_l, ato_r), S)
553+ end
554+ @inline project (ato:: AbstractAxis , v:: Axis2Tensor ) = project (ato, v, axes (v, 2 ))
555+ @inline project (v:: Axis2Tensor , ato:: AbstractAxis ) = project (axes (v, 1 ), v, ato)
556+ @inline function project (ato_l, v:: AxisVector )
557+ @assert symbols (axes (v, 1 )) == symbols (ato_l)
558+ _project (ato_l, v)
605559end
606-
607- @inline transform (ato:: CovariantAxis , v:: CovariantTensor ) = _project (ato, v)
608- @inline transform (ato:: ContravariantAxis , v:: ContravariantTensor ) =
609- _project (ato, v)
610- @inline transform (ato:: CartesianAxis , v:: CartesianTensor ) = _project (ato, v)
611- @inline transform (ato:: LocalAxis , v:: LocalTensor ) = _project (ato, v)
612-
613- @inline project (ato:: CovariantAxis , v:: CovariantTensor ) = _project (ato, v)
614- @inline project (ato:: ContravariantAxis , v:: ContravariantTensor ) =
615- _project (ato, v)
616- @inline project (ato:: CartesianAxis , v:: CartesianTensor ) = _project (ato, v)
617- @inline project (ato:: LocalAxis , v:: LocalTensor ) = _project (ato, v)
618-
619560
620561"""
621562 outer(x, y)
0 commit comments