Skip to content

Commit

Permalink
Adds flexible resampling function
Browse files Browse the repository at this point in the history
  • Loading branch information
vwmaus committed Dec 26, 2023
1 parent c8489d2 commit 7e46fce
Show file tree
Hide file tree
Showing 8 changed files with 77 additions and 67 deletions.
6 changes: 4 additions & 2 deletions R/predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@ predict.twdtw_knn1 <- function(object, newdata, ...){
newdata_ts <- prepare_time_series(newdata)

# Compute TWDTW distances
distances <- sapply(object$data$observations, function(pattern){
sapply(newdata_ts$observations, function(ts) {
distances <- sapply(seq_along(object$data$observations), function(i){
pattern <- object$data$observations[[i]]
sapply(seq_along(newdata_ts$observations), function(j) {
ts <- newdata_ts$observations[[j]]
do.call(proxy::dist, c(list(x = as.data.frame(ts), y = as.data.frame(pattern), method = 'twdtw'), object$twdtw_args))
})
})
Expand Down
48 changes: 23 additions & 25 deletions R/train.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,29 +9,32 @@
#' See details in \link[twdtw]{twdtw}.
#' @param cycle_length The length of the cycle, e.g. phenological cycles. Details in \link[twdtw]{twdtw}.
#' @param time_scale Specifies the time scale for the observations. Details in \link[twdtw]{twdtw}.
#' @param smooth_fun a function specifying how to create temporal patterns using the samples.
#' If not defined, it will keep all samples. Note that reducing the samples to patterns can significantly
#' improve computational time of predictions. See details.
#' @param resampling_fun a function specifying how to create temporal patterns using the samples.
#' If not defined, it will keep all samples. Note that reducing the samples to patterns can significantly
#' improve computational time of predictions. The resampling function must receive a single data frame as
#' argument an return a model. See details.
#' @param start_column Name of the column in y that indicates the start date. Default is 'start_date'.
#' @param end_column Name of the column in y that indicates the end date. Default is 'end_date'.
#' @param label_colum Name of the column in y containing land use labels. Default is 'label'.
#' @param resampling_freq The time for sampling the time series if `smooth_fun` is given.
#' @param resampling_freq The time for sampling the time series if `resampling_fun` is given.
#' If NULL, the function will infer the frequency of observations in `x`.
#' @param ... Additional arguments passed to \link[twdtw]{twdtw}.
#'
#' @details If \code{smooth_fun} not informed, the KNN-1 model will retain all training samples.
#' @details If \code{resampling_fun} not informed, the KNN-1 model will retain all training samples.
#'
#' If a custom smoothing function is passed to `smooth_fun`, the function will be used to
#' If a custom smoothing function is passed to `resampling_fun`, the function will be used to
#' resample values of samples sharing the same label (land cover class).
#'
#' The custom smoothing function takes two numeric vectors as arguments and returns a model:
#' The custom smoothing function takes a single data frame and returns a model.
#' The data frame has two named columns:
#' \itemize{
#' \item The first argument represents the independent variable (typically time).
#' \item The second argument represents the dependent variable (e.g., band values) corresponding to each coordinate in the first argument.
#' \item `x` is the first column representing the independent variable (typically time).
#' \item `y` is the second column representing the dependent variable (e.g., band values) corresponding to each coordinate in `x`.
#' }
#' See the examples section for further clarity.
#'
#' Smooting the samples can significantly reduce the processing time for prediction using `twdtw_knn1` model.
#'
#' See the examples section for further clarity.
#'
#' @return A 'twdtw_knn1' model containing the trained model information and the data used.
#'
Expand Down Expand Up @@ -62,12 +65,12 @@
#'
#' # Create a knn1-twdtw model
#' m <- twdtw_knn1(
#' x = dc,
#' y = samples,
#' smooth_fun = function(x, y) gam(y ~ s(x), data = data.frame(x = x, y = y))
#' cycle_length = 'year',
#' time_scale = 'day',
#' time_weight = c(steepness = 0.1, midpoint = 50))
#' x = dc,
#' y = samples,
#' resampling_fun = function(data) mgcv::gam(y ~ s(x), data = data),
#' cycle_length = 'year',
#' time_scale = 'day',
#' time_weight = c(steepness = 0.1, midpoint = 50))
#'
#' print(m)
#'
Expand All @@ -87,7 +90,7 @@
#'
#' m <- twdtw_knn1(x = dc,
#' y = samples,
#' smooth_fun = function(x, y) lm(y ~ factor(x), data = data.frame(x=x, y=y))
#' resampling_fun = function(data) lm(y ~ factor(x), data = data),
#' cycle_length = 'year',
#' time_scale = 'day',
#' time_weight = c(steepness = 0.1, midpoint = 50))
Expand All @@ -96,7 +99,7 @@
#'
#' }
#' @export
twdtw_knn1 <- function(x, y, smooth_fun = NULL, resampling_freq = NULL,
twdtw_knn1 <- function(x, y, resampling_fun = NULL, resampling_freq = NULL,
time_weight, cycle_length, time_scale,
start_column = 'start_date', end_column = 'end_date',
label_colum = 'label', ...){
Expand Down Expand Up @@ -140,12 +143,7 @@ twdtw_knn1 <- function(x, y, smooth_fun = NULL, resampling_freq = NULL,
ts_data$ts_id <- NULL

smooth_models <- NULL
if(!is.null(smooth_fun)) {

# Check if smooth_fun has two or three arguments
if(length(formals(smooth_fun)) != c(2)) {
stop("The smooth function should have only two arguments!")
}
if(!is.null(resampling_fun)) {

# Shift dates
ts_data$observations <- lapply(ts_data$observations, shift_ts_dates)
Expand All @@ -163,7 +161,7 @@ twdtw_knn1 <- function(x, y, smooth_fun = NULL, resampling_freq = NULL,

# Fit smooth model to each band
smooth_models <- lapply(as.list(ts), function(band) {
smooth_fun(x = as.numeric(y_time), y = band)
resampling_fun(data.frame(x = as.numeric(y_time), y = band))
})

return(smooth_models)
Expand Down
6 changes: 6 additions & 0 deletions cran-comments.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@

# REVIEWS

## v1.0-1

* Improves flexibility for user defined smooting models

*

## v1.0.0

* Major release that removes obsolete dependencies, such as raster, rgdal, and sp.
Expand Down
14 changes: 7 additions & 7 deletions man/plot.twdtw_knn1.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 7 additions & 7 deletions man/predict.twdtw_knn1.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

37 changes: 20 additions & 17 deletions man/twdtw_knn1.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 7 additions & 6 deletions tests/testthat/test-twdtw_classify.R
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ cat("Creating twdtw_knn1 model using GAM...\n")
system.time(
m <- twdtw_knn1(x = dc,
y = samples,
smooth_fun = function(x, y) gam(y ~ s(x), data = data.frame(x = x, y = y)),
resampling_fun = function(data) mgcv::gam(y ~ s(x), data = data),
cycle_length = 'year',
time_scale = 'day',
time_weight = c(steepness = 0.1, midpoint = 50))
Expand All @@ -52,11 +52,12 @@ ggplot() +
cat("Testing model with resampling frequency of 60 and GAM...\n")
m <- twdtw_knn1(x = dc,
y = samples,
resampling_fun = function(data) mgcv::gam(y ~ s(x), data = data),
resampling_freq = 60,
cycle_length = 'year',
time_scale = 'day',
time_weight = c(steepness = 0.1, midpoint = 50),
smooth_fun = function(x, y) gam(y ~ s(x), data = data.frame(x = x, y = y)),
resampling_freq = 60)
time_weight = c(steepness = 0.1, midpoint = 50)
)

cat("Visualizing resampled patterns...\n")
plot(m)
Expand Down Expand Up @@ -98,10 +99,10 @@ plot(m, bands = c('EVI', 'NDVI'))
cat("Testing custom smooth function using the average for each observation date...\n")
m <- twdtw_knn1(x = dc,
y = samples,
resampling_fun = function(data) lm(y ~ factor(x), data = data),
cycle_length = 'year',
time_scale = 'day',
time_weight = c(steepness = 0.1, midpoint = 50),
smooth_fun = function(x, y) lm(y ~ factor(x), data = data.frame(x = x, y = y)))
time_weight = c(steepness = 0.1, midpoint = 50))

cat("Printing model with custom smooth...\n")
print(m)
Expand Down
6 changes: 3 additions & 3 deletions vignettes/landuse-mapping.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ With the files and dates in hand, we can construct a stars satellite image time
```{r , echo = TRUE, eval = TRUE, warning = FALSE, message = FALSE}
# read data-cube
dc <- read_stars(tif_files,
proxy = FALSE,
along = list(time = acquisition_date),
RasterIO = list(bands = 1:6))
Expand All @@ -80,14 +79,15 @@ the 'band' dimension into attributes. This prepares the data for training the TW
There several wayis to buld the TWDTW-1NN model. The default options is to keep all samples as part of the model.
However, this implies in higher computational time for prediction. To speed the precessing,
here I show how to reduce the samples to a single pattern for land use class using
Generalized Additive Models (GAM) with cubic regression splines. That can be achieved by defining a smoothing function, such as
Generalized Additive Models (GAM) with cubic regression splines.
That can be achieved by defining a smoothing function, such as

```{r , echo = TRUE, eval = TRUE, warning = FALSE, message = FALSE}
library(mgcv)
twdtw_model <- twdtw_knn1(x = dc,
y = samples,
smooth_fun = function(x, y) gam(y ~ s(x), data = data.frame(x = x, y = y)),
resampling_fun = function(data) mgcv::gam(y ~ s(x), data = data),
cycle_length = 'year',
time_scale = 'day',
time_weight = c(steepness = 0.1, midpoint = 50))
Expand Down

0 comments on commit 7e46fce

Please sign in to comment.