Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,24 @@ TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76"

[weakdeps]
Bumper = "8ce10254-0962-460f-a3d8-1f77fea1446e"
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"

[extensions]
DynamicExpressionsBumperExt = "Bumper"
DynamicExpressionsLoopVectorizationExt = "LoopVectorization"
DynamicExpressionsOptimExt = "Optim"
DynamicExpressionsSymbolicUtilsExt = "SymbolicUtils"
DynamicExpressionsZygoteExt = "Zygote"
DynamicExpressionsLoopVectorizationExt = "LoopVectorization"

[compat]
Bumper = "0.6"
ChainRulesCore = "1"
Compat = "3.37, 4"
DispatchDoctor = "0.4"
Interfaces = "0.3"
LoopVectorization = "0.12"
MacroTools = "0.4, 0.5"
Optim = "0.19, 1"
PackageExtensionCompat = "1"
Expand All @@ -47,7 +46,6 @@ julia = "1.6"

[extras]
Bumper = "8ce10254-0962-460f-a3d8-1f77fea1446e"
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
66 changes: 64 additions & 2 deletions ext/DynamicExpressionsLoopVectorizationExt.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
module DynamicExpressionsLoopVectorizationExt

using LoopVectorization: @turbo
using DynamicExpressions: AbstractExpressionNode
using DynamicExpressions

using LoopVectorization: @turbo, vmapnt
using DynamicExpressions: AbstractExpressionNode, GraphNode, OperatorEnum
using DynamicExpressions.UtilsModule: ResultOk, fill_similar
using DynamicExpressions.EvaluateModule: @return_on_nonfinite_val, EvalOptions
import DynamicExpressions.EvaluateModule:
Expand All @@ -14,6 +16,7 @@ import DynamicExpressions.EvaluateModule:
deg2_r0_eval
import DynamicExpressions.ExtensionInterfaceModule:
_is_loopvectorization_loaded, bumper_kern1!, bumper_kern2!
import DynamicExpressions.ValueInterfaceModule: is_valid, is_valid_array

_is_loopvectorization_loaded(::Int) = true

Expand Down Expand Up @@ -230,4 +233,63 @@ function bumper_kern2!(
return cumulator1
end



# graph eval

function DynamicExpressions.EvaluateModule._eval_graph_array(
root::GraphNode{T},
cX::AbstractMatrix{T},
operators::OperatorEnum,
loopVectorization::Val{true}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
loopVectorization::Val{true}
::EvalOptions{true}

) where {T}

# vmap is faster with small cX sizes
# vmapnt (non-temporal) is faster with larger cX sizes (too big so not worth caching?)

order = topological_sort(root)
for node in order
if node.degree == 0 && !node.constant
node.cache = view(cX, node.feature, :)
elseif node.degree == 1
if node.l.constant
node.constant = true
node.val = operators.unaops[node.op](node.l.val)
if !is_valid(node.val) return ResultOk(Vector{T}(undef, size(cX, 2)), false) end
else
node.constant = false
node.cache = vmapnt(operators.unaops[node.op], node.l.cache)
if !is_valid_array(node.cache) return ResultOk(node.cache, false) end
end
elseif node.degree == 2
if node.l.constant
if node.r.constant
node.constant = true
node.val = operators.binops[node.op](node.l.val, node.r.val)
if !is_valid(node.val) return ResultOk(Vector{T}(undef, size(cX, 2)), false) end
else
node.constant = false
node.cache = vmapnt(Base.Fix1(operators.binops[node.op], node.l.val), node.r.cache)
if !is_valid_array(node.cache) return ResultOk(node.cache, false) end
end
else
if node.r.constant
node.constant = false
node.cache = vmapnt(Base.Fix2(operators.binops[node.op], node.r.val), node.l.cache)
if !is_valid_array(node.cache) return ResultOk(node.cache, false) end
else
node.constant = false
node.cache = vmapnt(operators.binops[node.op], node.l.cache, node.r.cache)
if !is_valid_array(node.cache) return ResultOk(node.cache, false) end
end
end
end
end
if root.constant
return ResultOk(fill(root.val, size(cX, 2)), true)
else
return ResultOk(root.cache, true)
end
end

end
9 changes: 6 additions & 3 deletions src/DynamicExpressions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,14 @@ import .ValueInterfaceModule:
set_node!,
tree_mapreduce,
filter_map,
filter_map!
filter_map!,
topological_sort,
randomised_topological_sort
import .NodeModule:
constructorof,
with_type_parameters,
preserve_sharing,
max_degree,
leaf_copy,
branch_copy,
leaf_hash,
Expand All @@ -66,8 +69,7 @@ import .NodeModule:
count_scalar_constants,
get_scalar_constants,
set_scalar_constants!
@reexport import .StringsModule: string_tree, print_tree
import .StringsModule: get_op_name
@reexport import .StringsModule: string_tree, print_tree, get_op_name
@reexport import .OperatorEnumModule: AbstractOperatorEnum
@reexport import .OperatorEnumConstructionModule:
OperatorEnum, GenericOperatorEnum, @extend_operators, set_default_variable_names!
Expand Down Expand Up @@ -104,6 +106,7 @@ end
import .InterfacesModule:
ExpressionInterface, NodeInterface, all_ei_methods_except, all_ni_methods_except


function __init__()
@require_extensions
end
Expand Down
157 changes: 156 additions & 1 deletion src/Evaluate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module EvaluateModule

using DispatchDoctor: @stable, @unstable

import ..NodeModule: AbstractExpressionNode, constructorof
import ..NodeModule: AbstractExpressionNode, constructorof, GraphNode, topological_sort
import ..StringsModule: string_tree
import ..OperatorEnumModule: OperatorEnum, GenericOperatorEnum
import ..UtilsModule: fill_similar, counttuple, ResultOk
Expand Down Expand Up @@ -854,4 +854,159 @@ end
end
end

# Parametric arguments don't use dynamic dispatch, calls with turbo/bumper won't resolve properly

# overwritten in ext/DynamicExpressionsLoopVectorizationExt.jl
function _eval_graph_array(
root::GraphNode{T},
cX::AbstractMatrix{T},
operators::OperatorEnum,
loopVectorization::Val{true}
) where {T}
error("DynamicExpressionsLoopVectorizationExt did not overwrite _eval_graph_array")
end

function _eval_graph_array(
root::GraphNode{T},
cX::AbstractMatrix{T},
operators::OperatorEnum,
loopVectorization::Val{false}
) where {T}
order = topological_sort(root)
skip = true
for node in order
skip &= !node.modified
if skip continue end
node.modified = false
if node.degree == 0 && !node.constant
node.cache = view(cX, node.feature, :)
elseif node.degree == 1
if node.l.constant
node.constant = true
node.val = operators.unaops[node.op](node.l.val)
if !is_valid(node.val) return ResultOk(Vector{T}(undef, size(cX, 2)), false) end
else
node.constant = false
node.cache = map(operators.unaops[node.op], node.l.cache)
if !is_valid_array(node.cache) return ResultOk(node.cache, false) end
end
elseif node.degree == 2
if node.l.constant
if node.r.constant
node.constant = true
node.val = operators.binops[node.op](node.l.val, node.r.val)
if !is_valid(node.val) return ResultOk(Vector{T}(undef, size(cX, 2)), false) end
else
node.constant = false
node.cache = map(Base.Fix1(operators.binops[node.op], node.l.val), node.r.cache)
if !is_valid_array(cache[node]) return ResultOk(node.cache, false) end
end
else
if node.r.constant
node.constant = false
node.cache = map(Base.Fix2(operators.binops[node.op], node.r.val), node.l.cache)
if !is_valid_array(node.cache) return ResultOk(node.cache, false) end
else
node.constant = false
node.cache = map(operators.binops[node.op], node.l.cache, node.r.cache)
if !is_valid_array(node.cache) return ResultOk(node.cache, false) end
end
end
end
end
if root.constant
return ResultOk(fill(root.val, size(cX, 2)), true)
else
return ResultOk(root.cache, true)
end
end

function eval_tree_array(
root::GraphNode{T},
cX::AbstractMatrix{T},
operators::OperatorEnum,
eval_options::Union{EvalOptions,Nothing}=nothing
) where {T}

if eval_options.turbo isa Val{true} || isnothing(eval_eval_options) && _is_loopvectorization_loaded(0)
return _eval_graph_array(root, cX, operators, Val(true))
else
return _eval_graph_array(root, cX, operators, Val(false))
end
end

function eval_graph_array_diff(
root::GraphNode{T},
cX::AbstractMatrix{T},
operators::OperatorEnum,
) where {T}

# vmap is faster with small cX sizes
# vmapnt (non-temporal) is faster with larger cX sizes (too big so not worth caching?)
dp = Dict{GraphNode, AbstractArray{T}}()
order = topological_sort(root)
for node in order
if node.degree == 0 && !node.constant
dp[node] = view(cX, node.feature, :)
elseif node.degree == 1
if node.l.constant
node.constant = true
node.val = operators.unaops[node.op](node.l.val)
if !is_valid(node.val) return false end
else
node.constant = false
dp[node] = map(operators.unaops[node.op], dp[node.l])
if !is_valid_array(dp[node]) return false end
end
elseif node.degree == 2
if node.l.constant
if node.r.constant
node.constant = true
node.val = operators.binops[node.op](node.l.val, node.r.val)
if !is_valid(node.val) return false end
else
node.constant = false
dp[node] = map(Base.Fix1(operators.binops[node.op], node.l.val), dp[node.r])
if !is_valid_array(dp[node]) return false end
end
else
if node.r.constant
node.constant = false
dp[node] = map(Base.Fix2(operators.binops[node.op], node.r.val), dp[node.l])
if !is_valid_array(dp[node]) return false end
else
node.constant = false
dp[node] = map(operators.binops[node.op], dp[node.l], dp[node.r])
if !is_valid_array(dp[node]) return false end
end
end
end
end
if root.constant
return fill(root.val, size(cX, 2))
else
return dp[root]
end
end

function eval_graph_single(
root::GraphNode{T},
cX::AbstractArray{T},
operators::OperatorEnum
) where {T}
order = topological_sort(root)
for node in order
if node.degree == 0 && !node.constant
node.val = cX[node.feature]
elseif node.degree == 1
node.val = operators.unaops[node.op](node.l.val)
if !is_valid(node.val) return false end
elseif node.degree == 2
node.val = operators.binops[node.op](node.l.val, node.r.val)
if !is_valid(node.val) return false end
end
end
return root.val
end

end
Loading