@@ -62,8 +62,7 @@ coordinate_axis(::Type{<:LatLongPoint}) = (1, 2)
6262
6363coordinate_axis (coord:: AbstractPoint ) = coordinate_axis (typeof (coord))
6464
65- @inline idxin (I:: NTuple{N, Int} , i:: Int ) where {N} =
66- unrolled_findfirst (isequal (i), I)
65+ @inline idxin (I:: NTuple{N, Int} , i:: Int ) where {N} = unrolled_findfirst (isequal (i), I)
6766
6867struct PropertyError <: Exception
6968 ax:: Any
8584# most of these are required for printing
8685Base. length (ax:: AbstractAxis{I} ) where {I} = length (I)
8786Base. unitrange (ax:: AbstractAxis ) = StaticArrays. SOneTo (length (ax))
88- Base. LinearIndices (axes:: Tuple{Vararg{AbstractAxis}} ) =
89- 1 : prod (map (length, axes))
87+ Base. LinearIndices (axes:: Tuple{Vararg{AbstractAxis}} ) = 1 : prod (map (length, axes))
9088Base. checkindex (:: Type{Bool} , ax:: AbstractAxis , i) =
9189 Base. checkindex (Bool, Base. unitrange (ax), i)
9290Base. lastindex (ax:: AbstractAxis ) = length (ax)
@@ -109,21 +107,16 @@ operations such as multiplication.
109107[`components`](@ref) to obtain the underlying array.
110108"""
111109struct AxisTensor{
112- T,
113- N,
114- A <: NTuple{N, AbstractAxis} ,
115- S <: StaticArray{<:Tuple, T, N} ,
110+ T, N, A <: NTuple{N, AbstractAxis} , S <: StaticArray{<:Tuple, T, N} ,
116111} <: AbstractArray{T, N}
117112 axes:: A
118113 components:: S
119114end
120115
121116AxisTensor (
122- axes:: A ,
123- components:: S ,
117+ axes:: A , components:: S ,
124118) where {
125- A <: Tuple{Vararg{AbstractAxis}} ,
126- S <: StaticArray{<:Tuple, T, N} ,
119+ A <: Tuple{Vararg{AbstractAxis}} , S <: StaticArray{<:Tuple, T, N} ,
127120} where {T, N} = AxisTensor {T, N, A, S} (axes, components)
128121
129122AxisTensor (axes:: Tuple{Vararg{AbstractAxis}} , components) =
@@ -154,8 +147,7 @@ Base.zeros(::Type{AxisTensor{T, N, A, S}}) where {T, N, A, S} =
154147
155148function Base. show (io:: IO , a:: AxisTensor{T, N, A, S} ) where {T, N, A, S}
156149 print (
157- io,
158- " AxisTensor{$T , $N , $A , $S }($(getfield (a, :axes )) , $(getfield (a, :components )) )" ,
150+ io, " AxisTensor{$T , $N , $A , $S }($(getfield (a, :axes )) , $(getfield (a, :components )) )" ,
159151 )
160152end
161153
@@ -180,28 +172,20 @@ Base.@propagate_inbounds Base.getindex(v::AxisTensor, i::Int...) =
180172
181173
182174Base. @propagate_inbounds function Base. getindex (
183- v:: AxisTensor{<:Any, 2, Tuple{A1, A2}} ,
184- :: Colon ,
185- i:: Integer ,
175+ v:: AxisTensor{<:Any, 2, Tuple{A1, A2}} , :: Colon , i:: Integer ,
186176) where {A1, A2}
187177 AxisVector (axes (v, 1 ), getindex (components (v), :, i))
188178end
189179Base. @propagate_inbounds function Base. getindex (
190- v:: AxisTensor{<:Any, 2, Tuple{A1, A2}} ,
191- i:: Integer ,
192- :: Colon ,
180+ v:: AxisTensor{<:Any, 2, Tuple{A1, A2}} , i:: Integer , :: Colon ,
193181) where {A1, A2}
194182 AxisVector (axes (v, 2 ), getindex (components (v), i, :))
195183end
196184
197185
198186Base. map (f:: F , a:: AxisTensor ) where {F} =
199187 AxisTensor (axes (a), map (f, components (a)))
200- Base. map (
201- f:: F ,
202- a:: AxisTensor{Ta, N, A} ,
203- b:: AxisTensor{Tb, N, A} ,
204- ) where {F, Ta, Tb, N, A} =
188+ Base. map (f:: F , a:: AxisTensor{Ta, N, A} , b:: AxisTensor{Tb, N, A} ) where {F, Ta, Tb, N, A} =
205189 AxisTensor (axes (a), map (f, components (a), components (b)))
206190# Base.map(f, a::AxisTensor{Ta,N}, b::AxisTensor{Tb,N}) where {Ta,Tb,N} =
207191# map(f, promote(a,b)...)
@@ -239,11 +223,8 @@ AxisVector(ax::A1, v::SVector{N, T}) where {A1 <: AbstractAxis, N, T} =
239223 AxisVector (A. instance, SVector (arg1))
240224(AxisVector{T, A, SVector{2 , T}} where {T})(arg1:: Real , arg2:: Real ) where {A} =
241225 AxisVector (A. instance, SVector (arg1, arg2))
242- (AxisVector{T, A, SVector{3 , T}} where {T})(
243- arg1:: Real ,
244- arg2:: Real ,
245- arg3:: Real ,
246- ) where {A} = AxisVector (A. instance, SVector (arg1, arg2, arg3))
226+ (AxisVector{T, A, SVector{3 , T}} where {T})(arg1:: Real , arg2:: Real , arg3:: Real ) where {A} =
227+ AxisVector (A. instance, SVector (arg1, arg2, arg3))
247228
248229const CovariantVector{T, I, S} = AxisVector{T, CovariantAxis{I}, S}
249230const ContravariantVector{T, I, S} = AxisVector{T, ContravariantAxis{I}, S}
@@ -284,8 +265,7 @@ Base.zero(::Type{AdjointAxisTensor{T, N, A, S}}) where {T, N, A, S} =
284265
285266const AdjointAxisVector{T, A1, S} = Adjoint{T, AxisVector{T, A1, S}}
286267
287- const AxisVectorOrAdj{T, A, S} =
288- Union{AxisVector{T, A, S}, AdjointAxisVector{T, A, S}}
268+ const AxisVectorOrAdj{T, A, S} = Union{AxisVector{T, A, S}, AdjointAxisVector{T, A, S}}
289269
290270Base. @propagate_inbounds Base. getindex (va:: AdjointAxisVector , i:: Int ) =
291271 getindex (components (va), i)
@@ -294,35 +274,28 @@ Base.@propagate_inbounds Base.getindex(va::AdjointAxisVector, i::Int, j::Int) =
294274
295275# 2-tensors
296276const Axis2Tensor{T, A, S} = AxisTensor{T, 2 , A, S}
297- Axis2Tensor (
298- axes:: Tuple{AbstractAxis, AbstractAxis} ,
299- components:: AbstractMatrix ,
300- ) = AxisTensor (axes, components)
277+ Axis2Tensor (axes:: Tuple{AbstractAxis, AbstractAxis} , components:: AbstractMatrix ) =
278+ AxisTensor (axes, components)
301279
302280const AdjointAxis2Tensor{T, A, S} = Adjoint{T, Axis2Tensor{T, A, S}}
303281
304- const Axis2TensorOrAdj{T, A, S} =
305- Union{Axis2Tensor{T, A, S}, AdjointAxis2Tensor{T, A, S}}
282+ const Axis2TensorOrAdj{T, A, S} = Union{Axis2Tensor{T, A, S}, AdjointAxis2Tensor{T, A, S}}
306283
307284@inline + (
308285 a:: Axis2Tensor{Ta, Tuple{A1, A2}, Sa} ,
309286 b:: Adjoint{Tb, Axis2Tensor{Tb, Tuple{A2, A2}, Sb}} ,
310- ) where {Ta, Tb, A1, A2, Sa, Sb} =
311- AxisTensor (a. axes, components (a) + components (b))
287+ ) where {Ta, Tb, A1, A2, Sa, Sb} = AxisTensor (a. axes, components (a) + components (b))
312288
313289@inline + (
314290 a:: Adjoint{Ta, Axis2Tensor{Ta, Tuple{A1, A2}, Sa}} ,
315291 b:: Axis2Tensor{Tb, Tuple{A2, A2}, Sb} ,
316- ) where {Ta, Tb, A1, A2, Sa, Sb} =
317- AxisTensor (b. axes, components (a) + components (b))
292+ ) where {Ta, Tb, A1, A2, Sa, Sb} = AxisTensor (b. axes, components (a) + components (b))
318293
319294# based on 1st dimension
320295const Covariant2Tensor{T, A, S} =
321296 Axis2Tensor{T, A, S} where {T, A <: Tuple{CovariantAxis, AbstractAxis} , S}
322297const Contravariant2Tensor{T, A, S} = Axis2Tensor{
323- T,
324- A,
325- S,
298+ T, A, S,
326299} where {T, A <: Tuple{ContravariantAxis, AbstractAxis} , S}
327300const Cartesian2Tensor{T, A, S} =
328301 Axis2Tensor{T, A, S} where {T, A <: Tuple{CartesianAxis, AbstractAxis} , S}
@@ -363,8 +336,7 @@ check_axes(ax1, ax2) = throw(DimensionMismatch("$ax1 and $ax2 do not match"))
363336
364337check_dual (ax1, ax2) = _check_dual (ax1, ax2, dual (ax2))
365338_check_dual (:: A , _, :: A ) where {A} = nothing
366- _check_dual (ax1, ax2, _) =
367- throw (DimensionMismatch (" $ax1 is not dual with $ax2 " ))
339+ _check_dual (ax1, ax2, _) = throw (DimensionMismatch (" $ax1 is not dual with $ax2 " ))
368340
369341
370342function LinearAlgebra. dot (x:: AxisVector , y:: AxisVector )
@@ -432,15 +404,13 @@ end
432404
433405
434406function _transform (
435- ato:: Ato ,
436- x:: AxisVector{T, Afrom, SVector{N, T}} ,
407+ ato:: Ato , x:: AxisVector{T, Afrom, SVector{N, T}} ,
437408) where {Ato <: AbstractAxis{I} , Afrom <: AbstractAxis{I} } where {I, T, N}
438409 x
439410end
440411
441412function _project (
442- ato:: Ato ,
443- x:: AxisVector{T, Afrom, SVector{N, T}} ,
413+ ato:: Ato , x:: AxisVector{T, Afrom, SVector{N, T}} ,
444414) where {Ato <: AbstractAxis{I} , Afrom <: AbstractAxis{I} } where {I, T, N}
445415 x
446416end
@@ -501,51 +471,34 @@ end
501471end
502472
503473function _transform (
504- ato:: Ato ,
505- x:: Axis2Tensor{T, Tuple{Afrom, A2}} ,
474+ ato:: Ato , x:: Axis2Tensor{T, Tuple{Afrom, A2}} ,
506475) where {
507- Ato <: AbstractAxis{I} ,
508- Afrom <: AbstractAxis{I} ,
509- A2 <: AbstractAxis{J} ,
476+ Ato <: AbstractAxis{I} , Afrom <: AbstractAxis{I} , A2 <: AbstractAxis{J} ,
510477} where {I, J, T}
511478 x
512479end
513480
514481function _project (
515- ato:: Ato ,
516- x:: Axis2Tensor{T, Tuple{Afrom, A2}} ,
482+ ato:: Ato , x:: Axis2Tensor{T, Tuple{Afrom, A2}} ,
517483) where {
518- Ato <: AbstractAxis{I} ,
519- Afrom <: AbstractAxis{I} ,
520- A2 <: AbstractAxis{J} ,
484+ Ato <: AbstractAxis{I} , Afrom <: AbstractAxis{I} , A2 <: AbstractAxis{J} ,
521485} where {I, J, T}
522486 x
523487end
524488
525489@inline transform (ato:: CovariantAxis , v:: CovariantTensor ) = project (ato, v)
526- @inline transform (ato:: ContravariantAxis , v:: ContravariantTensor ) =
527- project (ato, v)
490+ @inline transform (ato:: ContravariantAxis , v:: ContravariantTensor ) = project (ato, v)
528491@inline transform (ato:: CartesianAxis , v:: CartesianTensor ) = project (ato, v)
529492@inline transform (ato:: LocalAxis , v:: LocalTensor ) = project (ato, v)
530493
531- @inline function project (
532- ato_l:: AbstractAxis ,
533- v:: Axis2Tensor ,
534- ato_r:: AbstractAxis ,
535- )
494+ @inline function project (ato_l:: AbstractAxis , v:: Axis2Tensor , ato_r:: AbstractAxis )
536495 @assert symbols .(axes (v)) == symbols .((ato_l, ato_r)) " Axes do not match"
537496 T = eltype (v)
538497 Ifrom_l, Ifrom_r = axis_indices .(axes (v))
539498 product_to = unrolled_product (axis_indices (ato_l), axis_indices (ato_r))
540499 vals = unrolled_map (product_to) do (m_l, m_r)
541- n_l = unrolled_findfirst (
542- ((n_l, ifrom),) -> m_l == ifrom,
543- enumerate (Ifrom_l),
544- )
545- n_r = unrolled_findfirst (
546- ((n_r, ifrom),) -> m_r == ifrom,
547- enumerate (Ifrom_r),
548- )
500+ n_l = unrolled_findfirst (((n_l, ifrom),) -> m_l == ifrom, enumerate (Ifrom_l))
501+ n_r = unrolled_findfirst (((n_r, ifrom),) -> m_r == ifrom, enumerate (Ifrom_r))
549502 isnothing (n_l) || isnothing (n_r) ? zero (T) : v[n_l, n_r]
550503 end
551504 S = SMatrix {length(ato_l), length(ato_r)} (vals... )
0 commit comments