Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MLJ integration via Tables.jl interface #84

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,11 @@ version = "0.0.1"

[deps]
ArchGDAL = "c9ce4bd3-c3d5-55b8-8973-c0e20141b8c3"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
Fauxcurrences = "a2d61402-033a-4ca9-aef4-652d70cf7c9c"
GBIF = "ee291a33-5a6c-5552-a3c8-0f29a1181037"
GDAL = "add2ef01-049f-52c4-9ee2-e494f65e021a"
GLMakie = "e9467ef8-e4e7-5192-8a1a-b1aee30e663a"
GeoMakie = "db073c08-6b98-4ee5-b6a4-5efafb3259c6"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Revise = "295af30f-e4ad-537b-8983-00126c2a3abe"
SimpleSDMDatasets = "2c7d61d0-5c73-410d-85b2-d2e7fbbdcefa"
SimpleSDMLayers = "2c645270-77db-11e9-22c3-0f302a89c64c"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Expand Down
2 changes: 2 additions & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
[deps]
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
GeoMakie = "db073c08-6b98-4ee5-b6a4-5efafb3259c6"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
MLJ = "add582a8-e3ab-11e8-2d5e-e98b27df1bc7"
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
126 changes: 126 additions & 0 deletions docs/src/vignettes/09_data_preparation.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# # Preparing data for prediction

using SpeciesDistributionToolkit
using CairoMakie

#

spatial_extent = (left = 5.0, bottom = 57.5, right = 10.0, top = 62.7)

#

rangifer = taxon("Rangifer tarandus tarandus"; strict = false)
query = [
"occurrenceStatus" => "PRESENT",
"hasCoordinate" => true,
"decimalLatitude" => (spatial_extent.bottom, spatial_extent.top),
"decimalLongitude" => (spatial_extent.left, spatial_extent.right),
"limit" => 300,
]
presences = occurrences(rangifer, query...)
for i in 1:3
occurrences!(presences)
end

#

dataprovider = RasterData(CHELSA1, BioClim)

varnames = layerdescriptions(dataprovider)

#

layers = [
convert(
SimpleSDMResponse,
1.0SimpleSDMPredictor(dataprovider; spatial_extent..., layer = lname),
) for
lname in keys(varnames)
]

#

originallayers = deepcopy(layers)

#

presenceonly = mask(layers[1], presences, Bool)
absenceonly = SpeciesDistributionToolkit.sample(
pseudoabsencemask(SurfaceRangeEnvelope, presenceonly),
250,
)
replace!(presenceonly, false => nothing)
replace!(absenceonly, false => nothing)
for cell in absenceonly
presenceonly[cell.longitude, cell.latitude] = false
end

for i in eachindex(layers)
keys_to_void = setdiff(keys(layers[i]), keys(presenceonly))
for k in keys_to_void
layers[i][k] = nothing
end
end

layers

#

refs = Ref.([layers..., presenceonly])

datastack = SimpleSDMStack([values(varnames)..., "Presence"], refs)

predictionstack = SimpleSDMStack([values(varnames)...], Ref.(originallayers))

#


using DataFrames
DataFrame(datastack)

#

using MLJ

#

y, X = unpack(select(DataFrame(datastack), Not([:longitude, :latitude])), ==(:Presence));
y = coerce(y, Continuous)

#

Standardizer = @load Standardizer pkg = MLJModels add = true verbosity = 0
LM = @load LinearRegressor pkg = MLJLinearModels add = true verbosity = 0
model = Standardizer() |> LM()

#

mach = machine(model, X, y) |> fit!

#

perf_measures = [mcc, f1score, accuracy, balanced_accuracy]
evaluate!(
mach;
resampling = CV(; nfolds = 3, shuffle = true, rng = Xoshiro(234)),
measure = perf_measures,
)

#

value = predict(mach, select(DataFrame(predictionstack), Not([:longitude, :latitude])));

#

prediction = select(DataFrame(predictionstack), [:longitude, :latitude]);
prediction.value = value;

#

output = Tables.materializer(SimpleSDMResponse)(prediction)

#

heatmap(sprinkle(output)...; colormap = :viridis)
scatter!(longitudes(presences), latitudes(presences))
current_figure()
7 changes: 6 additions & 1 deletion src/SpeciesDistributionToolkit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ const _distance_function = Distances.Haversine(6371.0)

import StatsBase

import Tables

# We make ample use of re-export
using Reexport

Expand All @@ -29,8 +31,11 @@ include("io/geotiff.jl")
include("io/ascii.jl")
include("io/read_write.jl")

# Stack for data export
include("stack.jl")
export SimpleSDMStack

# Tables interface
import Tables
include("tables.jl")

# Functions for pseudo-absence generation
Expand Down
77 changes: 77 additions & 0 deletions src/stack.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""
SimpleSDMStack

Stores multiple _references_ to layers alongside with their names. This is mostly useful because it provides an interface that we can use as a Tables.jl provider.
"""
struct SimpleSDMStack
names::Vector{String}
layers::Vector{Base.RefValue}
function SimpleSDMStack(names::Vector{String}, layers::Vector{Base.RefValue})
# As many names as layers
@assert length(names) == length(layers)
# Layers have the correct type
@assert all([typeof(layer.x) <: SimpleSDMLayer for layer in layers])
# Layers are all compatible
@assert all([
SimpleSDMLayers._layers_are_compatible(first(layers).x, layer.x) for
layer in layers
])
# Layers all have the same keys
@assert all([
sort(keys(first(layers).x)) == sort(keys(layer.x)) for layer in layers
])
# Return if all pass
return new(names, layers)
end
end

Base.length(s::T) where {T <: SimpleSDMStack} = length(first(s.layers).x)
Base.names(s::T) where {T <: SimpleSDMStack} = s.names
SimpleSDMLayers.latitudes(s::T) where {T <: SimpleSDMStack} = latitudes(first(s.layers).x)
SimpleSDMLayers.longitudes(s::T) where {T <: SimpleSDMStack} = longitudes(first(s.layers).x)
SimpleSDMLayers.boundingbox(s::T) where {T <: SimpleSDMStack} =
boundingbox(first(s.layers).x)

Base.IteratorSize(::T) where {T <: SimpleSDMStack} = Base.HasLength()
function Base.IteratorEltype(s::T) where {T <: SimpleSDMStack}
varnames = [:longitude, :latitude, Symbol.(names(s))...]
vartypes = [
eltype(longitudes(s)),
eltype(latitudes(s)),
[SimpleSDMLayers._inner_type(l.x) for l in s.layers]...,
]
return NamedTuple{tuple(varnames...), Tuple{vartypes...}}
end

function Base.iterate(s::SimpleSDMStack)
position = findfirst(!isnothing, s.layers[1].x.grid)
lon = longitudes(s)[last(position.I)]
lat = latitudes(s)[first(position.I)]
vals = [l.x[lon, lat] for l in s.layers]
varnames = [:longitude, :latitude, Symbol.(names(s))...]
return (NamedTuple{tuple(varnames...)}(tuple(lon, lat, vals...)), position)
end

function Base.iterate(s::SimpleSDMStack, state)
newstate = LinearIndices(s.layers[1].x.grid)[state] + 1
newstate > prod(size(s.layers[1].x.grid)) && return nothing
position = findnext(
!isnothing,
s.layers[1].x.grid,
CartesianIndices(s.layers[1].x.grid)[newstate],
)
isnothing(position) && return nothing
lon = longitudes(s)[last(position.I)]
lat = latitudes(s)[first(position.I)]
vals = [l.x[lon, lat] for l in s.layers]
varnames = [:longitude, :latitude, Symbol.(names(s))...]
return (NamedTuple{tuple(varnames...)}(tuple(lon, lat, vals...)), position)
end

Tables.istable(::Type{SimpleSDMStack}) = true
Tables.rowaccess(::Type{SimpleSDMStack}) = true
function Tables.schema(s::SimpleSDMStack)
tp = first(s)
sc = Tables.Schema(keys(tp), typeof.(values(tp)))
return sc
end