Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updates for new release #61

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .Rbuildignore
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ examples_x64
^doc$
^Meta$

^\.vscode$
vignettes.awk
_pkgdown.yml
^_pkgdown\.yml$
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ revdep/
CRAN-SUBMISSION

# Other files
.vscode
src/symbols.rds
*.o
*.so
Expand Down
10 changes: 6 additions & 4 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ Package: dtwSat
Type: Package
Title: Time-Weighted Dynamic Time Warping for Satellite Image Time Series Analysis
Version: 1.0-1
Date: 2023-09-25
Date: 2023-10-18
Authors@R:
c(person(given = "Victor",
family = "Maus",
Expand Down Expand Up @@ -32,8 +32,8 @@ Description: Provides a robust approach to land use mapping using multi-dimensio
while also requiring minimal training sets. The package includes tools for training the 1-NN-TWDTW model,
visualizing temporal patterns, producing land use maps, and visualizing the results.
License: GPL (>= 3)
URL: https://github.com/vwmaus/dtwSat/
BugReports: https://github.com/vwmaus/dtwSat/issues/
URL: https://github.com/r-spatial/dtwSat/
BugReports: https://github.com/r-spatial/dtwSat/issues/
Maintainer: Victor Maus <vwmaus1@gmail.com>
VignetteBuilder:
knitr
Expand All @@ -46,11 +46,13 @@ Depends:
stars,
ggplot2
Imports:
mgcv,
stats,
utils,
methods,
tidyr,
proxy
Suggests:
mgcv,
knitr,
rmarkdown,
testthat (>= 3.0.0)
Expand Down
10 changes: 4 additions & 6 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,18 @@
S3method(plot,twdtw_knn1)
S3method(predict,twdtw_knn1)
S3method(print,twdtw_knn1)
export(shift_dates)
export(twdtw_knn1)
import(ggplot2)
import(sf)
import(stars)
import(twdtw)
importFrom(mgcv,gam)
importFrom(mgcv,predict.gam)
importFrom(mgcv,s)
importFrom(methods,as)
importFrom(proxy,dist)
importFrom(stats,as.formula)
importFrom(stats,model.frame)
importFrom(stats,predict)
importFrom(stats,setNames)
importFrom(stats,resid)
importFrom(tidyr,nest)
importFrom(tidyr,pivot_longer)
importFrom(tidyr,pivot_wider)
importFrom(tidyr,unnest)
importFrom(utils,methods)
6 changes: 4 additions & 2 deletions R/predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@ predict.twdtw_knn1 <- function(object, newdata, ...){
newdata_ts <- prepare_time_series(newdata)

# Compute TWDTW distances
distances <- sapply(object$data$observations, function(pattern){
sapply(newdata_ts$observations, function(ts) {
distances <- sapply(seq_along(object$data$observations), function(i){
pattern <- object$data$observations[[i]]
sapply(seq_along(newdata_ts$observations), function(j) {
ts <- newdata_ts$observations[[j]]
do.call(proxy::dist, c(list(x = as.data.frame(ts), y = as.data.frame(pattern), method = 'twdtw'), object$twdtw_args))
})
})
Expand Down
3 changes: 2 additions & 1 deletion R/prepare_time_series.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
#' @return A nested tibble in wide format. Each row of the tibble corresponds to a unique 'ts_id' that maintains the order from the original stars object.
#' The nested structure contains observations (time series) for each 'ts_id', including the 'time' of each observation, and individual bands are presented as separate columns.
#'
#'
#' @noRd
#' @keywords internal
prepare_time_series <- function(x) {

# Remove the 'geom' column if it exists
Expand Down
3 changes: 2 additions & 1 deletion R/shift_dates.R
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
#'
#' shift_dates(x)
#'
#' @export
#' @noRd
#' @keywords internal
shift_dates <- function(x, origin = "1970-01-01") {

# Convert the input dates to Date objects
Expand Down
164 changes: 112 additions & 52 deletions R/train.R
Original file line number Diff line number Diff line change
@@ -1,37 +1,50 @@
#'
#' Train a KNN-1 TWDTW model with optional GAM resampling
#' Train a KNN-1 TWDTW model
#'
#' This function prepares a KNN-1 model with the Time Warp Dynamic Time Warping (TWDTW) algorithm.
#' If a formula is provided, the training samples are resampled using Generalized Additive Models (GAM).
#'
#' @param x A three-dimensional stars object (x, y, time) with bands as attributes.
#' @param y An sf object with the coordinates of the training points.
#' @param time_weight A numeric vector with length two (steepness and midpoint of logistic weight) or a function.
#' See details in \link[twdtw]{twdtw}.
#' @param cycle_length The length of the cycle, e.g. phenological cycles. Details in \link[twdtw]{twdtw}.
#' @param time_scale Specifies the time scale for the observations. Details in \link[twdtw]{twdtw}.
#' @param formula Either NULL or a formula to reduce samples of the same label using Generalized Additive Models (GAM).
#' Default is \code{band ~ s(time)}. See details.
#' @param resampling_fun a function specifying how to create temporal patterns using the samples.
#' If not defined, it will keep all samples. Note that reducing the samples to patterns can significantly
#' improve computational time of predictions. The resampling function must receive a single data frame as
#' argument an return a model. See details.
#' @param start_column Name of the column in y that indicates the start date. Default is 'start_date'.
#' @param end_column Name of the column in y that indicates the end date. Default is 'end_date'.
#' @param label_colum Name of the column in y containing land use labels. Default is 'label'.
#' @param sampling_freq The time frequency for sampling, including the unit (e.g., '16 day').
#' If NULL, the function will infer the frequency. This parameter is only used if a formula is provided.
#' @param ... Additional arguments passed to the \link[mgcv]{gam} function and to \link[twdtw]{twdtw} function.
#' @param resampling_freq The time for sampling the time series if `resampling_fun` is given.
#' If NULL, the function will infer the frequency of observations in `x`.
#' @param ... Additional arguments passed to \link[twdtw]{twdtw}.
#'
#' @details If \code{formula} is NULL, the KNN-1 model will retain all training samples. If a formula is passed (e.g., \code{band ~ \link[mgcv]{s}(time)}),
#' then samples of the same label (land cover class) will be resampled using GAM.
#' Resampling can significantly reduce prediction processing time.
#' @details If \code{resampling_fun} not informed, the KNN-1 model will retain all training samples.
#'
#' If a custom smoothing function is passed to `resampling_fun`, the function will be used to
#' resample values of samples sharing the same label (land cover class).
#'
#' The custom smoothing function takes a single data frame and returns a model.
#' The data frame has two named columns:
#' \itemize{
#' \item `x` is the first column representing the independent variable (typically time).
#' \item `y` is the second column representing the dependent variable (e.g., band values) corresponding to each coordinate in `x`.
#' }
#'
#' Smooting the samples can significantly reduce the processing time for prediction using `twdtw_knn1` model.
#'
#' See the examples section for further clarity.
#'
#' @return A 'twdtw_knn1' model containing the trained model information and the data used.
#'
#' @examples
#' \dontrun{
#'
#' # Read training samples
#' samples_path <-
#' system.file("mato_grosso_brazil/samples.gpkg", package = "dtwSat")
#'
#' samples_path <-
# ' system.file("mato_grosso_brazil/samples.gpkg", package = "dtwSat")
#'
#' samples <- st_read(samples_path, quiet = TRUE)
#'
#' # Get satellite image time sereis files
Expand All @@ -51,12 +64,13 @@
#' dc <- split(dc, c("band"))
#'
#' # Create a knn1-twdtw model
#' m <- twdtw_knn1(x = dc,
#' y = samples,
#' cycle_length = 'year',
#' time_scale = 'day',
#' time_weight = c(steepness = 0.1, midpoint = 50),
#' formula = band ~ s(time))
#' m <- twdtw_knn1(
#' x = dc,
#' y = samples,
#' resampling_fun = function(data) mgcv::gam(y ~ s(x), data = data),
#' cycle_length = 'year',
#' time_scale = 'day',
#' time_weight = c(steepness = 0.1, midpoint = 50))
#'
#' print(m)
#'
Expand All @@ -71,12 +85,24 @@
#' geom_stars(data = lu) +
#' theme_minimal()
#'
#'
#' # Create a knn1-twdtw model with custom smoothing function
#'
#' m <- twdtw_knn1(x = dc,
#' y = samples,
#' resampling_fun = function(data) lm(y ~ factor(x), data = data),
#' cycle_length = 'year',
#' time_scale = 'day',
#' time_weight = c(steepness = 0.1, midpoint = 50))
#'
#' plot(m)
#'
#' }
#' @export
twdtw_knn1 <- function(x, y, time_weight, cycle_length, time_scale,
formula = NULL, start_column = 'start_date',
end_column = 'end_date', label_colum = 'label',
sampling_freq = NULL, ...){
twdtw_knn1 <- function(x, y, resampling_fun = NULL, resampling_freq = NULL,
time_weight, cycle_length, time_scale,
start_column = 'start_date', end_column = 'end_date',
label_colum = 'label', ...){

# Check if x is a stars object with a time dimension
if (!inherits(x, "stars") || dim(x)['time'] < 1 || length(dim(x)) != 3) {
Expand Down Expand Up @@ -116,17 +142,8 @@ twdtw_knn1 <- function(x, y, time_weight, cycle_length, time_scale,
ts_data <- prepare_time_series(as.data.frame(ts_data))
ts_data$ts_id <- NULL

if(!is.null(formula)) {

# Check if formula has two
if(length(all.vars(formula)) != 2) {
stop("The formula should have only one predictor!")
}

# Determine sampling frequency
if (is.null(sampling_freq)) {
sampling_freq <- get_time_series_freq(ts_data)
}
smooth_models <- NULL
if(!is.null(resampling_fun)) {

# Shift dates
ts_data$observations <- lapply(ts_data$observations, shift_ts_dates)
Expand All @@ -135,29 +152,61 @@ twdtw_knn1 <- function(x, y, time_weight, cycle_length, time_scale,
ts_data <- unnest(ts_data, cols = 'observations')
ts_data <- nest(ts_data, .by = 'label', .key = "observations")

# Define GAM function
gam_fun <- function(band, t, pred_t, formula, ...){
df <- setNames(list(band, as.numeric(t)), all.vars(formula))
pred_t[[all.vars(formula)[2]]] <- as.numeric(pred_t[[all.vars(formula)[2]]])
fit <- mgcv::gam(data = df, formula = formula, ...)
predict(fit, newdata = pred_t)
}
# Apply smooth function
smooth_models <- lapply(ts_data$observations, function(ts) {

# Get timeline
y_time <- ts$time
ts$time <- NULL

# Fit smooth model to each band
smooth_models <- lapply(as.list(ts), function(band) {
resampling_fun(data.frame(x = as.numeric(y_time), y = band))
})

return(smooth_models)

# Apply GAM function
ts_data$observations <- lapply(ts_data$observations, function(ts){
})

names(smooth_models) <- ts_data$label

ts_data$observations <- lapply(seq_along(ts_data$observations), function(l) {

# Get timeline
ts <- ts_data$observations[[l]]
y_time <- ts$time
ts$time <- NULL
pred_time <- setNames(list(seq(min(y_time), max(y_time), by = sampling_freq)), all.vars(formula)[2])
cbind(pred_time, as.data.frame(sapply(as.list(ts), function(band) {
gam_fun(band, y_time, pred_time, formula, ...)
})))

# Determine time for resampling time sereis
if (is.null(resampling_freq)) {
pred_time <- unique(y_time)
} else {
pred_time <- seq(min(y_time), max(y_time), by = resampling_freq)
}

smoothed_data <- sapply(smooth_models[[l]], function(m) {
# Determine target class
target_class <- class(model.frame(m)[, 2])
# Convert pred_time based on target class
pred_points <- if (target_class == "factor") {
factor(as.numeric(pred_time), levels = levels(model.frame(m)[, 2]))
} else {
as(as.numeric(pred_time), target_class)
}
predict(m, newdata = data.frame(x = pred_points))
})

# Bind time and smoothed data into a data frame
result_df <- data.frame(time = pred_time, smoothed_data)

return(result_df)
})

}

model <- list()
model$call <- match.call()
model$formula <- formula
model$smooth_models <- smooth_models
model$data <- ts_data
# add twdtw arguments to model
model$twdtw_args <- list(time_weight = time_weight,
Expand All @@ -183,7 +232,7 @@ twdtw_knn1 <- function(x, y, time_weight, cycle_length, time_scale,
#'
#' @param x An object of class `twdtw_knn1`.
#' @param ... ignored
#'
#'
#' @return Invisible `twdtw_knn1` object.
#'
#' @export
Expand All @@ -195,9 +244,16 @@ print.twdtw_knn1 <- function(x, ...) {
cat("Call:\n")
print(x$call)

# Printing the formula, if available
cat("\nFormula:\n")
print(x$formula)
# Printing the smooth_fun, if available
cat("\nRoot Mean Squared Error (RMSE) of smooth models:\n")
if(is.null(x$smooth_models)){
print(NULL)
} else {
print(sapply(x$smooth_models, function(m1) sapply(m1, function(m2){
residuals <- try(resid(m2))
ifelse(is.numeric(residuals), sqrt(mean(residuals^2)), NA)
})))
}

# Printing the data summary
cat("\nData:\n")
Expand All @@ -223,6 +279,8 @@ print.twdtw_knn1 <- function(x, ...) {
#' pretty_arguments(formals(twdtw_knn1))
#' }
#'
#' @noRd
#' @keywords internal
pretty_arguments <- function(args) {

if (is.null(args)) {
Expand Down Expand Up @@ -259,6 +317,8 @@ pretty_arguments <- function(args) {
#'
#' @return A difftime object representing the most common time difference between consecutive samples.
#'
#' @noRd
#' @keywords internal
get_time_series_freq <- function(x) {

# Extract the time dimension
Expand Down
7 changes: 4 additions & 3 deletions R/zzz.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
#' @import sf
#' @import stars
#' @import ggplot2
#' @importFrom stats as.formula predict setNames
#' @importFrom mgcv gam s predict.gam
#' @importFrom stats predict model.frame resid
#' @importFrom methods as
#' @importFrom utils methods
#' @importFrom tidyr pivot_longer pivot_wider nest unnest
#' @importFrom proxy dist
#'
#'
NULL
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

<!-- badges: start -->
[![License](https://img.shields.io/badge/license-GPL%20%28%3E=%202%29-brightgreen.svg?style=flat)](https://www.gnu.org/licenses/gpl-3.0.html)
[![R-CMD-check](https://github.com/vwmaus/dtwSat/actions/workflows/R-CMD-check.yaml/badge.svg)](https://github.com/vwmaus/dtwSat/actions/workflows/R-CMD-check.yaml)
[![R-CMD-check](https://github.com/r-spatial/dtwSat/actions/workflows/R-CMD-check.yaml/badge.svg)](https://github.com/r-spatial/dtwSat/actions/workflows/R-CMD-check.yaml)
[![Coverage Status](https://img.shields.io/codecov/c/github/vwmaus/dtwSat/main.svg)](https://app.codecov.io/gh/vwmaus/dtwSat)
[![CRAN](https://www.r-pkg.org/badges/version/dtwSat)](https://cran.r-project.org/package=dtwSat)
[![Downloads](https://cranlogs.r-pkg.org/badges/dtwSat?color=brightgreen)](https://www.r-pkg.org/pkg/dtwSat)
Expand Down Expand Up @@ -33,7 +33,7 @@ install.packages("dtwSat")
Alternatively, you can install the development version from GitHub:

``` r
devtools::install_github("vwmaus/dtwSat")
devtools::install_github("r-spatial/dtwSat")
```

After installation, you can read the vignette for a quick start guide:
Expand Down
Loading
Loading