-
-
Notifications
You must be signed in to change notification settings - Fork 5.6k
Description
I suggest a mechanism to customize splatting behavior such that code like
+((x .- y).^2...) / length(x) # compute MSE
can be executed efficiently without any allocations.
The idea is inspired by discussion in #29114 by @c42f et al.
Idea
Like dot-call syntax, I suggest to lower splatting to a series of function calls that are overloadable. One possibility is
# f(a...) is expanded to:
apply(f, Arguments(VA(splattable(a))))
# f(a, b..., c, d...) is expanded to:
apply(f, Arguments(a, VA(splattable(b)), c, VA(splattable(d))))
(To avoid recursion, this lowering should happen only when the call includes splatting.)
When the dot-call syntax appears in the splatting operand, I suggest not materialize the dot-call. That is to say, for example, op(f.(xs)...)
is lowered to
bc = broadcasted(f, xs) # `bc` not materialized
apply(op, Arguments(VA(splattable(bc))))
This let us evaluate op(f.(xs)...)
without any allocation once reduce
/foldl
supports Broadcasted
object (#31020 started tackle this).
Interface
The lowering above requires the following interface functions and types:
function apply end
splattable(x) = x
struct VA{T}
args::T
end
struct Arguments{T <: Tuple}
args::T
Arguments(args...) = new{typeof(args)}(args)
end
apply
must be dispatched on the first argument type and may be dispatched on the second argument type.- Per-vararg processing should be done via
splattable
. This is analogous tobroadcastable
(@yurivish suggested this in Performance of splatting a number #29114 (comment)). - The type
VA
(whose name can/should be improved) must be used only for definingapply
; its constructor must not be overloaded. This is for making it hard to break splatting semantics. - The constructor for type
Arguments
must not be overloaded for the same reason.
Using current Core._apply
, the default apply
can be implemented as
apply(f, x) = Core._apply(Core._apply, (f,), _default_splattables(x.args))
_default_splattables(args) = map(x -> x isa VA ? materialize(x.args) : (x,), args)
Example overloads
Associative binary operators
Many useful operations can be expressed using splatting into associative operators
+(xs...) == sum(xs)
*(xs...) == prod(xs)
min(xs...) == minimum(xs)
max(xs...) == maximum(xs)
or "mapped-splatting"
+(xs.^2...) == norm(xs)
+(xs .* ys...) == dot(xs, ys)
(This is reminiscent of the "big operator" in Fortress.)
There are also various other associative binary operators in Base. Invoking reduce
with splatting could be useful:
const AssociativeOperator = Union{
typoef(*),
typoef(+),
typoef(&),
typoef(|),
typoef(min),
typoef(max),
typoef(intersect),
typoef(union),
typoef(vcat),
typoef(hcat),
typoef(merge),
# what else?
}
apply(op::AssociativeOperator, args::Arguments{Tuple{<:VA}}) =
reduce(op, args.args[1].args)
For example, concatenating vectors would be efficiently done via vcat(vectors...)
thanks to #27188.
It also is possible to support
op :: AssociativeOperator
op(a, bs..., cs...)
such that it is computed as
op(op(a, reduce(op, bs)), reduce(op, cs))
This may be implemented as
apply(op::AssociativeOperator, args::Arguments) = mapreduce(apply1, op, args.args)
_apply1(op::AssociativeOperator, args::VA) = reduce(op, args.args)
_apply1(op, x) = x
Note that, since reduce
would degrade to foldl
when the input is not an array (or not Broadcasted
after #31020), we can also use it to fuse filtering with reduction
+((x for x in xs if x > 0)...)
Non-associative binary functions
Splatting is useful for non-associative binary functions:
const BinaryFunction = Union{
typoef(/),
typoef(-),
typoef(intersect!),
typoef(union!),
typoef(merge!),
typoef(append!),
typoef(push!),
# what else?
}
function apply(op::BinaryFunction, args::Arguments{Tuple{Any, <:VA}})
@assert !(args.args[1] isa VA)
return foldl(op, args.args[2].args; init=args.args[1])
end
or more generally
function apply(op::BinaryFunction, args::Arguments)
@assert !(args.args[1] isa VA)
return foldl(op, flatten(x isa VA ? x.args : (x,) for x in args.args[2:end]); init=args.args[1])
end
Matrix-vector multiplications
Not sure how many people need this, but *(matrices..., vector)
can be (somewhat) efficiently evaluated by defining
apply(::typeof(*), args::Arguments{Tuple{<:VA, <:AbstractVector}}) =
foldr(*, args.args[1].args; init=args.args[2])
(Of course, allocation could be much more minimized if we really want this.)
Similar optimization can be done for ∘(fs...)
; but I'm not sure about the exact usecase.
Higher-order functions (map(f, iters..)
etc.)
As this mechanism let any function optimize splatting, higher-order functions that may call splatting of given function can be optimized by defining their own apply
specialization. For example, map(f, iters..)
can be specialized as
function apply(::typeof(map), args::Arguments{Tuple{Any, <:VA}})
f = args.args[1]
iters = args.args[2].args
return map(splat(f), _zipsplat(iters))
end
where _zipsplat(iters)
behaves like zip(iters...)
but its element type does not have to be a Tuple
.
Defining a similar overload for broadcasted
may be possible provided that the object returned by _zipsplat(iters)
is indexable. This let us nest reduction inside mapping and avoid allocation in some cases:
vector .= +.(eachcol(matrix)...)
Other map
-like functions including map!
and foreach
can also implement this overload.
print
-like functions
print
, println
and write
can be invoked with varargs. We can make, e.g., println(xs...)
more compiler friendly and efficient when xs
is a generic iterator. Note that apply(string, Arguments(xs))
can also be implemented in terms of apply(print, Arguments(io, xs))
.
splattable
splattable
may be used to solve performance problem discussed in #29114:
splattable(x::Number) = (x,)
splattable(x::StaticArray) = Tuple(x)
In case of Broadcasted
, it can be used for calling instantiate
:
splattable(x::Broadcasted) = instantiate(x)