From 1b542bca424b34908fbcf8c634d8abccef3b7146 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Thu, 27 Jun 2024 19:11:16 +0100 Subject: [PATCH 01/16] testing out n-arity implementation with SizedVector --- Project.toml | 1 + src/DynamicExpressions.jl | 2 +- src/Node.jl | 162 ++++++++++++++------------------ src/NodeUtils.jl | 20 ++-- src/OperatorEnumConstruction.jl | 8 +- src/ParametricExpression.jl | 10 +- 6 files changed, 90 insertions(+), 113 deletions(-) diff --git a/Project.toml b/Project.toml index 20faf9f0..bb7c8e90 100644 --- a/Project.toml +++ b/Project.toml @@ -13,6 +13,7 @@ PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76" TestItems = "1c621080-faea-4a02-84b6-bbd5e436b8fe" diff --git a/src/DynamicExpressions.jl b/src/DynamicExpressions.jl index 0eebeadc..be8c76f6 100644 --- a/src/DynamicExpressions.jl +++ b/src/DynamicExpressions.jl @@ -107,5 +107,5 @@ end @ignore include("../test/runtests.jl") include("precompile.jl") -do_precompilation(; mode=:precompile) +# do_precompilation(; mode=:precompile) end diff --git a/src/Node.jl b/src/Node.jl index 303c8a93..eeb5e246 100644 --- a/src/Node.jl +++ b/src/Node.jl @@ -1,6 +1,7 @@ module NodeModule using DispatchDoctor: @unstable +using StaticArrays: SizedVector import ..OperatorEnumModule: AbstractOperatorEnum import ..UtilsModule: @memoize_on, @with_memoize, deprecate_varmap, Undefined @@ -8,25 +9,30 @@ import ..UtilsModule: @memoize_on, @with_memoize, deprecate_varmap, Undefined const DEFAULT_NODE_TYPE = Float32 """ - AbstractNode + AbstractNode{D,shared} -Abstract type for binary trees. Must have the following fields: +Abstract type for D-arity trees. If `shared`, the node type +permits graph-like structures. Must have the following fields: - `degree::Integer`: Degree of the node. Either 0, 1, or 2. If 1, then `l` needs to be defined as the left child. If 2, then `r` also needs to be defined as the right child. -- `l::AbstractNode`: Left child of the current node. Should only be +- `children`: A collection of D children nodes. + +# Deprecated fields + +- `l::AbstractNode{D}`: Left child of the current node. Should only be defined if `degree >= 1`; otherwise, leave it undefined (see the the constructors of [`Node{T}`](@ref) for an example). Don't use `nothing` to represent an undefined value as it will incur a large performance penalty. -- `r::AbstractNode`: Right child of the current node. Should only +- `r::AbstractNode{D}`: Right child of the current node. Should only be defined if `degree == 2`. """ -abstract type AbstractNode end +abstract type AbstractNode{D,shared} end """ - AbstractExpressionNode{T} <: AbstractNode + AbstractExpressionNode{T,D} <: AbstractNode{D} Abstract type for nodes that represent an expression. Along with the fields required for `AbstractNode`, @@ -67,11 +73,25 @@ You likely do not need to, but you could choose to override the following: - `with_type_parameters` """ -abstract type AbstractExpressionNode{T} <: AbstractNode end +abstract type AbstractExpressionNode{T,D,shared} <: AbstractNode{D,shared} end + +mutable struct GeneralNode{T,D,shared} <: AbstractExpressionNode{T,D,shared} + degree::UInt8 # 0 for constant/variable, 1 for cos/sin, 2 for +/* etc. + constant::Bool # false if variable + val::T # If is a constant, this stores the actual value + feature::UInt16 # If is a variable (e.g., x in cos(x)), this stores the feature index. + op::UInt8 # If operator, this is the index of the operator in the degree-specific operator enum + children::SizedVector{D,GeneralNode{T,D,shared}} # Children nodes + + ################# + ## Constructors: + ################# + GeneralNode{_T,_D,_shared}() where {_T,_D,_shared} = new{_T,_D,_shared}() +end #! format: off """ - Node{T} <: AbstractExpressionNode{T} + Node{T} <: AbstractExpressionNode{T,2} Node defines a symbolic expression stored in a binary tree. A single `Node` instance is one "node" of this tree, and @@ -81,58 +101,37 @@ nodes, you can evaluate or print a given expression. # Fields - `degree::UInt8`: Degree of the node. 0 for constants, 1 for - unary operators, 2 for binary operators. + unary operators, 2 for binary operators, etc. Maximum of `D`. - `constant::Bool`: Whether the node is a constant. - `val::T`: Value of the node. If `degree==0`, and `constant==true`, this is the value of the constant. It has a type specified by the overall type of the `Node` (e.g., `Float64`). - `feature::UInt16`: Index of the feature to use in the - case of a feature node. Only used if `degree==0` and `constant==false`. - Only defined if `degree == 0 && constant == false`. + case of a feature node. Only defined if `degree == 0 && constant == false`. - `op::UInt8`: If `degree==1`, this is the index of the operator in `operators.unaops`. If `degree==2`, this is the index of the operator in `operators.binops`. In other words, this is an enum of the operators, and is dependent on the specific `OperatorEnum` object. Only defined if `degree >= 1` -- `l::Node{T}`: Left child of the node. Only defined if `degree >= 1`. - Same type as the parent node. -- `r::Node{T}`: Right child of the node. Only defined if `degree == 2`. - Same type as the parent node. This is to be passed as the right - argument to the binary operator. +- `children::SizedArray{D,Node{T,D}}`: Children of the node. Only defined up to `degree` # Constructors - Node([T]; val=nothing, feature=nothing, op=nothing, l=nothing, r=nothing, children=nothing, allocator=default_allocator) - Node{T}(; val=nothing, feature=nothing, op=nothing, l=nothing, r=nothing, children=nothing, allocator=default_allocator) + Node([T]; val=nothing, feature=nothing, op=nothing, children=nothing, allocator=default_allocator) + Node{T}(; val=nothing, feature=nothing, op=nothing, children=nothing, allocator=default_allocator) Create a new node in an expression tree. If `T` is not specified in either the type or the -first argument, it will be inferred from the value of `val` passed or `l` and/or `r`. -If it cannot be inferred from these, it will default to `Float32`. - -The `children` keyword can be used instead of `l` and `r` and should be a tuple of children. This -is to permit the use of splatting in constructors. +first argument, it will be inferred from the value of `val` passed or the children. +The `children` keyword is used to pass in a collection of children nodes. You may also construct nodes via the convenience operators generated by creating an `OperatorEnum`. You may also choose to specify a default memory allocator for the node other than simply `Node{T}()` in the `allocator` keyword argument. """ -mutable struct Node{T} <: AbstractExpressionNode{T} - degree::UInt8 # 0 for constant/variable, 1 for cos/sin, 2 for +/* etc. - constant::Bool # false if variable - val::T # If is a constant, this stores the actual value - # ------------------- (possibly undefined below) - feature::UInt16 # If is a variable (e.g., x in cos(x)), this stores the feature index. - op::UInt8 # If operator, this is the index of the operator in operators.binops, or operators.unaops - l::Node{T} # Left child node. Only defined for degree=1 or degree=2. - r::Node{T} # Right child node. Only defined for degree=2. +const Node{T} = GeneralNode{T,2,false} - ################# - ## Constructors: - ################# - Node{_T}() where {_T} = new{_T}() -end """ GraphNode{T} <: AbstractExpressionNode{T} @@ -146,7 +145,7 @@ be performed with this assumption, to preserve structure of the graph. ```julia julia> operators = OperatorEnum(; binary_operators=[+, -, *], unary_operators=[cos, sin] - ); + ); julia> x = GraphNode(feature=1) x1 @@ -165,18 +164,7 @@ This has the same constructors as [`Node{T}`](@ref). Shared nodes are created simply by using the same node in multiple places when constructing or setting properties. """ -mutable struct GraphNode{T} <: AbstractExpressionNode{T} - degree::UInt8 # 0 for constant/variable, 1 for cos/sin, 2 for +/* etc. - constant::Bool # false if variable - val::T # If is a constant, this stores the actual value - # ------------------- (possibly undefined below) - feature::UInt16 # If is a variable (e.g., x in cos(x)), this stores the feature index. - op::UInt8 # If operator, this is the index of the operator in operators.binops, or operators.unaops - l::GraphNode{T} # Left child node. Only defined for degree=1 or degree=2. - r::GraphNode{T} # Right child node. Only defined for degree=2. - - GraphNode{_T}() where {_T} = new{_T}() -end +const GraphNode{T} = GeneralNode{T,2,true} ################################################################################ #! format: on @@ -184,49 +172,41 @@ end Base.eltype(::Type{<:AbstractExpressionNode{T}}) where {T} = T Base.eltype(::AbstractExpressionNode{T}) where {T} = T -@unstable constructorof(::Type{N}) where {N<:AbstractNode} = Base.typename(N).wrapper +function max_degree(::Type{N}) where {N<:AbstractExpressionNode} + return (N isa UnionAll ? N.body : N).parameters[2] +end + @unstable constructorof(::Type{<:Node}) = Node @unstable constructorof(::Type{<:GraphNode}) = GraphNode -function with_type_parameters(::Type{N}, ::Type{T}) where {N<:AbstractExpressionNode,T} - return constructorof(N){T} -end with_type_parameters(::Type{<:Node}, ::Type{T}) where {T} = Node{T} with_type_parameters(::Type{<:GraphNode}, ::Type{T}) where {T} = GraphNode{T} -function default_allocator(::Type{N}, ::Type{T}) where {N<:AbstractExpressionNode,T} - return with_type_parameters(N, T)() -end default_allocator(::Type{<:Node}, ::Type{T}) where {T} = Node{T}() default_allocator(::Type{<:GraphNode}, ::Type{T}) where {T} = GraphNode{T}() """Trait declaring whether nodes share children or not.""" preserve_sharing(::Union{Type{<:AbstractNode},AbstractNode}) = false -preserve_sharing(::Union{Type{<:Node},Node}) = false -preserve_sharing(::Union{Type{<:GraphNode},GraphNode}) = true +function preserve_sharing( + ::Union{Type{<:G},G} +) where {shared,G<:GeneralNode{T,D,shared} where {T,D}} + return shared +end include("base.jl") #! format: off @inline function (::Type{N})( - ::Type{T1}=Undefined; val=nothing, feature=nothing, op=nothing, l=nothing, r=nothing, children=nothing, allocator::F=default_allocator, + ::Type{T1}=Undefined; val=nothing, feature=nothing, op=nothing, children=nothing, allocator::F=default_allocator, ) where {T1,N<:AbstractExpressionNode,F} - validate_not_all_defaults(N, val, feature, op, l, r, children) - if children !== nothing - @assert l === nothing && r === nothing - if length(children) == 1 - return node_factory(N, T1, val, feature, op, only(children), nothing, allocator) - else - return node_factory(N, T1, val, feature, op, children..., allocator) - end - end - return node_factory(N, T1, val, feature, op, l, r, allocator) + validate_not_all_defaults(N, val, feature, op, children) + return node_factory(N, T1, val, feature, op, children, allocator) end -function validate_not_all_defaults(::Type{N}, val, feature, op, l, r, children) where {N<:AbstractExpressionNode} +function validate_not_all_defaults(::Type{N}, val, feature, op, children) where {N<:AbstractExpressionNode} return nothing end -function validate_not_all_defaults(::Type{N}, val, feature, op, l, r, children) where {T,N<:AbstractExpressionNode{T}} - if val === nothing && feature === nothing && op === nothing && l === nothing && r === nothing && children === nothing +function validate_not_all_defaults(::Type{N}, val, feature, op, children) where {T,N<:AbstractExpressionNode{T}} + if val === nothing && feature === nothing && op === nothing && children === nothing error( "Encountered the call for $N() inside the generic constructor. " * "Did you forget to define `$(Base.typename(N).wrapper){T}() where {T} = new{T}()`?" @@ -236,7 +216,7 @@ function validate_not_all_defaults(::Type{N}, val, feature, op, l, r, children) end """Create a constant leaf.""" @inline function node_factory( - ::Type{N}, ::Type{T1}, val::T2, ::Nothing, ::Nothing, ::Nothing, ::Nothing, allocator::F, + ::Type{N}, ::Type{T1}, val::T2, ::Nothing, ::Nothing, ::Nothing, allocator::F, ) where {N,T1,T2,F} T = node_factory_type(N, T1, T2) n = allocator(N, T) @@ -247,7 +227,7 @@ end end """Create a variable leaf, to store data.""" @inline function node_factory( - ::Type{N}, ::Type{T1}, ::Nothing, feature::Integer, ::Nothing, ::Nothing, ::Nothing, allocator::F, + ::Type{N}, ::Type{T1}, ::Nothing, feature::Integer, ::Nothing, ::Nothing, allocator::F, ) where {N,T1,F} T = node_factory_type(N, T1, DEFAULT_NODE_TYPE) n = allocator(N, T) @@ -256,28 +236,22 @@ end n.feature = feature return n end -"""Create a unary operator node.""" +"""Create an operator node.""" @inline function node_factory( - ::Type{N}, ::Type{T1}, ::Nothing, ::Nothing, op::Integer, l::AbstractExpressionNode{T2}, ::Nothing, allocator::F, -) where {N,T1,T2,F} - @assert l isa N - T = T2 # Always prefer existing nodes, so we don't mess up references from conversion + ::Type{N}, ::Type, ::Nothing, ::Nothing, op::Integer, children::NTuple{D2}, allocator::F, +) where {N,F,D2} + D = max_degree(N) + @assert D2 <= D + T = promote_type(map(eltype, children)...) # Always prefer existing nodes, so we don't mess up references from conversion + NT = with_type_parameters(N, T) n = allocator(N, T) - n.degree = 1 + n.degree = D2 n.op = op - n.l = l - return n -end -"""Create a binary operator node.""" -@inline function node_factory( - ::Type{N}, ::Type{T1}, ::Nothing, ::Nothing, op::Integer, l::AbstractExpressionNode{T2}, r::AbstractExpressionNode{T3}, allocator::F, -) where {N,T1,T2,T3,F} - T = promote_type(T2, T3) - n = allocator(N, T) - n.degree = 2 - n.op = op - n.l = T2 === T ? l : convert(with_type_parameters(N, T), l) - n.r = T3 === T ? r : convert(with_type_parameters(N, T), r) + ar = SizedVector{D,NT}(undef) + for i in 1:D2 + ar[i] = children[i] + end + n.children = ar return n end diff --git a/src/NodeUtils.jl b/src/NodeUtils.jl index ae331f26..17b4e898 100644 --- a/src/NodeUtils.jl +++ b/src/NodeUtils.jl @@ -1,9 +1,11 @@ module NodeUtilsModule +using StaticArrays: MVector import Compat: Returns import ..NodeModule: AbstractNode, AbstractExpressionNode, + GeneralNode, Node, preserve_sharing, constructorof, @@ -98,18 +100,18 @@ end ## Assign index to nodes of a tree # This will mirror a Node struct, rather # than adding a new attribute to Node. -struct NodeIndex{T} <: AbstractNode +struct NodeIndex{T,D} <: AbstractNode{D,false} degree::UInt8 # 0 for constant/variable, 1 for cos/sin, 2 for +/* etc. val::T # If is a constant, this stores the actual value # ------------------- (possibly undefined below) - l::NodeIndex{T} # Left child node. Only defined for degree=1 or degree=2. - r::NodeIndex{T} # Right child node. Only defined for degree=2. - - NodeIndex(::Type{_T}) where {_T} = new{_T}(0, zero(_T)) - NodeIndex(::Type{_T}, val) where {_T} = new{_T}(0, convert(_T, val)) - NodeIndex(::Type{_T}, l::NodeIndex) where {_T} = new{_T}(1, zero(_T), l) - function NodeIndex(::Type{_T}, l::NodeIndex, r::NodeIndex) where {_T} - return new{_T}(2, zero(_T), l, r) + children::MVector{D,NodeIndex{T,D}} + + NodeIndex(::Type{_T}, ::Type{_D}) where {_T,_D} = new{_T,_D}(0, zero(_T)) + NodeIndex(::Type{_T}, ::Type{_D}, val) where {_T,_D} = new{_T,_D}(0, convert(_T, val)) + function NodeIndex(::Type{_T}, ::Type{_D}, children::Vararg{Any,_D2}) where {_T,_D,_D2} + _children = MVector{_D,NodeIndex{_T,_D}}(undef) + _children[begin:_D2] = children + return new{_T,_D}(1, zero(_T), _children) end end # Sharing is never needed for NodeIndex, diff --git a/src/OperatorEnumConstruction.jl b/src/OperatorEnumConstruction.jl index 919dea10..3c17ecb7 100644 --- a/src/OperatorEnumConstruction.jl +++ b/src/OperatorEnumConstruction.jl @@ -141,7 +141,7 @@ function _extend_unary_operator(f::Symbol, type_requirements, internal) $_constructorof(N)(T; val=$($f)(l.val)) else latest_op_idx = $($lookup_op)($($f), Val(1)) - $_constructorof(N)(; op=latest_op_idx, l) + $_constructorof(N)(; op=latest_op_idx, children=(l,)) end end end @@ -168,7 +168,7 @@ function _extend_binary_operator(f::Symbol, type_requirements, build_converters, $_constructorof(N)(T; val=$($f)(l.val, r.val)) else latest_op_idx = $($lookup_op)($($f), Val(2)) - $_constructorof(N)(; op=latest_op_idx, l, r) + $_constructorof(N)(; op=latest_op_idx, children=(l, r)) end end function $($f)( @@ -179,7 +179,7 @@ function _extend_binary_operator(f::Symbol, type_requirements, build_converters, else latest_op_idx = $($lookup_op)($($f), Val(2)) $_constructorof(N)(; - op=latest_op_idx, l, r=$_constructorof(N)(T; val=r) + op=latest_op_idx, children=(l, $_constructorof(N)(T; val=r)) ) end end @@ -191,7 +191,7 @@ function _extend_binary_operator(f::Symbol, type_requirements, build_converters, else latest_op_idx = $($lookup_op)($($f), Val(2)) $_constructorof(N)(; - op=latest_op_idx, l=$_constructorof(N)(T; val=l), r + op=latest_op_idx, children=($_constructorof(N)(T; val=l), r) ) end end diff --git a/src/ParametricExpression.jl b/src/ParametricExpression.jl index 383af13d..d7960943 100644 --- a/src/ParametricExpression.jl +++ b/src/ParametricExpression.jl @@ -1,6 +1,7 @@ module ParametricExpressionModule using DispatchDoctor: @stable, @unstable +using StaticArrays: MVector using ..OperatorEnumModule: AbstractOperatorEnum using ..NodeModule: AbstractExpressionNode, Node, tree_mapreduce @@ -29,7 +30,7 @@ import ..ExpressionModule: import ..ParseModule: parse_leaf """A type of expression node that also stores a parameter index""" -mutable struct ParametricNode{T} <: AbstractExpressionNode{T} +mutable struct ParametricNode{T,D,shared} <: AbstractExpressionNode{T,D,shared} degree::UInt8 constant::Bool # if true => constant; if false, then check `is_parameter` val::T @@ -39,11 +40,10 @@ mutable struct ParametricNode{T} <: AbstractExpressionNode{T} parameter::UInt16 # Stores index of per-class parameter op::UInt8 - l::ParametricNode{T} - r::ParametricNode{T} + children::MVector{D,ParametricNode{T,D}} # Children nodes - function ParametricNode{_T}() where {_T} - n = new{_T}() + function ParametricNode{_T,_D,_shared}() where {_T,_D,_shared} + n = new{_T,_D,_shared}() n.is_parameter = false n.parameter = UInt16(0) return n From ab7c65ca810672fcacc65ba826b899e3bab21ad1 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Mon, 8 Jul 2024 11:52:01 +0100 Subject: [PATCH 02/16] wip --- Project.toml | 1 - src/Node.jl | 108 +++++++++++++++++++++++++++------------------------ 2 files changed, 57 insertions(+), 52 deletions(-) diff --git a/Project.toml b/Project.toml index b2216e58..3f5df1a5 100644 --- a/Project.toml +++ b/Project.toml @@ -13,7 +13,6 @@ PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" -StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76" [weakdeps] diff --git a/src/Node.jl b/src/Node.jl index eeb5e246..d1102168 100644 --- a/src/Node.jl +++ b/src/Node.jl @@ -1,7 +1,6 @@ module NodeModule using DispatchDoctor: @unstable -using StaticArrays: SizedVector import ..OperatorEnumModule: AbstractOperatorEnum import ..UtilsModule: @memoize_on, @with_memoize, deprecate_varmap, Undefined @@ -9,15 +8,14 @@ import ..UtilsModule: @memoize_on, @with_memoize, deprecate_varmap, Undefined const DEFAULT_NODE_TYPE = Float32 """ - AbstractNode{D,shared} + AbstractNode{D} -Abstract type for D-arity trees. If `shared`, the node type -permits graph-like structures. Must have the following fields: +Abstract type for D-arity trees. Must have the following fields: - `degree::Integer`: Degree of the node. Either 0, 1, or 2. If 1, then `l` needs to be defined as the left child. If 2, then `r` also needs to be defined as the right child. -- `children`: A collection of D children nodes. +- `children`: A collection of D references to children nodes. # Deprecated fields @@ -29,7 +27,7 @@ permits graph-like structures. Must have the following fields: - `r::AbstractNode{D}`: Right child of the current node. Should only be defined if `degree == 2`. """ -abstract type AbstractNode{D,shared} end +abstract type AbstractNode{D} end """ AbstractExpressionNode{T,D} <: AbstractNode{D} @@ -73,25 +71,27 @@ You likely do not need to, but you could choose to override the following: - `with_type_parameters` """ -abstract type AbstractExpressionNode{T,D,shared} <: AbstractNode{D,shared} end - -mutable struct GeneralNode{T,D,shared} <: AbstractExpressionNode{T,D,shared} - degree::UInt8 # 0 for constant/variable, 1 for cos/sin, 2 for +/* etc. - constant::Bool # false if variable - val::T # If is a constant, this stores the actual value - feature::UInt16 # If is a variable (e.g., x in cos(x)), this stores the feature index. - op::UInt8 # If operator, this is the index of the operator in the degree-specific operator enum - children::SizedVector{D,GeneralNode{T,D,shared}} # Children nodes - - ################# - ## Constructors: - ################# - GeneralNode{_T,_D,_shared}() where {_T,_D,_shared} = new{_T,_D,_shared}() +abstract type AbstractExpressionNode{T,D} <: AbstractNode{D} end + +for N in (:Node, :GraphNode) + @eval mutable struct $N{T,D} <: AbstractExpressionNode{T,D} + degree::UInt8 # 0 for constant/variable, 1 for cos/sin, 2 for +/* etc. + constant::Bool # false if variable + val::T # If is a constant, this stores the actual value + feature::UInt16 # (Possibly undefined) If is a variable (e.g., x in cos(x)), this stores the feature index. + op::UInt8 # (Possibly undefined) If operator, this is the index of the operator in the degree-specific operator enum + children::NTuple{D,Base.RefValue{$N{T,D}}} # Children nodes + + ################# + ## Constructors: + ################# + $N{_T,_D}() where {_T,_D} = new{_T,_D}() + end end #! format: off """ - Node{T} <: AbstractExpressionNode{T,2} + Node{T,D} <: AbstractExpressionNode{T,D} Node defines a symbolic expression stored in a binary tree. A single `Node` instance is one "node" of this tree, and @@ -113,7 +113,7 @@ nodes, you can evaluate or print a given expression. operator in `operators.binops`. In other words, this is an enum of the operators, and is dependent on the specific `OperatorEnum` object. Only defined if `degree >= 1` -- `children::SizedArray{D,Node{T,D}}`: Children of the node. Only defined up to `degree` +- `children::NTuple{D,Base.RefValue{Node{T,D}}}`: Children of the node. Only defined up to `degree` # Constructors @@ -130,13 +130,13 @@ You may also construct nodes via the convenience operators generated by creating You may also choose to specify a default memory allocator for the node other than simply `Node{T}()` in the `allocator` keyword argument. """ -const Node{T} = GeneralNode{T,2,false} +Node """ - GraphNode{T} <: AbstractExpressionNode{T} + GraphNode{T,D} <: AbstractExpressionNode{T,D} -Exactly the same as [`Node{T}`](@ref), but with the assumption that some +Exactly the same as [`Node{T,D}`](@ref), but with the assumption that some nodes will be shared. All copies of this graph-like structure will be performed with this assumption, to preserve structure of the graph. @@ -164,7 +164,7 @@ This has the same constructors as [`Node{T}`](@ref). Shared nodes are created simply by using the same node in multiple places when constructing or setting properties. """ -const GraphNode{T} = GeneralNode{T,2,true} +GraphNode ################################################################################ #! format: on @@ -177,30 +177,41 @@ function max_degree(::Type{N}) where {N<:AbstractExpressionNode} end @unstable constructorof(::Type{<:Node}) = Node +@unstable constructorof(::Type{<:Node{T,D} where T}) where {D} = Node{T,D} where T @unstable constructorof(::Type{<:GraphNode}) = GraphNode +@unstable constructorof(::Type{<:GraphNode{T,D} where T}) where {D} = GraphNode{T,D} where T -with_type_parameters(::Type{<:Node}, ::Type{T}) where {T} = Node{T} -with_type_parameters(::Type{<:GraphNode}, ::Type{T}) where {T} = GraphNode{T} +with_type_parameters(::Type{<:Node}, ::Type{T}, ::Val{D}=Val(2)) where {T,D} = Node{T,D} +with_type_parameters(::Type{<:GraphNode}, ::Type{T}, ::Val{D}=Val(2)) where {T,D} = GraphNode{T,D} -default_allocator(::Type{<:Node}, ::Type{T}) where {T} = Node{T}() -default_allocator(::Type{<:GraphNode}, ::Type{T}) where {T} = GraphNode{T}() +default_allocator(::Type{<:Node}, ::Type{T}, ::Val{D}=Val(2)) where {T,D} = Node{T,D}() +default_allocator(::Type{<:GraphNode}, ::Type{T}, ::Val{D}=Val(2)) where {T,D} = GraphNode{T,D}() """Trait declaring whether nodes share children or not.""" preserve_sharing(::Union{Type{<:AbstractNode},AbstractNode}) = false -function preserve_sharing( - ::Union{Type{<:G},G} -) where {shared,G<:GeneralNode{T,D,shared} where {T,D}} - return shared -end +preserve_sharing(::Union{Type{<:GraphNode},GraphNode}) = true include("base.jl") #! format: off @inline function (::Type{N})( - ::Type{T1}=Undefined; val=nothing, feature=nothing, op=nothing, children=nothing, allocator::F=default_allocator, + ::Type{T1}=Undefined; kws... ) where {T1,N<:AbstractExpressionNode,F} - validate_not_all_defaults(N, val, feature, op, children) - return node_factory(N, T1, val, feature, op, children, allocator) +end +@inline function (::Type{N})( + ::Type{T1}=Undefined; val=nothing, feature=nothing, op=nothing, l=nothing, r=nothing, children=nothing, allocator::F=default_allocator, +) where {T1,D,N<:AbstractExpressionNode{T,D} where T,F} + _children = if l !== nothing && r === nothing + @assert children === nothing + (l,) + elseif l !== nothing && r !== nothing + @assert children === nothing + (l, r) + else + children + end + validate_not_all_defaults(N, val, feature, op, _children) + return node_factory(N, T1, val, feature, op, _children, allocator) end function validate_not_all_defaults(::Type{N}, val, feature, op, children) where {N<:AbstractExpressionNode} return nothing @@ -209,7 +220,7 @@ function validate_not_all_defaults(::Type{N}, val, feature, op, children) where if val === nothing && feature === nothing && op === nothing && children === nothing error( "Encountered the call for $N() inside the generic constructor. " - * "Did you forget to define `$(Base.typename(N).wrapper){T}() where {T} = new{T}()`?" + * "Did you forget to define `$(Base.typename(N).wrapper){T,D}() where {T,D} = new{T,D}()`?" ) end return nothing @@ -219,7 +230,7 @@ end ::Type{N}, ::Type{T1}, val::T2, ::Nothing, ::Nothing, ::Nothing, allocator::F, ) where {N,T1,T2,F} T = node_factory_type(N, T1, T2) - n = allocator(N, T) + n = allocator(N, T, D) n.degree = 0 n.constant = true n.val = convert(T, val) @@ -230,7 +241,7 @@ end ::Type{N}, ::Type{T1}, ::Nothing, feature::Integer, ::Nothing, ::Nothing, allocator::F, ) where {N,T1,F} T = node_factory_type(N, T1, DEFAULT_NODE_TYPE) - n = allocator(N, T) + n = allocator(N, T, D) n.degree = 0 n.constant = false n.feature = feature @@ -239,19 +250,14 @@ end """Create an operator node.""" @inline function node_factory( ::Type{N}, ::Type, ::Nothing, ::Nothing, op::Integer, children::NTuple{D2}, allocator::F, -) where {N,F,D2} - D = max_degree(N) - @assert D2 <= D +) where {D,N<:AbstractExpressionNode{T where T,D},F,D2} T = promote_type(map(eltype, children)...) # Always prefer existing nodes, so we don't mess up references from conversion - NT = with_type_parameters(N, T) - n = allocator(N, T) + NT = with_type_parameters(N, T, D) + n = allocator(N, T, D) n.degree = D2 n.op = op - ar = SizedVector{D,NT}(undef) - for i in 1:D2 - ar[i] = children[i] - end - n.children = ar + n.children + # map(Ref, children) return n end From 3ed6b41067c48f8e374f0c198ce3d6203fab8185 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Mon, 8 Jul 2024 12:38:39 +0100 Subject: [PATCH 03/16] fix: various aspects of degree interface --- src/Node.jl | 95 +++++++++++++++++++++++++------------ src/NodeUtils.jl | 33 +++++++------ src/ParametricExpression.jl | 9 ++-- 3 files changed, 86 insertions(+), 51 deletions(-) diff --git a/src/Node.jl b/src/Node.jl index d1102168..b3731f85 100644 --- a/src/Node.jl +++ b/src/Node.jl @@ -85,7 +85,7 @@ for N in (:Node, :GraphNode) ################# ## Constructors: ################# - $N{_T,_D}() where {_T,_D} = new{_T,_D}() + $N{_T,_D}() where {_T,_D} = new{_T,_D::Int}() end end @@ -166,26 +166,62 @@ when constructing or setting properties. """ GraphNode +@inline function Base.getproperty(n::Union{Node,GraphNode}, k::Symbol) + if k == :l + # TODO: Should a depwarn be raised here? Or too slow? + return getfield(n, :children)[1][] + elseif k == :r + return getfield(n, :children)[2][] + else + return getfield(n, k) + end +end +@inline function Base.setproperty!(n::Union{Node,GraphNode}, k::Symbol, v) + if k == :l + getfield(n, :children)[1][] = v + elseif k == :r + getfield(n, :children)[2][] = v + elseif k == :degree + setfield!(n, :degree, convert(UInt8, v)) + elseif k == :constant + setfield!(n, :constant, convert(Bool, v)) + elseif k == :feature + setfield!(n, :feature, convert(UInt16, v)) + elseif k == :op + setfield!(n, :op, convert(UInt8, v)) + elseif k == :val + setfield!(n, :val, convert(eltype(n), v)) + elseif k == :children + setfield!(n, :children, v) + else + error("Invalid property: $k") + end +end + ################################################################################ #! format: on Base.eltype(::Type{<:AbstractExpressionNode{T}}) where {T} = T Base.eltype(::AbstractExpressionNode{T}) where {T} = T -function max_degree(::Type{N}) where {N<:AbstractExpressionNode} - return (N isa UnionAll ? N.body : N).parameters[2] -end +max_degree(::Type{<:AbstractNode}) = 2 # Default +max_degree(::Type{<:AbstractNode{D}}) where {D} = D + +@unstable constructorof(::Type{N}) where {N<:Node} = Node{T,max_degree(N)} where {T} +@unstable constructorof(::Type{N}) where {N<:GraphNode} = + GraphNode{T,max_degree(N)} where {T} -@unstable constructorof(::Type{<:Node}) = Node -@unstable constructorof(::Type{<:Node{T,D} where T}) where {D} = Node{T,D} where T -@unstable constructorof(::Type{<:GraphNode}) = GraphNode -@unstable constructorof(::Type{<:GraphNode{T,D} where T}) where {D} = GraphNode{T,D} where T +with_type_parameters(::Type{N}, ::Type{T}) where {N<:Node,T} = Node{T,max_degree(N)} +function with_type_parameters(::Type{N}, ::Type{T}) where {N<:GraphNode,T} + return GraphNode{T,max_degree(N)} +end -with_type_parameters(::Type{<:Node}, ::Type{T}, ::Val{D}=Val(2)) where {T,D} = Node{T,D} -with_type_parameters(::Type{<:GraphNode}, ::Type{T}, ::Val{D}=Val(2)) where {T,D} = GraphNode{T,D} +# with_degree(::Type{N}, ::Val{D}) where {T,N<:Node{T},D} = Node{T,D} +# with_degree(::Type{N}, ::Val{D}) where {T,N<:GraphNode{T},D} = GraphNode{T,D} -default_allocator(::Type{<:Node}, ::Type{T}, ::Val{D}=Val(2)) where {T,D} = Node{T,D}() -default_allocator(::Type{<:GraphNode}, ::Type{T}, ::Val{D}=Val(2)) where {T,D} = GraphNode{T,D}() +function default_allocator(::Type{N}, ::Type{T}) where {N<:Union{Node,GraphNode},T} + return with_type_parameters(N, T)() +end """Trait declaring whether nodes share children or not.""" preserve_sharing(::Union{Type{<:AbstractNode},AbstractNode}) = false @@ -194,13 +230,9 @@ preserve_sharing(::Union{Type{<:GraphNode},GraphNode}) = true include("base.jl") #! format: off -@inline function (::Type{N})( - ::Type{T1}=Undefined; kws... -) where {T1,N<:AbstractExpressionNode,F} -end @inline function (::Type{N})( ::Type{T1}=Undefined; val=nothing, feature=nothing, op=nothing, l=nothing, r=nothing, children=nothing, allocator::F=default_allocator, -) where {T1,D,N<:AbstractExpressionNode{T,D} where T,F} +) where {T1,N<:AbstractExpressionNode{T} where T,F} _children = if l !== nothing && r === nothing @assert children === nothing (l,) @@ -230,7 +262,7 @@ end ::Type{N}, ::Type{T1}, val::T2, ::Nothing, ::Nothing, ::Nothing, allocator::F, ) where {N,T1,T2,F} T = node_factory_type(N, T1, T2) - n = allocator(N, T, D) + n = allocator(N, T) n.degree = 0 n.constant = true n.val = convert(T, val) @@ -241,7 +273,7 @@ end ::Type{N}, ::Type{T1}, ::Nothing, feature::Integer, ::Nothing, ::Nothing, allocator::F, ) where {N,T1,F} T = node_factory_type(N, T1, DEFAULT_NODE_TYPE) - n = allocator(N, T, D) + n = allocator(N, T) n.degree = 0 n.constant = false n.feature = feature @@ -249,15 +281,16 @@ end end """Create an operator node.""" @inline function node_factory( - ::Type{N}, ::Type, ::Nothing, ::Nothing, op::Integer, children::NTuple{D2}, allocator::F, -) where {D,N<:AbstractExpressionNode{T where T,D},F,D2} + ::Type{N}, ::Type, ::Nothing, ::Nothing, op::Integer, children::Tuple, allocator::F, +) where {N<:AbstractExpressionNode,F} T = promote_type(map(eltype, children)...) # Always prefer existing nodes, so we don't mess up references from conversion - NT = with_type_parameters(N, T, D) - n = allocator(N, T, D) + D2 = length(children) + @assert D2 <= max_degree(N) + NT = with_type_parameters(N, T) + n = allocator(N, T) n.degree = D2 n.op = op - n.children - # map(Ref, children) + n.children = ntuple(i -> i <= D2 ? Ref(convert(NT, children[i])) : Ref{NT}(), Val(max_degree(N))) return n end @@ -298,14 +331,14 @@ function (::Type{N})( return N(; feature=i) end -function Base.promote_rule(::Type{Node{T1}}, ::Type{Node{T2}}) where {T1,T2} - return Node{promote_type(T1, T2)} +function Base.promote_rule(::Type{Node{T1,D}}, ::Type{Node{T2,D}}) where {T1,T2,D} + return Node{promote_type(T1, T2),D} end -function Base.promote_rule(::Type{GraphNode{T1}}, ::Type{Node{T2}}) where {T1,T2} - return GraphNode{promote_type(T1, T2)} +function Base.promote_rule(::Type{GraphNode{T1,D}}, ::Type{Node{T2,D}}) where {T1,T2,D} + return GraphNode{promote_type(T1, T2),D} end -function Base.promote_rule(::Type{GraphNode{T1}}, ::Type{GraphNode{T2}}) where {T1,T2} - return GraphNode{promote_type(T1, T2)} +function Base.promote_rule(::Type{GraphNode{T1,D}}, ::Type{GraphNode{T2,D}}) where {T1,T2,D} + return GraphNode{promote_type(T1, T2),D} end # TODO: Verify using this helps with garbage collection diff --git a/src/NodeUtils.jl b/src/NodeUtils.jl index 37e60596..2b49441c 100644 --- a/src/NodeUtils.jl +++ b/src/NodeUtils.jl @@ -1,11 +1,9 @@ module NodeUtilsModule -using StaticArrays: MVector import Compat: Returns import ..NodeModule: AbstractNode, AbstractExpressionNode, - GeneralNode, Node, preserve_sharing, constructorof, @@ -145,17 +143,20 @@ end ## Assign index to nodes of a tree # This will mirror a Node struct, rather # than adding a new attribute to Node. -struct NodeIndex{T,D} <: AbstractNode{D,false} +struct NodeIndex{T,D} <: AbstractNode{D} degree::UInt8 # 0 for constant/variable, 1 for cos/sin, 2 for +/* etc. val::T # If is a constant, this stores the actual value # ------------------- (possibly undefined below) - children::MVector{D,NodeIndex{T,D}} - - NodeIndex(::Type{_T}, ::Type{_D}) where {_T,_D} = new{_T,_D}(0, zero(_T)) - NodeIndex(::Type{_T}, ::Type{_D}, val) where {_T,_D} = new{_T,_D}(0, convert(_T, val)) - function NodeIndex(::Type{_T}, ::Type{_D}, children::Vararg{Any,_D2}) where {_T,_D,_D2} - _children = MVector{_D,NodeIndex{_T,_D}}(undef) - _children[begin:_D2] = children + children::NTuple{D,Base.RefValue{NodeIndex{T,D}}} + + NodeIndex(::Type{_T}, ::Val{_D}) where {_T,_D} = new{_T,_D}(0, zero(_T)) + NodeIndex(::Type{_T}, ::Val{_D}, val) where {_T,_D} = new{_T,_D}(0, convert(_T, val)) + function NodeIndex( + ::Type{_T}, ::Val{_D}, children::Vararg{NodeIndex{_T,_D},_D2} + ) where {_T,_D,_D2} + _children = ntuple( + i -> i <= _D2 ? Ref(children[i]) : Ref{NodeIndex{_T,_D}}(), Val(_D) + ) return new{_T,_D}(1, zero(_T), _children) end end @@ -163,20 +164,22 @@ end # as we trace over the node we are indexing on. preserve_sharing(::Union{Type{<:NodeIndex},NodeIndex}) = false -function index_constant_nodes(tree::AbstractExpressionNode, ::Type{T}=UInt16) where {T} +function index_constant_nodes( + tree::AbstractExpressionNode{Ti,D} where {Ti}, ::Type{T}=UInt16 +) where {D,T} # Essentially we copy the tree, replacing the values # with indices constant_index = Ref(T(0)) return tree_mapreduce( t -> if t.constant - NodeIndex(T, (constant_index[] += T(1))) + NodeIndex(T, Val(D), (constant_index[] += T(1))) else - NodeIndex(T) + NodeIndex(T, Val(D)) end, t -> nothing, - (_, c...) -> NodeIndex(T, c...), + (_, c...) -> NodeIndex(T, Val(D), c...), tree, - NodeIndex{T}; + NodeIndex{T,D}; ) end diff --git a/src/ParametricExpression.jl b/src/ParametricExpression.jl index b67f0490..4b6d4b50 100644 --- a/src/ParametricExpression.jl +++ b/src/ParametricExpression.jl @@ -1,7 +1,6 @@ module ParametricExpressionModule using DispatchDoctor: @stable, @unstable -using StaticArrays: MVector using ChainRulesCore: ChainRulesCore as CRC, NoTangent, @thunk using ..OperatorEnumModule: AbstractOperatorEnum, OperatorEnum @@ -35,7 +34,7 @@ import ..ValueInterfaceModule: count_scalar_constants, pack_scalar_constants!, unpack_scalar_constants """A type of expression node that also stores a parameter index""" -mutable struct ParametricNode{T,D,shared} <: AbstractExpressionNode{T,D,shared} +mutable struct ParametricNode{T,D} <: AbstractExpressionNode{T,D} degree::UInt8 constant::Bool # if true => constant; if false, then check `is_parameter` val::T @@ -45,10 +44,10 @@ mutable struct ParametricNode{T,D,shared} <: AbstractExpressionNode{T,D,shared} parameter::UInt16 # Stores index of per-class parameter op::UInt8 - children::MVector{D,ParametricNode{T,D}} # Children nodes + children::NTuple{D,Base.RefValue{ParametricNode{T,D}}} # Children nodes - function ParametricNode{_T,_D,_shared}() where {_T,_D,_shared} - n = new{_T,_D,_shared}() + function ParametricNode{_T,_D}() where {_T,_D} + n = new{_T,_D}() n.is_parameter = false n.parameter = UInt16(0) return n From b5285f7c355129deed855d12632749d32f00915f Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Mon, 8 Jul 2024 21:50:12 +0100 Subject: [PATCH 04/16] fix: segfault in NodeIndex See https://github.com/JuliaLang/julia/issues/55076 for details --- src/DynamicExpressions.jl | 2 +- src/Node.jl | 2 +- src/NodeUtils.jl | 13 +++++++++---- 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/src/DynamicExpressions.jl b/src/DynamicExpressions.jl index 927adb78..c2d71489 100644 --- a/src/DynamicExpressions.jl +++ b/src/DynamicExpressions.jl @@ -117,5 +117,5 @@ end @ignore include("../test/runtests.jl") include("precompile.jl") -# do_precompilation(; mode=:precompile) +do_precompilation(; mode=:precompile) end diff --git a/src/Node.jl b/src/Node.jl index b3731f85..0d1ab048 100644 --- a/src/Node.jl +++ b/src/Node.jl @@ -3,7 +3,7 @@ module NodeModule using DispatchDoctor: @unstable import ..OperatorEnumModule: AbstractOperatorEnum -import ..UtilsModule: @memoize_on, @with_memoize, deprecate_varmap, Undefined +import ..UtilsModule: deprecate_varmap, Undefined const DEFAULT_NODE_TYPE = Float32 diff --git a/src/NodeUtils.jl b/src/NodeUtils.jl index 2b49441c..87c4ba60 100644 --- a/src/NodeUtils.jl +++ b/src/NodeUtils.jl @@ -143,23 +143,28 @@ end ## Assign index to nodes of a tree # This will mirror a Node struct, rather # than adding a new attribute to Node. -struct NodeIndex{T,D} <: AbstractNode{D} +mutable struct NodeIndex{T,D} <: AbstractNode{D} degree::UInt8 # 0 for constant/variable, 1 for cos/sin, 2 for +/* etc. val::T # If is a constant, this stores the actual value # ------------------- (possibly undefined below) children::NTuple{D,Base.RefValue{NodeIndex{T,D}}} - NodeIndex(::Type{_T}, ::Val{_D}) where {_T,_D} = new{_T,_D}(0, zero(_T)) - NodeIndex(::Type{_T}, ::Val{_D}, val) where {_T,_D} = new{_T,_D}(0, convert(_T, val)) + function NodeIndex(::Type{_T}, ::Val{_D}, val) where {_T,_D} + return new{_T,_D}( + 0, convert(_T, val), ntuple(_ -> Ref{NodeIndex{_T,_D}}(), Val(_D)) + ) + end function NodeIndex( ::Type{_T}, ::Val{_D}, children::Vararg{NodeIndex{_T,_D},_D2} ) where {_T,_D,_D2} _children = ntuple( i -> i <= _D2 ? Ref(children[i]) : Ref{NodeIndex{_T,_D}}(), Val(_D) ) - return new{_T,_D}(1, zero(_T), _children) + return new{_T,_D}(convert(UInt8, _D2), zero(_T), _children) end end +NodeIndex(::Type{T}, ::Val{D}) where {T,D} = NodeIndex(T, Val(D), zero(T)) + # Sharing is never needed for NodeIndex, # as we trace over the node we are indexing on. preserve_sharing(::Union{Type{<:NodeIndex},NodeIndex}) = false From 8707d24726e42d1f4d0d36339062d95070b4bb9d Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Mon, 8 Jul 2024 21:50:45 +0100 Subject: [PATCH 05/16] refactor: no more need for `memoize_on` --- src/Utils.jl | 97 --------------------------------------------- src/base.jl | 83 +++++++++++++++++++++++++++----------- test/test_graphs.jl | 70 -------------------------------- 3 files changed, 60 insertions(+), 190 deletions(-) diff --git a/src/Utils.jl b/src/Utils.jl index bd3326e2..691de70a 100644 --- a/src/Utils.jl +++ b/src/Utils.jl @@ -12,103 +12,6 @@ macro return_on_false2(flag, retval, retval2) ) end -""" - @memoize_on tree [postprocess] function my_function_on_tree(tree::AbstractExpressionNode) - ... - end - -This macro takes a function definition and creates a second version of the -function with an additional `id_map` argument. When passed this argument (an -IdDict()), it will use use the `id_map` to avoid recomputing the same value -for the same node in a tree. Use this to automatically create functions that -work with trees that have shared child nodes. - -Can optionally take a `postprocess` function, which will be applied to the -result of the function before returning it, taking the result as the -first argument and a boolean for whether the result was memoized as the -second argument. This is useful for functions that need to count the number -of unique nodes in a tree, for example. -""" -macro memoize_on(tree, args...) - if length(args) ∉ (1, 2) - error("Expected 2 or 3 arguments to @memoize_on") - end - postprocess = length(args) == 1 ? :((r, _) -> r) : args[1] - def = length(args) == 1 ? args[1] : args[2] - idmap_def = _memoize_on(tree, postprocess, def) - - return quote - $(esc(def)) # The normal function - $(esc(idmap_def)) # The function with an id_map argument - end -end -function _memoize_on(tree::Symbol, postprocess, def) - sdef = splitdef(def) - - # Add an id_map argument - push!(sdef[:args], :(id_map::AbstractDict)) - - f_name = sdef[:name] - - # Forward id_map argument to all calls of the same function - # within the function body: - sdef[:body] = postwalk(sdef[:body]) do ex - if @capture(ex, f_(args__)) - if f == f_name - return Expr(:call, f, args..., :id_map) - end - end - return ex - end - - # Wrap the function body in a get!(id_map, tree) do ... end block: - @gensym key is_memoized result body - sdef[:body] = quote - $key = objectid($tree) - $is_memoized = haskey(id_map, $key) - function $body() - return $(sdef[:body]) - end - $result = if $is_memoized - @inbounds(id_map[$key]) - else - id_map[$key] = $body() - end - return $postprocess($result, $is_memoized) - end - - return combinedef(sdef) -end - -""" - @with_memoize(call, id_map) - -This simple macro simply puts the `id_map` -into the call, to be consistent with the `@memoize_on` macro. - -``` -@with_memoize(_copy_node(tree), IdDict{Any,Any}()) -```` - -is converted to - -``` -_copy_node(tree, IdDict{Any,Any}()) -``` - -""" -macro with_memoize(def, id_map) - idmap_def = _add_idmap_to_call(def, id_map) - return quote - $(esc(idmap_def)) - end -end - -function _add_idmap_to_call(def::Expr, id_map::Union{Symbol,Expr}) - @assert def.head == :call - return Expr(:call, def.args[1], def.args[2:end]..., id_map) -end - @inline function fill_similar(value::T, array, args...) where {T} out_array = similar(array, args...) fill!(out_array, value) diff --git a/src/base.jl b/src/base.jl index 7d0c0041..32d67fd6 100644 --- a/src/base.jl +++ b/src/base.jl @@ -25,7 +25,7 @@ import Base: using DispatchDoctor: @unstable using Compat: @inline, Returns -using ..UtilsModule: @memoize_on, @with_memoize, Undefined +using ..UtilsModule: Undefined """ tree_mapreduce( @@ -89,38 +89,76 @@ function tree_mapreduce( f_leaf::F1, f_branch::F2, op::G, - tree::AbstractNode, + tree::AbstractNode{D}, result_type::Type{RT}=Undefined; f_on_shared::H=(result, is_shared) -> result, - break_sharing::Val=Val(false), -) where {F1<:Function,F2<:Function,G<:Function,H<:Function,RT} - - # Trick taken from here: - # https://discourse.julialang.org/t/recursive-inner-functions-a-thousand-times-slower/85604/5 - # to speed up recursive closure - @memoize_on t f_on_shared function inner(inner, t) - if t.degree == 0 - return @inline(f_leaf(t)) - elseif t.degree == 1 - return @inline(op(@inline(f_branch(t)), inner(inner, t.l))) - else - return @inline(op(@inline(f_branch(t)), inner(inner, t.l), inner(inner, t.r))) - end - end - - sharing = preserve_sharing(typeof(tree)) && break_sharing === Val(false) + break_sharing::Val{BS}=Val(false), +) where {F1<:Function,F2<:Function,G<:Function,D,H<:Function,RT,BS} + sharing = preserve_sharing(typeof(tree)) && !break_sharing RT == Undefined && sharing && throw(ArgumentError("Need to specify `result_type` if nodes are shared..")) if sharing && RT != Undefined - d = allocate_id_map(tree, RT) - return @with_memoize inner(inner, tree) d + id_map = allocate_id_map(tree, RT) + reducer = TreeMapreducer(Val(D), id_map, f_leaf, f_branch, op, f_on_shared) + return reducer(tree) + else + reducer = TreeMapreducer(Val(D), nothing, f_leaf, f_branch, op, f_on_shared) + return reducer(tree) + end +end + +struct TreeMapreducer{D,ID,F1<:Function,F2<:Function,G<:Function,H<:Function} + max_degree::Val{D} + id_map::ID + f_leaf::F1 + f_branch::F2 + op::G + f_on_shared::H +end + +@generated function (mapreducer::TreeMapreducer{MAX_DEGREE,ID})( + tree::AbstractNode +) where {MAX_DEGREE,ID} + base_expr = quote + d = tree.degree + Base.Cartesian.@nif( + $(MAX_DEGREE + 1), + d_p_one -> (d_p_one - 1) == d, + d_p_one -> if d_p_one == 1 + mapreducer.f_leaf(tree) + else + mapreducer.op( + mapreducer.f_branch(tree), + Base.Cartesian.@ntuple( + d_p_one - 1, i -> mapreducer(tree.children[i][]) + )..., + ) + end + ) + end + if ID <: Nothing + # No sharing of nodes (is a tree, not a graph) + return base_expr else - return inner(inner, tree) + # Otherwise, we need to cache results in `id_map` + # according to `objectid` of the node + return quote + key = objectid(tree) + is_cached = haskey(mapreducer.id_map, key) + if is_cached + return mapreducer.f_on_shared(@inbounds(mapreducer.id_map[key]), true) + else + res = $base_expr + mapreducer.id_map[key] = res + return mapreducer.f_on_shared(res, false) + end + end end end + function allocate_id_map(tree::AbstractNode, ::Type{RT}) where {RT} d = Dict{UInt,RT}() # Preallocate maximum storage (counting with duplicates is fast) @@ -128,7 +166,6 @@ function allocate_id_map(tree::AbstractNode, ::Type{RT}) where {RT} sizehint!(d, N) return d end - # TODO: Raise Julia issue for this. # Surprisingly Dict{UInt,RT} is faster than IdDict{Node{T},RT} here! # I think it's because `setindex!` is declared with `@nospecialize` in IdDict. diff --git a/test/test_graphs.jl b/test/test_graphs.jl index 2f31c4ed..c25a3ab6 100644 --- a/test/test_graphs.jl +++ b/test/test_graphs.jl @@ -120,76 +120,6 @@ end @test expr_eql(ex, true_ex) end - - @testset "@memoize_on" begin - ex = @macroexpand DynamicExpressions.UtilsModule.@memoize_on tree ((x, _) -> x) function _copy_node( - tree::Node{T} - )::Node{T} where {T} - if tree.degree == 0 - if tree.constant - Node(; val=copy(tree.val)) - else - Node(T; feature=copy(tree.feature)) - end - elseif tree.degree == 1 - Node(copy(tree.op), _copy_node(tree.l)) - else - Node(copy(tree.op), _copy_node(tree.l), _copy_node(tree.r)) - end - end - true_ex = quote - function _copy_node(tree::Node{T})::Node{T} where {T} - if tree.degree == 0 - if tree.constant - Node(; val=copy(tree.val)) - else - Node(T; feature=copy(tree.feature)) - end - elseif tree.degree == 1 - Node(copy(tree.op), _copy_node(tree.l)) - else - Node(copy(tree.op), _copy_node(tree.l), _copy_node(tree.r)) - end - end - function _copy_node(tree::Node{T}, id_map::AbstractDict;)::Node{T} where {T} - key = objectid(tree) - is_memoized = haskey(id_map, key) - function body() - return begin - if tree.degree == 0 - if tree.constant - Node(; val=copy(tree.val)) - else - Node(T; feature=copy(tree.feature)) - end - elseif tree.degree == 1 - Node(copy(tree.op), _copy_node(tree.l, id_map)) - else - Node( - copy(tree.op), - _copy_node(tree.l, id_map), - _copy_node(tree.r, id_map), - ) - end - end - end - result = if is_memoized - begin - $(Expr(:inbounds, true)) - local val = id_map[key] - $(Expr(:inbounds, :pop)) - val - end - else - id_map[key] = body() - end - return (((x, _) -> begin - x - end)(result, is_memoized)) - end - end - @test expr_eql(ex, true_ex) - end end @testset "Operations on graphs" begin From 2a0bd054578c88f47b70b61ad0141e40c8e6ce47 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Mon, 8 Jul 2024 22:11:29 +0100 Subject: [PATCH 06/16] fix: various aspects of degree interface --- src/DynamicExpressions.jl | 1 + src/Node.jl | 2 +- src/NodeUtils.jl | 11 +++++++++++ src/base.jl | 2 +- test/test_base.jl | 4 ++-- test/test_custom_node_type.jl | 22 +++++++++++++--------- test/test_equality.jl | 4 ++-- test/test_extra_node_fields.jl | 25 ++++++++++++++++--------- test/test_graphs.jl | 13 +------------ test/test_parse.jl | 4 ++-- 10 files changed, 50 insertions(+), 38 deletions(-) diff --git a/src/DynamicExpressions.jl b/src/DynamicExpressions.jl index c2d71489..b9aa4b4f 100644 --- a/src/DynamicExpressions.jl +++ b/src/DynamicExpressions.jl @@ -47,6 +47,7 @@ import .NodeModule: constructorof, with_type_parameters, preserve_sharing, + max_degree, leaf_copy, branch_copy, leaf_hash, diff --git a/src/Node.jl b/src/Node.jl index 0d1ab048..275fb496 100644 --- a/src/Node.jl +++ b/src/Node.jl @@ -219,7 +219,7 @@ end # with_degree(::Type{N}, ::Val{D}) where {T,N<:Node{T},D} = Node{T,D} # with_degree(::Type{N}, ::Val{D}) where {T,N<:GraphNode{T},D} = GraphNode{T,D} -function default_allocator(::Type{N}, ::Type{T}) where {N<:Union{Node,GraphNode},T} +function default_allocator(::Type{N}, ::Type{T}) where {N<:AbstractExpressionNode,T} return with_type_parameters(N, T)() end diff --git a/src/NodeUtils.jl b/src/NodeUtils.jl index 87c4ba60..392df3b7 100644 --- a/src/NodeUtils.jl +++ b/src/NodeUtils.jl @@ -165,6 +165,17 @@ mutable struct NodeIndex{T,D} <: AbstractNode{D} end NodeIndex(::Type{T}, ::Val{D}) where {T,D} = NodeIndex(T, Val(D), zero(T)) +@inline function Base.getproperty(n::NodeIndex, k::Symbol) + if k == :l + # TODO: Should a depwarn be raised here? Or too slow? + return getfield(n, :children)[1][] + elseif k == :r + return getfield(n, :children)[2][] + else + return getfield(n, k) + end +end + # Sharing is never needed for NodeIndex, # as we trace over the node we are indexing on. preserve_sharing(::Union{Type{<:NodeIndex},NodeIndex}) = false diff --git a/src/base.jl b/src/base.jl index 32d67fd6..6f2a8fbd 100644 --- a/src/base.jl +++ b/src/base.jl @@ -94,7 +94,7 @@ function tree_mapreduce( f_on_shared::H=(result, is_shared) -> result, break_sharing::Val{BS}=Val(false), ) where {F1<:Function,F2<:Function,G<:Function,D,H<:Function,RT,BS} - sharing = preserve_sharing(typeof(tree)) && !break_sharing + sharing = preserve_sharing(typeof(tree)) && !BS RT == Undefined && sharing && diff --git a/test/test_base.jl b/test/test_base.jl index b14894b1..f7e7a483 100644 --- a/test/test_base.jl +++ b/test/test_base.jl @@ -32,11 +32,11 @@ end @testset "collect" begin ctree = copy(tree) - @test typeof(first(collect(ctree))) == Node{Float64} + @test typeof(first(collect(ctree))) <: Node{Float64} @test objectid(first(collect(ctree))) == objectid(ctree) @test objectid(first(collect(ctree))) == objectid(ctree) @test objectid(first(collect(ctree))) == objectid(ctree) - @test typeof(collect(ctree)) == Vector{Node{Float64}} + @test typeof(collect(ctree)) <: Vector{<:Node{Float64}} @test length(collect(ctree)) == 24 @test sum((t -> (t.degree == 0 && t.constant) ? t.val : 0.0).(collect(ctree))) ≈ 11.6 end diff --git a/test/test_custom_node_type.jl b/test/test_custom_node_type.jl index 3fc333bc..57a3706c 100644 --- a/test/test_custom_node_type.jl +++ b/test/test_custom_node_type.jl @@ -1,16 +1,21 @@ using DynamicExpressions using Test -mutable struct MyCustomNode{A,B} <: AbstractNode +mutable struct MyCustomNode{A,B} <: AbstractNode{2} degree::Int val1::A val2::B - l::MyCustomNode{A,B} - r::MyCustomNode{A,B} + children::NTuple{2,Base.RefValue{MyCustomNode{A,B}}} MyCustomNode(val1, val2) = new{typeof(val1),typeof(val2)}(0, val1, val2) - MyCustomNode(val1, val2, l) = new{typeof(val1),typeof(val2)}(1, val1, val2, l) - MyCustomNode(val1, val2, l, r) = new{typeof(val1),typeof(val2)}(2, val1, val2, l, r) + function MyCustomNode(val1, val2, l) + return new{typeof(val1),typeof(val2)}( + 1, val1, val2, (Ref(l), Ref{MyCustomNode{typeof(val1),typeof(val2)}}()) + ) + end + function MyCustomNode(val1, val2, l, r) + return new{typeof(val1),typeof(val2)}(2, val1, val2, (Ref(l), Ref(r))) + end end node1 = MyCustomNode(1.0, 2) @@ -24,7 +29,7 @@ node2 = MyCustomNode(1.5, 3, node1) @test typeof(node2) == MyCustomNode{Float64,Int} @test node2.degree == 1 -@test node2.l.degree == 0 +@test node2.children[1][].degree == 0 @test count_depth(node2) == 2 @test count_nodes(node2) == 2 @@ -37,14 +42,13 @@ node2 = MyCustomNode(1.5, 3, node1, node1) @test count(t -> t.degree == 0, node2) == 2 # If we have a bad definition, it should get caught with a helpful message -mutable struct MyCustomNode2{T} <: AbstractExpressionNode{T} +mutable struct MyCustomNode2{T} <: AbstractExpressionNode{T,2} degree::UInt8 constant::Bool val::T feature::UInt16 op::UInt8 - l::MyCustomNode2{T} - r::MyCustomNode2{T} + children::NTuple{2,Base.RefValue{MyCustomNode2{T}}} end @test_throws ErrorException MyCustomNode2() diff --git a/test/test_equality.jl b/test/test_equality.jl index 220e63c3..7e9b845b 100644 --- a/test/test_equality.jl +++ b/test/test_equality.jl @@ -45,8 +45,8 @@ modified_tree5 = 1.5 * cos(x2 * x1) + x1 + x2 * x3 - log(x2 * 3.2) f64_tree = GraphNode{Float64}(x1 + x2 * x3 - log(x2 * 3.0) + 1.5 * cos(x2 / x1)) f32_tree = GraphNode{Float32}(x1 + x2 * x3 - log(x2 * 3.0) + 1.5 * cos(x2 / x1)) -@test typeof(f64_tree) == GraphNode{Float64} -@test typeof(f32_tree) == GraphNode{Float32} +@test typeof(f64_tree) <: GraphNode{Float64} +@test typeof(f32_tree) <: GraphNode{Float32} @test convert(GraphNode{Float64}, f32_tree) == f64_tree diff --git a/test/test_extra_node_fields.jl b/test/test_extra_node_fields.jl index 467c6226..60b35595 100644 --- a/test/test_extra_node_fields.jl +++ b/test/test_extra_node_fields.jl @@ -2,24 +2,31 @@ using Test using DynamicExpressions -using DynamicExpressions: constructorof +using DynamicExpressions: constructorof, max_degree -mutable struct FrozenNode{T} <: AbstractExpressionNode{T} +mutable struct FrozenNode{T,D} <: AbstractExpressionNode{T,D} degree::UInt8 constant::Bool val::T frozen::Bool # Extra field! feature::UInt16 op::UInt8 - l::FrozenNode{T} - r::FrozenNode{T} + children::NTuple{D,Base.RefValue{FrozenNode{T,D}}} - function FrozenNode{_T}() where {_T} - n = new{_T}() + function FrozenNode{_T,_D}() where {_T,_D} + n = new{_T,_D}() n.frozen = false return n end end +function DynamicExpressions.constructorof(::Type{N}) where {N<:FrozenNode} + return FrozenNode{T,max_degree(N)} where {T} +end +function DynamicExpressions.with_type_parameters( + ::Type{N}, ::Type{T} +) where {T,N<:FrozenNode} + return FrozenNode{T,max_degree(N)} +end function DynamicExpressions.leaf_copy(t::FrozenNode{T}) where {T} out = if t.constant constructorof(typeof(t))(; val=t.val) @@ -56,7 +63,7 @@ function DynamicExpressions.leaf_equal(a::FrozenNode, b::FrozenNode) end end -n = let n = FrozenNode{Float64}() +n = let n = FrozenNode{Float64,2}() n.degree = 0 n.constant = true n.val = 0.0 @@ -92,5 +99,5 @@ ex = parse_expression( @test string_tree(ex) == "x + sin(y + 2.1)" @test ex.tree.frozen == false -@test ex.tree.r.frozen == true -@test ex.tree.r.l.frozen == false +@test ex.tree.children[2][].frozen == true +@test ex.tree.children[2][].children[1][].frozen == false diff --git a/test/test_graphs.jl b/test/test_graphs.jl index c25a3ab6..55ab4d79 100644 --- a/test/test_graphs.jl +++ b/test/test_graphs.jl @@ -109,17 +109,6 @@ end :(_convert(Node{T1}, tree, IdDict{Node{T2},Node{T1}}())), ) end - - @testset "@with_memoize" begin - ex = @macroexpand DynamicExpressions.UtilsModule.@with_memoize( - _convert(Node{T1}, tree), IdDict{Node{T2},Node{T1}}() - ) - true_ex = quote - _convert(Node{T1}, tree, IdDict{Node{T2},Node{T1}}()) - end - - @test expr_eql(ex, true_ex) - end end @testset "Operations on graphs" begin @@ -283,7 +272,7 @@ end x = GraphNode(Float32; feature=1) tree = x + 1.0 @test tree.l === x - @test typeof(tree) === GraphNode{Float32} + @test typeof(tree) <: GraphNode{Float32} # Detect error from Float32(1im) @test_throws InexactError x + 1im diff --git a/test/test_parse.jl b/test/test_parse.jl index c9b40d0c..8d9c351d 100644 --- a/test/test_parse.jl +++ b/test/test_parse.jl @@ -108,7 +108,7 @@ end variable_names = ["x"], ) - @test typeof(ex.tree) === Node{Any} + @test typeof(ex.tree) <: Node{Any} @test typeof(ex.metadata.operators) <: GenericOperatorEnum s = sprint((io, e) -> show(io, MIME("text/plain"), e), ex) @test s == "[1, 2, 3] * tan(cos(5.0 + x))" @@ -184,7 +184,7 @@ end s = sprint((io, e) -> show(io, MIME("text/plain"), e), ex) @test s == "(x * 2.5) - cos(y)" end - @test contains(logged_out, "Node{Float32}") + @test contains(logged_out, "Node{Float32") end @testitem "Helpful errors for missing operator" begin From 60579fc32ce5c877cc8b16d909f0e27e23d897d8 Mon Sep 17 00:00:00 2001 From: Robert Dancer Date: Tue, 9 Jul 2024 15:58:15 +0100 Subject: [PATCH 07/16] Add topological sort --- src/DynamicExpressions.jl | 5 +++- src/Node.jl | 55 +++++++++++++++++++++++++++++++++++++-- 2 files changed, 57 insertions(+), 3 deletions(-) diff --git a/src/DynamicExpressions.jl b/src/DynamicExpressions.jl index a708dd9d..2988757b 100644 --- a/src/DynamicExpressions.jl +++ b/src/DynamicExpressions.jl @@ -33,7 +33,9 @@ macro ignore(args...) end set_node!, tree_mapreduce, filter_map, - filter_map! + filter_map!, + topological_sort, + randomised_topological_sort import .NodeModule: constructorof, with_type_parameters, @@ -71,6 +73,7 @@ import .ExpressionModule: get_tree, get_operators, get_variable_names, Metadata @reexport import .ParseModule: @parse_expression, parse_expression import .ParseModule: parse_leaf + function __init__() @require_extensions end diff --git a/src/Node.jl b/src/Node.jl index 21f32e59..8c2747db 100644 --- a/src/Node.jl +++ b/src/Node.jl @@ -2,6 +2,7 @@ module NodeModule import ..OperatorEnumModule: AbstractOperatorEnum import ..UtilsModule: @memoize_on, @with_memoize, deprecate_varmap, Undefined +using Random: default_rng, AbstractRNG const DEFAULT_NODE_TYPE = Float32 @@ -163,14 +164,15 @@ when constructing or setting properties. mutable struct GraphNode{T} <: AbstractExpressionNode{T} degree::UInt8 # 0 for constant/variable, 1 for cos/sin, 2 for +/* etc. constant::Bool # false if variable - val::T # If is a constant, this stores the actual value + val::T # If is a constant, this stores the actual value, otherwise stores calculated values during evaluation # ------------------- (possibly undefined below) feature::UInt16 # If is a variable (e.g., x in cos(x)), this stores the feature index. op::UInt8 # If operator, this is the index of the operator in operators.binops, or operators.unaops l::GraphNode{T} # Left child node. Only defined for degree=1 or degree=2. r::GraphNode{T} # Right child node. Only defined for degree=2. + visited::Bool # used in dfs toposort - GraphNode{_T}() where {_T} = new{_T}() + GraphNode{_T}() where {_T} = (x = new{_T}(); x.visited = false; x) end ################################################################################ @@ -358,4 +360,53 @@ function set_node!(tree::AbstractExpressionNode, new_tree::AbstractExpressionNod return nothing end +"""Topological sort of the graph following a depth-first search""" +function topological_sort(graph::GraphNode{T}) where {T} + order = Vector{GraphNode{T}}() + _rec_toposort(graph, order) + for node in order + node.visited = false + end + return order +end + +"""Topological sort of the graph following a randomised depth-first search""" +function randomised_topological_sort(graph::GraphNode{T}, rng::AbstractRNG=default_rng()) where {T} + order = Vector{GraphNode{T}}() + _rec_randomised_toposort(graph, order, rng) + for node in order + node.visited = false + end + return order +end + +function _rec_toposort(gnode::GraphNode{T}, order::Vector{GraphNode{T}}) where {T} + if gnode.visited return end + gnode.visited = true + if gnode.degree == 1 + _rec_toposort(gnode.l, order) + elseif gnode.degree == 2 + _rec_toposort(gnode.l, order) + _rec_toposort(gnode.r, order) + end + push!(order, gnode) end + +function _rec_randomised_toposort(gnode::GraphNode{T}, order::Vector{GraphNode{T}}, rng::AbstractRNG) where {T} + if gnode.visited return end + gnode.visited = true + if gnode.degree == 1 + _rec_randomised_toposort(gnode.l, order, rng) + elseif gnode.degree == 2 + if rand(rng, Bool) + _rec_randomised_toposort(gnode.l, order, rng) + _rec_randomised_toposort(gnode.r, order, rng) + else + _rec_randomised_toposort(gnode.r, order, rng) + _rec_randomised_toposort(gnode.l, order, rng) + end + end + push!(order, gnode) +end + +end \ No newline at end of file From 516a130a5fc0586c672d310e45fc8706a2783b73 Mon Sep 17 00:00:00 2001 From: Robert Dancer Date: Wed, 10 Jul 2024 09:02:00 +0100 Subject: [PATCH 08/16] wip --- src/Node.jl | 45 +++++++++++++++++++++++++++++++-------------- 1 file changed, 31 insertions(+), 14 deletions(-) diff --git a/src/Node.jl b/src/Node.jl index 30578db2..ebf5061c 100644 --- a/src/Node.jl +++ b/src/Node.jl @@ -4,6 +4,7 @@ using DispatchDoctor: @unstable import ..OperatorEnumModule: AbstractOperatorEnum import ..UtilsModule: deprecate_varmap, Undefined +using Random: default_rng, AbstractRNG const DEFAULT_NODE_TYPE = Float32 @@ -73,20 +74,34 @@ You likely do not need to, but you could choose to override the following: """ abstract type AbstractExpressionNode{T,D} <: AbstractNode{D} end -for N in (:Node, :GraphNode) - @eval mutable struct $N{T,D} <: AbstractExpressionNode{T,D} - degree::UInt8 # 0 for constant/variable, 1 for cos/sin, 2 for +/* etc. - constant::Bool # false if variable - val::T # If is a constant, this stores the actual value - feature::UInt16 # (Possibly undefined) If is a variable (e.g., x in cos(x)), this stores the feature index. - op::UInt8 # (Possibly undefined) If operator, this is the index of the operator in the degree-specific operator enum - children::NTuple{D,Base.RefValue{$N{T,D}}} # Children nodes - - ################# - ## Constructors: - ################# - $N{_T,_D}() where {_T,_D} = new{_T,_D::Int}() - end +mutable struct Node{T,D} <: AbstractExpressionNode{T,D} + degree::UInt8 # 0 for constant/variable, 1 for cos/sin, 2 for +/* etc. + constant::Bool # false if variable + val::T # If is a constant, this stores the actual value + feature::UInt16 # (Possibly undefined) If is a variable (e.g., x in cos(x)), this stores the feature index. + op::UInt8 # (Possibly undefined) If operator, this is the index of the operator in the degree-specific operator enum + children::NTuple{D,Base.RefValue{Node{T,D}}} # Children nodes + + ################# + ## Constructors: + ################# + #Node{_T,_D}() where {_T,_D} = new{_T,_D::Int}() + Node{_T,_D}() where {_T,_D} = (x = new{_T,_D::Int}(); x.children = ntuple(i -> Ref{Node{_T,_D}}(), Val(max_degree(Node))); x) +end + +mutable struct GraphNode{T,D} <: AbstractExpressionNode{T,D} + degree::UInt8 # 0 for constant/variable, 1 for cos/sin, 2 for +/* etc. + constant::Bool # false if variable + val::T # If is a constant, this stores the actual value + feature::UInt16 # (Possibly undefined) If is a variable (e.g., x in cos(x)), this stores the feature index. + op::UInt8 # (Possibly undefined) If operator, this is the index of the operator in the degree-specific operator enum + children::NTuple{D,Base.RefValue{GraphNode{T,D}}} # Children nodes + visited::Bool # search accounting, initialised to false + + ################# + ## Constructors: + ################# + GraphNode{_T,_D}() where {_T,_D} = (x = new{_T,_D::Int}(); x.visited = false; x.children = ntuple(i -> Ref{GraphNode{_T,_D}}(), Val(max_degree(GraphNode))); x) end #! format: off @@ -193,6 +208,8 @@ end setfield!(n, :val, convert(eltype(n), v)) elseif k == :children setfield!(n, :children, v) + elseif k == :visited && typeof(n) <: GraphNode + setfield!(n, :visited, v) else error("Invalid property: $k") end From a83bbe315c60f305899221aed05bb340c8810340 Mon Sep 17 00:00:00 2001 From: Robert Dancer Date: Sat, 13 Jul 2024 14:42:14 +0100 Subject: [PATCH 09/16] Fix toposort vector bug --- src/Node.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/Node.jl b/src/Node.jl index ebf5061c..e075b0a8 100644 --- a/src/Node.jl +++ b/src/Node.jl @@ -394,8 +394,8 @@ function set_node!(tree::AbstractExpressionNode, new_tree::AbstractExpressionNod end """Topological sort of the graph following a depth-first search""" -function topological_sort(graph::GraphNode{T}) where {T} - order = Vector{GraphNode{T}}() +function topological_sort(graph::GraphNode) + order = Vector{GraphNode}() _rec_toposort(graph, order) for node in order node.visited = false @@ -404,8 +404,8 @@ function topological_sort(graph::GraphNode{T}) where {T} end """Topological sort of the graph following a randomised depth-first search""" -function randomised_topological_sort(graph::GraphNode{T}, rng::AbstractRNG=default_rng()) where {T} - order = Vector{GraphNode{T}}() +function randomised_topological_sort(graph::GraphNode, rng::AbstractRNG=default_rng()) + order = Vector{GraphNode}() _rec_randomised_toposort(graph, order, rng) for node in order node.visited = false @@ -413,7 +413,7 @@ function randomised_topological_sort(graph::GraphNode{T}, rng::AbstractRNG=defau return order end -function _rec_toposort(gnode::GraphNode{T}, order::Vector{GraphNode{T}}) where {T} +function _rec_toposort(gnode::GraphNode, order::Vector{GraphNode}) if gnode.visited return end gnode.visited = true if gnode.degree == 1 @@ -425,7 +425,7 @@ function _rec_toposort(gnode::GraphNode{T}, order::Vector{GraphNode{T}}) where { push!(order, gnode) end -function _rec_randomised_toposort(gnode::GraphNode{T}, order::Vector{GraphNode{T}}, rng::AbstractRNG) where {T} +function _rec_randomised_toposort(gnode::GraphNode, order::Vector{GraphNode}, rng::AbstractRNG) if gnode.visited return end gnode.visited = true if gnode.degree == 1 From 8f146cd5942da5068ba5bd27dcec5ed3d322accb Mon Sep 17 00:00:00 2001 From: Robert Dancer Date: Mon, 15 Jul 2024 00:54:45 +0100 Subject: [PATCH 10/16] wip new evaluator --- src/DynamicExpressions.jl | 2 +- src/Evaluate.jl | 71 ++++++++++++++++++++++++++++++++++++++- 2 files changed, 71 insertions(+), 2 deletions(-) diff --git a/src/DynamicExpressions.jl b/src/DynamicExpressions.jl index 0d2c8e9b..17c78c70 100644 --- a/src/DynamicExpressions.jl +++ b/src/DynamicExpressions.jl @@ -71,7 +71,7 @@ import .NodeModule: @reexport import .OperatorEnumModule: AbstractOperatorEnum @reexport import .OperatorEnumConstructionModule: OperatorEnum, GenericOperatorEnum, @extend_operators, set_default_variable_names! -@reexport import .EvaluateModule: eval_tree_array, differentiable_eval_tree_array +@reexport import .EvaluateModule: eval_tree_array, differentiable_eval_tree_array, eval_tree_array_graph @reexport import .EvaluateDerivativeModule: eval_diff_tree_array, eval_grad_tree_array @reexport import .ChainRulesModule: NodeTangent, extract_gradient @reexport import .SimplifyModule: combine_operators, simplify_tree! diff --git a/src/Evaluate.jl b/src/Evaluate.jl index 88999e62..4d62ecee 100644 --- a/src/Evaluate.jl +++ b/src/Evaluate.jl @@ -2,7 +2,7 @@ module EvaluateModule using DispatchDoctor: @unstable -import ..NodeModule: AbstractExpressionNode, constructorof +import ..NodeModule: AbstractExpressionNode, constructorof, GraphNode, topological_sort import ..StringsModule: string_tree import ..OperatorEnumModule: OperatorEnum, GenericOperatorEnum import ..UtilsModule: fill_similar, counttuple, ResultOk @@ -769,4 +769,73 @@ end end end +# Parametric arguments don't use dynamic dispatch, not all calls will resolve properly + +""" +function eval_tree_array_graph( + graph::GraphNode{T}, + cX::AbstractMatrix{T}, + operators::OperatorEnum; + turbo::Val{false}=Val(false), + bumper::Val{false}=Val(false) +) where {T} + order = topological_sort(graph) + res = Vector{T}(undef, size(cX, 2)) + @inbounds for sampleindex in axes(cX, 2) + @inbounds for node in order + if node.degree != 0 || !node.constant + if node.degree == 0 && !node.constant + node.val = cX[node.feature, sampleindex] + elseif node.degree == 1 + node.val = operators.unaops[node.op](node.children[1][].val) + elseif node.degree == 2 + node.val = operators.binops[node.op](node.children[1][].val, node.children[2][].val) + else + error("n-ary operator evaluation not implemented") + end + end + if !is_valid(node.val) + return (res, false) + end + end + res[sampleindex] = last(order).val + end + return (res, is_valid_array(res)) end +""" + +function eval_tree_array_graph( + node::GraphNode{T}, + cX::AbstractMatrix{T}, + operators::OperatorEnum +) where {T} + if node.degree == 0 + if node.constant + return fill(node.val, axes(cX, 2)) + else + return cX[node.feature, :] + end + #elseif node.degree == 1 + # return map(x -> operators.unaops[node.op](x), eval_tree_array_graph(node.l, cX, operators)) + #else + # return map(tp -> operators.binops[node.op](tp...), zip(eval_tree_array_graph(node.l, cX, operators), eval_tree_array_graph(node.r, cX, operators))) + #end + elseif node.degree == 1 + cl = eval_tree_array_graph(node.l, cX, operators) + op = operators.unaops[node.op] + @inbounds @simd for j in eachindex(cl) + cl[j] = op(cl[j])::T + end + return cl + else + cl = eval_tree_array_graph(node.l, cX, operators) + cr = eval_tree_array_graph(node.r, cX, operators) + op = operators.binops[node.op] + @inbounds @simd for j in eachindex(cl) + cl[j] = op(cl[j], cr[j])::T + end + return cl + end +end + +end \ No newline at end of file From 2cb14b3fd31cbae28b19bec66f59172f1a857c08 Mon Sep 17 00:00:00 2001 From: Robert Dancer Date: Tue, 16 Jul 2024 16:16:10 +0100 Subject: [PATCH 11/16] working eval_tree_array_graph, make loopvectorization a full dependency --- Project.toml | 5 +-- src/Evaluate.jl | 100 +++++++++++++++++++++--------------------------- src/Node.jl | 3 ++ 3 files changed, 48 insertions(+), 60 deletions(-) diff --git a/Project.toml b/Project.toml index 3f5df1a5..ed2a6f9a 100644 --- a/Project.toml +++ b/Project.toml @@ -8,6 +8,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" DispatchDoctor = "8d63f2c5-f18a-4cf2-ba9d-b3f60fc568c8" Interfaces = "85a1e053-f937-4924-92a5-1367d23b7b87" +LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" @@ -17,14 +18,12 @@ TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76" [weakdeps] Bumper = "8ce10254-0962-460f-a3d8-1f77fea1446e" -LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890" Optim = "429524aa-4258-5aef-a3af-852621145aeb" SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] DynamicExpressionsBumperExt = "Bumper" -DynamicExpressionsLoopVectorizationExt = "LoopVectorization" DynamicExpressionsOptimExt = "Optim" DynamicExpressionsSymbolicUtilsExt = "SymbolicUtils" DynamicExpressionsZygoteExt = "Zygote" @@ -35,7 +34,6 @@ ChainRulesCore = "1" Compat = "3.37, 4" DispatchDoctor = "0.4" Interfaces = "0.3" -LoopVectorization = "0.12" MacroTools = "0.4, 0.5" Optim = "0.19, 1" PackageExtensionCompat = "1" @@ -47,7 +45,6 @@ julia = "1.6" [extras] Bumper = "8ce10254-0962-460f-a3d8-1f77fea1446e" -LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890" Optim = "429524aa-4258-5aef-a3af-852621145aeb" SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/src/Evaluate.jl b/src/Evaluate.jl index 4d62ecee..e3dd4229 100644 --- a/src/Evaluate.jl +++ b/src/Evaluate.jl @@ -1,6 +1,7 @@ module EvaluateModule using DispatchDoctor: @unstable +using LoopVectorization import ..NodeModule: AbstractExpressionNode, constructorof, GraphNode, topological_sort import ..StringsModule: string_tree @@ -769,72 +770,59 @@ end end end -# Parametric arguments don't use dynamic dispatch, not all calls will resolve properly +# Parametric arguments don't use dynamic dispatch, calls with turbo/bumper won't resolve properly -""" function eval_tree_array_graph( - graph::GraphNode{T}, + root::GraphNode{T}, cX::AbstractMatrix{T}, - operators::OperatorEnum; - turbo::Val{false}=Val(false), - bumper::Val{false}=Val(false) + operators::OperatorEnum, ) where {T} - order = topological_sort(graph) - res = Vector{T}(undef, size(cX, 2)) - @inbounds for sampleindex in axes(cX, 2) - @inbounds for node in order - if node.degree != 0 || !node.constant - if node.degree == 0 && !node.constant - node.val = cX[node.feature, sampleindex] - elseif node.degree == 1 - node.val = operators.unaops[node.op](node.children[1][].val) - elseif node.degree == 2 - node.val = operators.binops[node.op](node.children[1][].val, node.children[2][].val) + + # vmap is faster with small cX sizes + # vmapnt (non-temporal) is faster with larger cX sizes (too big so not worth caching?) + + order = topological_sort(root) + for node in order + if node.degree == 0 && !node.constant + node.cache = view(cX, node.feature, :) + elseif node.degree == 1 + if node.l.constant + node.constant = true + node.val = operators.unaops[node.op](node.l.val) + if !is_valid(node.val) return ResultOk(Vector{T}(undef, size(cX, 2)), false) end + else + node.constant = false + node.cache = vmapnt(operators.unaops[node.op], node.l.cache) + if !is_valid_array(node.cache) return ResultOk(node.cache, false) end + end + elseif node.degree == 2 + if node.l.constant + if node.r.constant + node.constant = true + node.val = operators.binops[node.op](node.l.val, node.r.val) + if !is_valid(node.val) return ResultOk(Vector{T}(undef, size(cX, 2)), false) end else - error("n-ary operator evaluation not implemented") + node.constant = false + node.cache = vmapnt(Base.Fix1(operators.binops[node.op], node.l.val), node.r.cache) + if !is_valid_array(node.cache) return ResultOk(node.cache, false) end + end + else + if node.r.constant + node.constant = false + node.cache = vmapnt(Base.Fix2(operators.binops[node.op], node.r.val), node.l.cache) + if !is_valid_array(node.cache) return ResultOk(node.cache, false) end + else + node.constant = false + node.cache = vmapnt(operators.binops[node.op], node.l.cache, node.r.cache) + if !is_valid_array(node.cache) return ResultOk(node.cache, false) end end - end - if !is_valid(node.val) - return (res, false) end end - res[sampleindex] = last(order).val end - return (res, is_valid_array(res)) -end -""" - -function eval_tree_array_graph( - node::GraphNode{T}, - cX::AbstractMatrix{T}, - operators::OperatorEnum -) where {T} - if node.degree == 0 - if node.constant - return fill(node.val, axes(cX, 2)) - else - return cX[node.feature, :] - end - #elseif node.degree == 1 - # return map(x -> operators.unaops[node.op](x), eval_tree_array_graph(node.l, cX, operators)) - #else - # return map(tp -> operators.binops[node.op](tp...), zip(eval_tree_array_graph(node.l, cX, operators), eval_tree_array_graph(node.r, cX, operators))) - #end - elseif node.degree == 1 - cl = eval_tree_array_graph(node.l, cX, operators) - op = operators.unaops[node.op] - @inbounds @simd for j in eachindex(cl) - cl[j] = op(cl[j])::T - end - return cl + if root.constant + return ResultOk(fill(root.val, size(cX, 2)), true) else - cl = eval_tree_array_graph(node.l, cX, operators) - cr = eval_tree_array_graph(node.r, cX, operators) - op = operators.binops[node.op] - @inbounds @simd for j in eachindex(cl) - cl[j] = op(cl[j], cr[j])::T - end - return cl + return ResultOk(root.cache, true) end end diff --git a/src/Node.jl b/src/Node.jl index e075b0a8..9e9d261f 100644 --- a/src/Node.jl +++ b/src/Node.jl @@ -97,6 +97,7 @@ mutable struct GraphNode{T,D} <: AbstractExpressionNode{T,D} op::UInt8 # (Possibly undefined) If operator, this is the index of the operator in the degree-specific operator enum children::NTuple{D,Base.RefValue{GraphNode{T,D}}} # Children nodes visited::Bool # search accounting, initialised to false + cache::AbstractArray{T} ################# ## Constructors: @@ -210,6 +211,8 @@ end setfield!(n, :children, v) elseif k == :visited && typeof(n) <: GraphNode setfield!(n, :visited, v) + elseif k == :cache && typeof(n) <: GraphNode + setfield!(n, :cache, v) else error("Invalid property: $k") end From 7330be7d6446fba0a328e94cf9c35f645aa9688f Mon Sep 17 00:00:00 2001 From: Robert Dancer Date: Thu, 18 Jul 2024 22:01:51 +0100 Subject: [PATCH 12/16] Add graph visualizer --- Project.toml | 2 ++ src/DynamicExpressions.jl | 4 ++- src/Evaluate.jl | 2 +- src/Visualize.jl | 76 +++++++++++++++++++++++++++++++++++++++ 4 files changed, 82 insertions(+), 2 deletions(-) create mode 100644 src/Visualize.jl diff --git a/Project.toml b/Project.toml index ed2a6f9a..7acc1bc2 100644 --- a/Project.toml +++ b/Project.toml @@ -7,10 +7,12 @@ version = "0.18.6" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" DispatchDoctor = "8d63f2c5-f18a-4cf2-ba9d-b3f60fc568c8" +GraphRecipes = "bd48cda9-67a9-57be-86fa-5b3c104eda73" Interfaces = "85a1e053-f937-4924-92a5-1367d23b7b87" LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930" +Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" diff --git a/src/DynamicExpressions.jl b/src/DynamicExpressions.jl index 17c78c70..5bc263f2 100644 --- a/src/DynamicExpressions.jl +++ b/src/DynamicExpressions.jl @@ -20,6 +20,7 @@ using DispatchDoctor: @stable, @unstable include("Random.jl") include("Parse.jl") include("ParametricExpression.jl") + include("Visualize.jl") end import PackageExtensionCompat: @require_extensions @@ -71,7 +72,7 @@ import .NodeModule: @reexport import .OperatorEnumModule: AbstractOperatorEnum @reexport import .OperatorEnumConstructionModule: OperatorEnum, GenericOperatorEnum, @extend_operators, set_default_variable_names! -@reexport import .EvaluateModule: eval_tree_array, differentiable_eval_tree_array, eval_tree_array_graph +@reexport import .EvaluateModule: eval_tree_array, differentiable_eval_tree_array @reexport import .EvaluateDerivativeModule: eval_diff_tree_array, eval_grad_tree_array @reexport import .ChainRulesModule: NodeTangent, extract_gradient @reexport import .SimplifyModule: combine_operators, simplify_tree! @@ -85,6 +86,7 @@ import .ExpressionModule: @reexport import .ParseModule: @parse_expression, parse_expression import .ParseModule: parse_leaf @reexport import .ParametricExpressionModule: ParametricExpression, ParametricNode +@reexport import .VisualizeModule: visualize @stable default_mode = "disable" begin include("Interfaces.jl") diff --git a/src/Evaluate.jl b/src/Evaluate.jl index e3dd4229..c71e354f 100644 --- a/src/Evaluate.jl +++ b/src/Evaluate.jl @@ -772,7 +772,7 @@ end # Parametric arguments don't use dynamic dispatch, calls with turbo/bumper won't resolve properly -function eval_tree_array_graph( +function eval_tree_array( root::GraphNode{T}, cX::AbstractMatrix{T}, operators::OperatorEnum, diff --git a/src/Visualize.jl b/src/Visualize.jl new file mode 100644 index 00000000..80e10984 --- /dev/null +++ b/src/Visualize.jl @@ -0,0 +1,76 @@ +module VisualizeModule + +using Plots, GraphRecipes +using ..NodeModule: GraphNode, topological_sort +using ..OperatorEnumModule: AbstractOperatorEnum +using ..StringsModule: get_op_name + +# visualization probably best as an extension if pulled into master + +function visualize( + graph::GraphNode, + operators::AbstractOperatorEnum, + show = true +) + @info "Generating graph visualization" + + order = reverse(topological_sort(graph)) + + # multigraph adjacency list + g = map( + node -> convert(Vector{Int64}, map( + cindex -> findfirst(x -> x === node.children[cindex][], order), + 1:node.degree + )), + order + ) + + # node labels + n = map(x -> + if x.degree == 0 + x.constant ? x.val : 'x' * string(x.feature) + elseif x.degree == 1 + join(get_op_name(operators.unaops[x.op])) + elseif x.degree == 2 + join(get_op_name(operators.binops[x.op])) + else + @warn "Can't label operator node with degree > 2" + end, + order + ) + + # edge labels (specifies parameter no.) + e = Dict{Tuple{Int64, Int64, Int64}, String}() + for (index, node) in enumerate(order) + edge_count = Dict{Int64, Int64}() # count number of edges to each child node + for cindex in 1:node.degree + order_cindex = findfirst(x -> x === node.children[cindex][], order) + get!( + e, + ( + index, # source + order_cindex, # dest + get!(edge_count, order_cindex, pop!(edge_count, order_cindex, 0)+1) # edge no. + ), + string(cindex) + ) + end + end + + # node colours + c = map(x -> x == 1 ? 2 : 1, eachindex(order)) + + return graphplot( + g, + names = n, + edgelabel = e, + nodecolor = c, + show = show, + nodeshape=:circle, + edge_label_box = false, + edgelabel_offset = 0.015, + nodesize=0.15 + ) +end + +end \ No newline at end of file From 789ae573a4d44e3753850a1f59c6d72cce5373c6 Mon Sep 17 00:00:00 2001 From: Robert Dancer Date: Thu, 18 Jul 2024 22:13:40 +0100 Subject: [PATCH 13/16] Add topological sort implementation for trees --- src/Node.jl | 40 ++++++++++++++++++++++++++++++++++++++++ src/Visualize.jl | 2 +- 2 files changed, 41 insertions(+), 1 deletion(-) diff --git a/src/Node.jl b/src/Node.jl index 9e9d261f..86a55eb4 100644 --- a/src/Node.jl +++ b/src/Node.jl @@ -445,4 +445,44 @@ function _rec_randomised_toposort(gnode::GraphNode, order::Vector{GraphNode}, rn push!(order, gnode) end + +"""Topological sort of the tree following a depth-first search""" +function topological_sort(tree::Node) + order = Vector{Node}() + _rec_toposort(tree, order) + return order +end + +"""Topological sort of the tree following a randomised depth-first search""" +function randomised_topological_sort(tree::Node, rng::AbstractRNG=default_rng()) + order = Vector{Node}() + _rec_randomised_toposort(tree, order, rng) + return order +end + +function _rec_toposort(tnode::Node, order::Vector{Node}) + if tnode.degree == 1 + _rec_toposort(tnode.l, order) + elseif tnode.degree == 2 + _rec_toposort(tnode.l, order) + _rec_toposort(tnode.r, order) + end + push!(order, tnode) +end + +function _rec_randomised_toposort(tnode::Node, order::Vector{Node}, rng::AbstractRNG) + if tnode.degree == 1 + _rec_randomised_toposort(tnode.l, order, rng) + elseif tnode.degree == 2 + if rand(rng, Bool) + _rec_randomised_toposort(tnode.l, order, rng) + _rec_randomised_toposort(tnode.r, order, rng) + else + _rec_randomised_toposort(tnode.r, order, rng) + _rec_randomised_toposort(tnode.l, order, rng) + end + end + push!(order, gnode) +end + end \ No newline at end of file diff --git a/src/Visualize.jl b/src/Visualize.jl index 80e10984..16c5e611 100644 --- a/src/Visualize.jl +++ b/src/Visualize.jl @@ -8,7 +8,7 @@ using ..StringsModule: get_op_name # visualization probably best as an extension if pulled into master function visualize( - graph::GraphNode, + graph::Union{GraphNode,Node}, # types accepted by topological_sort operators::AbstractOperatorEnum, show = true ) From d219f046f4a76ce3c7b6d4102c86f76f73ea638f Mon Sep 17 00:00:00 2001 From: Robert Dancer Date: Fri, 19 Jul 2024 12:57:32 +0100 Subject: [PATCH 14/16] imports fix --- src/Visualize.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Visualize.jl b/src/Visualize.jl index 16c5e611..82ea9653 100644 --- a/src/Visualize.jl +++ b/src/Visualize.jl @@ -1,7 +1,7 @@ module VisualizeModule using Plots, GraphRecipes -using ..NodeModule: GraphNode, topological_sort +using ..NodeModule: GraphNode, Node, topological_sort using ..OperatorEnumModule: AbstractOperatorEnum using ..StringsModule: get_op_name From 514829375d8e5fe7b6c75b7ca0027d5d05a6169e Mon Sep 17 00:00:00 2001 From: Robert Dancer Date: Wed, 31 Jul 2024 00:11:05 +0100 Subject: [PATCH 15/16] wip graph differentiation --- src/DynamicExpressions.jl | 2 +- src/Evaluate.jl | 74 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 75 insertions(+), 1 deletion(-) diff --git a/src/DynamicExpressions.jl b/src/DynamicExpressions.jl index 5bc263f2..b482a7e2 100644 --- a/src/DynamicExpressions.jl +++ b/src/DynamicExpressions.jl @@ -72,7 +72,7 @@ import .NodeModule: @reexport import .OperatorEnumModule: AbstractOperatorEnum @reexport import .OperatorEnumConstructionModule: OperatorEnum, GenericOperatorEnum, @extend_operators, set_default_variable_names! -@reexport import .EvaluateModule: eval_tree_array, differentiable_eval_tree_array +@reexport import .EvaluateModule: eval_tree_array, differentiable_eval_tree_array, eval_graph_array_diff, eval_graph_single @reexport import .EvaluateDerivativeModule: eval_diff_tree_array, eval_grad_tree_array @reexport import .ChainRulesModule: NodeTangent, extract_gradient @reexport import .SimplifyModule: combine_operators, simplify_tree! diff --git a/src/Evaluate.jl b/src/Evaluate.jl index c71e354f..cfba95e6 100644 --- a/src/Evaluate.jl +++ b/src/Evaluate.jl @@ -826,4 +826,78 @@ function eval_tree_array( end end +function eval_graph_array_diff( + root::GraphNode{T}, + cX::AbstractMatrix{T}, + operators::OperatorEnum, +) where {T} + + # vmap is faster with small cX sizes + # vmapnt (non-temporal) is faster with larger cX sizes (too big so not worth caching?) + dp = Dict{GraphNode, AbstractArray{T}}() + order = topological_sort(root) + for node in order + if node.degree == 0 && !node.constant + dp[node] = view(cX, node.feature, :) + elseif node.degree == 1 + if node.l.constant + node.constant = true + node.val = operators.unaops[node.op](node.l.val) + if !is_valid(node.val) return false end + else + node.constant = false + dp[node] = map(operators.unaops[node.op], dp[node.l]) + if !is_valid_array(dp[node]) return false end + end + elseif node.degree == 2 + if node.l.constant + if node.r.constant + node.constant = true + node.val = operators.binops[node.op](node.l.val, node.r.val) + if !is_valid(node.val) return false end + else + node.constant = false + dp[node] = map(Base.Fix1(operators.binops[node.op], node.l.val), dp[node.r]) + if !is_valid_array(dp[node]) return false end + end + else + if node.r.constant + node.constant = false + dp[node] = map(Base.Fix2(operators.binops[node.op], node.r.val), dp[node.l]) + if !is_valid_array(dp[node]) return false end + else + node.constant = false + dp[node] = map(operators.binops[node.op], dp[node.l], dp[node.r]) + if !is_valid_array(dp[node]) return false end + end + end + end + end + if root.constant + return fill(root.val, size(cX, 2)) + else + return dp[root] + end +end + +function eval_graph_single( + root::GraphNode{T}, + cX::AbstractArray{T}, + operators::OperatorEnum +) where {T} + order = topological_sort(root) + for node in order + if node.degree == 0 && !node.constant + node.val = cX[node.feature] + elseif node.degree == 1 + node.val = operators.unaops[node.op](node.l.val) + if !is_valid(node.val) return false end + elseif node.degree == 2 + node.val = operators.binops[node.op](node.l.val, node.r.val) + if !is_valid(node.val) return false end + end + end + return root.val +end + end \ No newline at end of file From d09db4e6077ce3921c4534d330ca55d2225a997c Mon Sep 17 00:00:00 2001 From: Robert Dancer Date: Sun, 11 Aug 2024 01:06:42 +0100 Subject: [PATCH 16/16] Change Plots, GraphRecipes, LoopVectorization to weak deps --- Project.toml | 8 +- ext/DynamicExpressionsLoopVectorizationExt.jl | 66 +++++++++++++- ext/DynamicExpressionsVisualizeExt.jl | 72 +++++++++++++++ src/DynamicExpressions.jl | 2 +- src/Evaluate.jl | 90 +++++++++++-------- src/Visualize.jl | 66 +------------- 6 files changed, 196 insertions(+), 108 deletions(-) create mode 100644 ext/DynamicExpressionsVisualizeExt.jl diff --git a/Project.toml b/Project.toml index 7acc1bc2..d0b6288c 100644 --- a/Project.toml +++ b/Project.toml @@ -7,12 +7,9 @@ version = "0.18.6" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" DispatchDoctor = "8d63f2c5-f18a-4cf2-ba9d-b3f60fc568c8" -GraphRecipes = "bd48cda9-67a9-57be-86fa-5b3c104eda73" Interfaces = "85a1e053-f937-4924-92a5-1367d23b7b87" -LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930" -Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" @@ -23,12 +20,17 @@ Bumper = "8ce10254-0962-460f-a3d8-1f77fea1446e" Optim = "429524aa-4258-5aef-a3af-852621145aeb" SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" +Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" +GraphRecipes = "bd48cda9-67a9-57be-86fa-5b3c104eda73" +LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890" [extensions] DynamicExpressionsBumperExt = "Bumper" DynamicExpressionsOptimExt = "Optim" DynamicExpressionsSymbolicUtilsExt = "SymbolicUtils" DynamicExpressionsZygoteExt = "Zygote" +DynamicExpressionsLoopVectorizationExt = "LoopVectorization" +DynamicExpressionsVisualizeExt = ["Plots","GraphRecipes"] [compat] Bumper = "0.6" diff --git a/ext/DynamicExpressionsLoopVectorizationExt.jl b/ext/DynamicExpressionsLoopVectorizationExt.jl index ec158320..c0906014 100644 --- a/ext/DynamicExpressionsLoopVectorizationExt.jl +++ b/ext/DynamicExpressionsLoopVectorizationExt.jl @@ -1,7 +1,9 @@ module DynamicExpressionsLoopVectorizationExt -using LoopVectorization: @turbo -using DynamicExpressions: AbstractExpressionNode +using DynamicExpressions + +using LoopVectorization: @turbo, vmapnt +using DynamicExpressions: AbstractExpressionNode, GraphNode, OperatorEnum using DynamicExpressions.UtilsModule: ResultOk, fill_similar using DynamicExpressions.EvaluateModule: @return_on_check import DynamicExpressions.EvaluateModule: @@ -14,6 +16,7 @@ import DynamicExpressions.EvaluateModule: deg2_r0_eval import DynamicExpressions.ExtensionInterfaceModule: _is_loopvectorization_loaded, bumper_kern1!, bumper_kern2! +import DynamicExpressions.ValueInterfaceModule: is_valid, is_valid_array _is_loopvectorization_loaded(::Int) = true @@ -212,4 +215,63 @@ function bumper_kern2!(op::F, cumulator1, cumulator2, ::Val{true}) where {F} return cumulator1 end + + +# graph eval + +function DynamicExpressions.EvaluateModule._eval_graph_array( + root::GraphNode{T}, + cX::AbstractMatrix{T}, + operators::OperatorEnum, + loopVectorization::Val{true} +) where {T} + + # vmap is faster with small cX sizes + # vmapnt (non-temporal) is faster with larger cX sizes (too big so not worth caching?) + + order = topological_sort(root) + for node in order + if node.degree == 0 && !node.constant + node.cache = view(cX, node.feature, :) + elseif node.degree == 1 + if node.l.constant + node.constant = true + node.val = operators.unaops[node.op](node.l.val) + if !is_valid(node.val) return ResultOk(Vector{T}(undef, size(cX, 2)), false) end + else + node.constant = false + node.cache = vmapnt(operators.unaops[node.op], node.l.cache) + if !is_valid_array(node.cache) return ResultOk(node.cache, false) end + end + elseif node.degree == 2 + if node.l.constant + if node.r.constant + node.constant = true + node.val = operators.binops[node.op](node.l.val, node.r.val) + if !is_valid(node.val) return ResultOk(Vector{T}(undef, size(cX, 2)), false) end + else + node.constant = false + node.cache = vmapnt(Base.Fix1(operators.binops[node.op], node.l.val), node.r.cache) + if !is_valid_array(node.cache) return ResultOk(node.cache, false) end + end + else + if node.r.constant + node.constant = false + node.cache = vmapnt(Base.Fix2(operators.binops[node.op], node.r.val), node.l.cache) + if !is_valid_array(node.cache) return ResultOk(node.cache, false) end + else + node.constant = false + node.cache = vmapnt(operators.binops[node.op], node.l.cache, node.r.cache) + if !is_valid_array(node.cache) return ResultOk(node.cache, false) end + end + end + end + end + if root.constant + return ResultOk(fill(root.val, size(cX, 2)), true) + else + return ResultOk(root.cache, true) + end +end + end diff --git a/ext/DynamicExpressionsVisualizeExt.jl b/ext/DynamicExpressionsVisualizeExt.jl new file mode 100644 index 00000000..e422b618 --- /dev/null +++ b/ext/DynamicExpressionsVisualizeExt.jl @@ -0,0 +1,72 @@ +module DynamicExpressionsVisualizeExt + +using Plots, GraphRecipes, DynamicExpressions +using DynamicExpressions: GraphNode, Node, topological_sort, AbstractOperatorEnum, get_op_name + +function DynamicExpressions.visualize( + graph::Union{GraphNode,Node}, # types accepted by topological_sort + operators::AbstractOperatorEnum, + show = true +) + @info "Generating graph visualization" + + order = reverse(topological_sort(graph)) + + # multigraph adjacency list + g = map( + node -> convert(Vector{Int64}, map( + cindex -> findfirst(x -> x === node.children[cindex][], order), + 1:node.degree + )), + order + ) + + # node labels + n = map(x -> + if x.degree == 0 + x.constant ? x.val : 'x' * string(x.feature) + elseif x.degree == 1 + join(get_op_name(operators.unaops[x.op])) + elseif x.degree == 2 + join(get_op_name(operators.binops[x.op])) + else + @warn "Can't label operator node with degree > 2" + end, + order + ) + + # edge labels (specifies parameter no.) + e = Dict{Tuple{Int64, Int64, Int64}, String}() + for (index, node) in enumerate(order) + edge_count = Dict{Int64, Int64}() # count number of edges to each child node + for cindex in 1:node.degree + order_cindex = findfirst(x -> x === node.children[cindex][], order) + get!( + e, + ( + index, # source + order_cindex, # dest + get!(edge_count, order_cindex, pop!(edge_count, order_cindex, 0)+1) # edge no. + ), + string(cindex) + ) + end + end + + # node colours + c = map(x -> x == 1 ? 2 : 1, eachindex(order)) + + return graphplot( + g, + names = n, + edgelabel = e, + nodecolor = c, + show = show, + nodeshape=:circle, + edge_label_box = false, + edgelabel_offset = 0.015, + nodesize=0.15 + ) +end + +end \ No newline at end of file diff --git a/src/DynamicExpressions.jl b/src/DynamicExpressions.jl index b482a7e2..1b44052f 100644 --- a/src/DynamicExpressions.jl +++ b/src/DynamicExpressions.jl @@ -68,7 +68,7 @@ import .NodeModule: count_scalar_constants, get_scalar_constants, set_scalar_constants! -@reexport import .StringsModule: string_tree, print_tree +@reexport import .StringsModule: string_tree, print_tree, get_op_name @reexport import .OperatorEnumModule: AbstractOperatorEnum @reexport import .OperatorEnumConstructionModule: OperatorEnum, GenericOperatorEnum, @extend_operators, set_default_variable_names! diff --git a/src/Evaluate.jl b/src/Evaluate.jl index cfba95e6..1785b0d2 100644 --- a/src/Evaluate.jl +++ b/src/Evaluate.jl @@ -1,7 +1,6 @@ module EvaluateModule using DispatchDoctor: @unstable -using LoopVectorization import ..NodeModule: AbstractExpressionNode, constructorof, GraphNode, topological_sort import ..StringsModule: string_tree @@ -772,58 +771,73 @@ end # Parametric arguments don't use dynamic dispatch, calls with turbo/bumper won't resolve properly -function eval_tree_array( +# overwritten in ext/DynamicExpressionsLoopVectorizationExt.jl +function _eval_graph_array( root::GraphNode{T}, cX::AbstractMatrix{T}, operators::OperatorEnum, + loopVectorization::Val{true} ) where {T} + error("_is_loopvectorization_loaded(0) is true but _eval_graph_array has not been overwritten") +end - # vmap is faster with small cX sizes - # vmapnt (non-temporal) is faster with larger cX sizes (too big so not worth caching?) - - order = topological_sort(root) - for node in order - if node.degree == 0 && !node.constant - node.cache = view(cX, node.feature, :) - elseif node.degree == 1 - if node.l.constant +function _eval_graph_array( + root::GraphNode{T}, + cX::AbstractMatrix{T}, + operators::OperatorEnum, + loopVectorization::Val{false} +) where {T} +order = topological_sort(root) +for node in order + if node.degree == 0 && !node.constant + node.cache = view(cX, node.feature, :) + elseif node.degree == 1 + if node.l.constant + node.constant = true + node.val = operators.unaops[node.op](node.l.val) + if !is_valid(node.val) return ResultOk(Vector{T}(undef, size(cX, 2)), false) end + else + node.constant = false + node.cache = map(operators.unaops[node.op], node.l.cache) + if !is_valid_array(node.cache) return ResultOk(node.cache, false) end + end + elseif node.degree == 2 + if node.l.constant + if node.r.constant node.constant = true - node.val = operators.unaops[node.op](node.l.val) + node.val = operators.binops[node.op](node.l.val, node.r.val) if !is_valid(node.val) return ResultOk(Vector{T}(undef, size(cX, 2)), false) end else node.constant = false - node.cache = vmapnt(operators.unaops[node.op], node.l.cache) + node.cache = map(Base.Fix1(operators.binops[node.op], node.l.val), node.r.cache) if !is_valid_array(node.cache) return ResultOk(node.cache, false) end end - elseif node.degree == 2 - if node.l.constant - if node.r.constant - node.constant = true - node.val = operators.binops[node.op](node.l.val, node.r.val) - if !is_valid(node.val) return ResultOk(Vector{T}(undef, size(cX, 2)), false) end - else - node.constant = false - node.cache = vmapnt(Base.Fix1(operators.binops[node.op], node.l.val), node.r.cache) - if !is_valid_array(node.cache) return ResultOk(node.cache, false) end - end + else + if node.r.constant + node.constant = false + node.cache = map(Base.Fix2(operators.binops[node.op], node.r.val), node.l.cache) + if !is_valid_array(node.cache) return ResultOk(node.cache, false) end else - if node.r.constant - node.constant = false - node.cache = vmapnt(Base.Fix2(operators.binops[node.op], node.r.val), node.l.cache) - if !is_valid_array(node.cache) return ResultOk(node.cache, false) end - else - node.constant = false - node.cache = vmapnt(operators.binops[node.op], node.l.cache, node.r.cache) - if !is_valid_array(node.cache) return ResultOk(node.cache, false) end - end + node.constant = false + node.cache = map(operators.binops[node.op], node.l.cache, node.r.cache) + if !is_valid_array(node.cache) return ResultOk(node.cache, false) end end end end - if root.constant - return ResultOk(fill(root.val, size(cX, 2)), true) - else - return ResultOk(root.cache, true) - end +end +if root.constant + return ResultOk(fill(root.val, size(cX, 2)), true) +else + return ResultOk(root.cache, true) +end +end + +function eval_tree_array( + root::GraphNode{T}, + cX::AbstractMatrix{T}, + operators::OperatorEnum, +) where {T} + return _eval_graph_array(root, cX, operators, Val(_is_loopvectorization_loaded(0))) end function eval_graph_array_diff( diff --git a/src/Visualize.jl b/src/Visualize.jl index 82ea9653..1b3f65ce 100644 --- a/src/Visualize.jl +++ b/src/Visualize.jl @@ -1,76 +1,14 @@ module VisualizeModule -using Plots, GraphRecipes -using ..NodeModule: GraphNode, Node, topological_sort +using ..NodeModule: GraphNode, Node using ..OperatorEnumModule: AbstractOperatorEnum -using ..StringsModule: get_op_name - -# visualization probably best as an extension if pulled into master function visualize( graph::Union{GraphNode,Node}, # types accepted by topological_sort operators::AbstractOperatorEnum, show = true ) - @info "Generating graph visualization" - - order = reverse(topological_sort(graph)) - - # multigraph adjacency list - g = map( - node -> convert(Vector{Int64}, map( - cindex -> findfirst(x -> x === node.children[cindex][], order), - 1:node.degree - )), - order - ) - - # node labels - n = map(x -> - if x.degree == 0 - x.constant ? x.val : 'x' * string(x.feature) - elseif x.degree == 1 - join(get_op_name(operators.unaops[x.op])) - elseif x.degree == 2 - join(get_op_name(operators.binops[x.op])) - else - @warn "Can't label operator node with degree > 2" - end, - order - ) - - # edge labels (specifies parameter no.) - e = Dict{Tuple{Int64, Int64, Int64}, String}() - for (index, node) in enumerate(order) - edge_count = Dict{Int64, Int64}() # count number of edges to each child node - for cindex in 1:node.degree - order_cindex = findfirst(x -> x === node.children[cindex][], order) - get!( - e, - ( - index, # source - order_cindex, # dest - get!(edge_count, order_cindex, pop!(edge_count, order_cindex, 0)+1) # edge no. - ), - string(cindex) - ) - end - end - - # node colours - c = map(x -> x == 1 ? 2 : 1, eachindex(order)) - - return graphplot( - g, - names = n, - edgelabel = e, - nodecolor = c, - show = show, - nodeshape=:circle, - edge_label_box = false, - edgelabel_offset = 0.015, - nodesize=0.15 - ) + error("Please load the Plots.jl and GraphRecipes.jl packages to use this feature.") end end \ No newline at end of file