-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add lazy compute operations #25
base: dev
Are you sure you want to change the base?
Changes from all commits
a0a17ea
cda2fd9
2faf11a
5043bc4
a783626
9fbbb8e
1ca659f
294fabc
1f45f1a
cc9dbe0
d8232b8
05d965b
9117ff6
850c170
c1d52af
9d44267
89bca5d
ecbfe08
14d2e00
5bb180d
aaf300a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
module ConScapeCUDSS | ||
|
||
using ConScape | ||
using CUDA | ||
using CUDA.CUSPARSE | ||
using CUDSS | ||
using SparseArrays | ||
using LinearAlgebra | ||
|
||
using ConScape: FundamentalMeasure, AbstractProblem, Grid, setup_sparse_problem | ||
|
||
struct CUDSSsolver <: Solver end | ||
|
||
function ConScape.solve(m::CUDSSsolver, cm::FundamentalMeasure, p::AbstractProblem, g::Grid) | ||
(; A, B, Pref, W) = setup_sparse_problem(g, cm) | ||
Z = zeros(T, size(B)) | ||
|
||
A_gpu = CuSparseMatrixCSR(A |> tril) | ||
Z_gpu = CuMatrix(Z) | ||
B_gpu = CuMatrix(B) | ||
|
||
solver = CudssSolver(A_gpu, "S", "L") | ||
|
||
cudss("analysis", solver, Z_gpu, B_gpu) | ||
cudss("factorization", solver, Z_gpu, B_gpu) | ||
cudss("solve", solver, Z_gpu, B_gpu) | ||
|
||
Z .= Z_gpu | ||
# TODO: maybe graph measures can run on GPU as well? | ||
grsp = GridRSP(g, cm.θ, Pref, W, Z) | ||
results = map(p.graph_measures) do gm | ||
compute(gm, p, grsp) | ||
end | ||
return _merge_to_stack(results) | ||
end | ||
|
||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,30 +1,41 @@ | ||
module ConScape | ||
|
||
using SparseArrays, LinearAlgebra | ||
using Graphs, Plots, SimpleWeightedGraphs, ProgressLogging, ArnoldiMethod | ||
using Rasters | ||
using Rasters.DimensionalData | ||
using SparseArrays, LinearAlgebra | ||
using Graphs, Plots, SimpleWeightedGraphs, ProgressLogging, ArnoldiMethod | ||
using Rasters | ||
using LinearSolve | ||
using Rasters.DimensionalData | ||
|
||
abstract type ConnectivityFunction <: Function end | ||
abstract type DistanceFunction <: ConnectivityFunction end | ||
abstract type ProximityFunction <: ConnectivityFunction end | ||
# Old funcion-based interface | ||
abstract type ConnectivityFunction <: Function end | ||
abstract type DistanceFunction <: ConnectivityFunction end | ||
abstract type ProximityFunction <: ConnectivityFunction end | ||
|
||
struct least_cost_distance <: DistanceFunction end | ||
struct expected_cost <: DistanceFunction end | ||
struct free_energy_distance <: DistanceFunction end | ||
struct least_cost_distance <: DistanceFunction end | ||
struct expected_cost <: DistanceFunction end | ||
struct free_energy_distance <: DistanceFunction end | ||
|
||
struct survival_probability <: ProximityFunction end | ||
struct power_mean_proximity <: ProximityFunction end | ||
struct survival_probability <: ProximityFunction end | ||
struct power_mean_proximity <: ProximityFunction end | ||
|
||
# Randomized shortest path algorithms | ||
include("randomizedshortestpath.jl") | ||
# Grid struct and methods | ||
include("grid.jl") | ||
# GridRSP (randomized shortest path) struct and methods | ||
include("gridrsp.jl") | ||
# IO | ||
include("io.jl") | ||
# Utilities | ||
include("utils.jl") | ||
# Need to define before loading files | ||
abstract type AbstractProblem end | ||
abstract type Solver end | ||
|
||
# Randomized shortest path algorithms | ||
include("randomizedshortestpath.jl") | ||
# Grid struct and methods | ||
include("grid.jl") | ||
# GridRSP (randomized shortest path) struct and methods | ||
include("gridrsp.jl") | ||
# IO | ||
include("io.jl") | ||
# Utilities | ||
include("utils.jl") | ||
include("graph_measure.jl") | ||
include("connectivity_measure.jl") | ||
include("problem.jl") | ||
include("solvers.jl") | ||
include("tiles.jl") | ||
|
||
end |
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How about having a folder There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should reorganise all the code like that, but in this PR im trying to add new things without completely reorganising the old ones so that changes to the old code diff properly There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So for now its just a file with struct definitions and the methods are left in their current place. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
# New type-based interface | ||
# Easier to add parameters to these | ||
abstract type ConnectivityMeasure end | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you describe what is the difference between those? I feel like the type system should be as close as possible to the mathematical classification of these measures. I suggest that we could follow the classification of e.g. Graphs.jl or that of networkx: https://networkx.org/documentation/stable/reference/algorithms/index.html i.e., DistanceMeasure, ConnectivityMeasure, CentralityMeasure, etc... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sounds like a good idea, the GraphMeasures that we discussed would then be CentralityMeasures (connected_habitat is a weighted closeness centrality and the betweennesses are betweenness centralities) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Lets do a round or two of documenting and reorganising/renaming the type structure when differences get clearer, the objects here can be seen as place holders for better names. |
||
|
||
# TODO document these groups | ||
abstract type FundamentalMeasure <: ConnectivityMeasure end | ||
abstract type DistanceMeasure <: FundamentalMeasure end | ||
|
||
struct LeastCostDistance <: ConnectivityMeasure end | ||
@kwdef struct ExpectedCost{T<:Union{Real,Nothing},CM} <: DistanceMeasure | ||
θ::T=nothing | ||
distance_transformation::CM=nothing | ||
approx::Bool=false | ||
end | ||
@kwdef struct FreeEnergyDistance{T<:Union{Real,Nothing},CM} <: DistanceMeasure | ||
θ::T=nothing | ||
distance_transformation::CM=nothing | ||
approx::Bool=false | ||
end | ||
@kwdef struct SurvivalProbability{T<:Union{Real,Nothing}} <: FundamentalMeasure | ||
θ::T=nothing | ||
approx::Bool=false | ||
end | ||
@kwdef struct PowerMeanProximity{T<:Union{Real,Nothing}} <: FundamentalMeasure | ||
θ::T=nothing | ||
approx::Bool=false | ||
end | ||
|
||
keywords(cm::ConnectivityMeasure) = _keywords(cm) | ||
|
||
# TODO remove the complexity of the connectivity_function | ||
# These methods are mostly to avoid changing the original interface for now | ||
connectivity_function(::LeastCostDistance) = least_cost_distance | ||
connectivity_function(::ExpectedCost) = expected_cost | ||
connectivity_function(::FreeEnergyDistance) = free_energy_distance | ||
connectivity_function(::SurvivalProbability) = survival_probability | ||
connectivity_function(::PowerMeanProximity) = power_mean_proximity | ||
|
||
distance_transformation(::ConnectivityMeasure) = nothing | ||
distance_transformation(cm::Union{ExpectedCost,FreeEnergyDistance}) = cm.distance_transformation | ||
|
||
# This is not used yet but could be | ||
compute(cm::ConnectivityMeasure, g; kw...) = | ||
connectivity_function(m)(g; keywords(cm)..., kw...) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
""" | ||
GraphMeasure | ||
|
||
Abstract supertype for graph measures. | ||
These are lazy definitions of conscape functions. | ||
""" | ||
abstract type GraphMeasure end | ||
|
||
keywords(o::GraphMeasure) = _keywords(o) | ||
|
||
# TODO: document/rethink these | ||
abstract type TopologicalMeasure <: GraphMeasure end | ||
abstract type BetweennessMeasure <: GraphMeasure end | ||
abstract type PerturbationMeasure <: GraphMeasure end | ||
abstract type PathDistributionMeasure <: GraphMeasure end | ||
|
||
struct BetweennessQweighted <: BetweennessMeasure end | ||
@kwdef struct BetweennessKweighted{DV} <: BetweennessMeasure | ||
diagvalue::DV=nothing | ||
end | ||
struct EdgeBetweennessQweighted <: BetweennessMeasure end | ||
@kwdef struct EdgeBetweennessKweighted{DV} <: BetweennessMeasure | ||
diagvalue::DV=nothing | ||
end | ||
|
||
@kwdef struct ConnectedHabitat{DV} <: GraphMeasure | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think that this should be a topological measure |
||
diagvalue::DV=nothing | ||
end | ||
|
||
@kwdef struct Criticality{DV,AV,QT,QS} <: PerturbationMeasure | ||
diagvalue::DV=nothing | ||
avalue::AV=floatmin() | ||
qˢvalue::QS=0.0 | ||
qᵗvalue::QT=0.0 | ||
end | ||
|
||
# These maybe don't quite belong here? | ||
@kwdef struct EigMax{DV,T} <: TopologicalMeasure | ||
diagvalue::DV=nothing | ||
tol::T=1e-14 | ||
end | ||
|
||
struct MeanLeastCostKullbackLeiblerDivergence <: PathDistributionMeasure end | ||
struct MeanKullbackLeiblerDivergence <: PathDistributionMeasure end | ||
|
||
# Map structs to functions | ||
|
||
# These return Rasters | ||
graph_function(m::BetweennessKweighted) = betweenness_kweighted | ||
graph_function(m::BetweennessQweighted) = betweenness_qweighted | ||
graph_function(m::ConnectedHabitat) = connected_habitat | ||
graph_function(m::Criticality) = criticality | ||
# These return scalars | ||
graph_function(m::MeanLeastCostKullbackLeiblerDivergence) = mean_lc_kl_divergence | ||
graph_function(m::MeanKullbackLeiblerDivergence) = mean_kl_divergence | ||
# These return sparse arrays | ||
graph_function(m::EdgeBetweennessKweighted) = edge_betweenness_kweighted | ||
graph_function(m::EdgeBetweennessQweighted) = edge_betweenness_qweighted | ||
# Returns a tuple | ||
graph_function(m::EigMax) = eigmax | ||
|
||
# Map structs to function keywords, | ||
# a bit of a hack until we refactor the rest | ||
keywords(gm::GraphMeasure, p::AbstractProblem) = | ||
(; _keywords(gm)...)#, solver=solver(p)) | ||
keywords(gm::ConnectedHabitat, p::AbstractProblem) = | ||
(; _keywords(gm)..., approx=connectivity_measure(p).approx)#, solver=solver(p)) | ||
|
||
# A trait for connectivity requirement | ||
struct NeedsConnectivity end | ||
struct NoConnectivity end | ||
needs_connectivity(::GraphMeasure) = NoConnectivity() | ||
needs_connectivity(::BetweennessKweighted) = NeedsConnectivity() | ||
needs_connectivity(::EdgeBetweennessKweighted) = NeedsConnectivity() | ||
needs_connectivity(::EigMax) = NeedsConnectivity() | ||
needs_connectivity(::ConnectedHabitat) = NeedsConnectivity() | ||
needs_connectivity(::Criticality) = NeedsConnectivity() | ||
|
||
# compute | ||
# This is where things actually happen | ||
# | ||
# Add dispatch on connectivity measure | ||
compute(gm::GraphMeasure, p::AbstractProblem, g::Union{Grid,GridRSP}) = | ||
compute(needs_connectivity(gm), gm, p, g) | ||
function compute(::NeedsConnectivity, | ||
gm::GraphMeasure, | ||
p::AbstractProblem, | ||
g::Union{Grid,GridRSP} | ||
) | ||
cf = connectivity_function(p) | ||
dt = distance_transformation(p) | ||
# Handle multiple distance transformations | ||
if dt isa NamedTuple | ||
map(distance_transformation) do dtx | ||
graph_function(gm)(g; keywords(gm, p)..., distance_transformation=dtx, connectivity_function=cf) | ||
end | ||
else | ||
graph_function(gm)(g; keywords(gm, p)..., distance_transformation=dt, connectivity_function=cf) | ||
end | ||
end | ||
function compute(::NoConnectivity, | ||
gm::GraphMeasure, | ||
p::AbstractProblem, | ||
g::Union{Grid,GridRSP} | ||
) | ||
graph_function(gm)(g; keywords(gm, p)...) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not having something like
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've allowed passing solver modes to
\
in rsp, but playing around I found some situations where LinearSolve solvers were very fast on Z but stalled in RSP, so maybe we want different solvers per problem?We also want to factor out some of the solves as the same thing is being done in multiple places if you want multiple kinds of betweenness metrics in the same run.