aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKen Kellner <ken@kenkellner.com>2023-03-07 14:41:20 -0500
committerKen Kellner <ken@kenkellner.com>2023-03-07 14:55:23 -0500
commite68ea989f7510991aaaf635140f1507d06b1c2d7 (patch)
treeb0131b88568e54fa49ce9dc4a4ec9f479c936214
parent8bf6bf43ca45808a5839e091534c6b4a4b494741 (diff)
Add support for terra:rast in predict newdata
-rw-r--r--DESCRIPTION2
-rw-r--r--R/predict.R45
-rw-r--r--tests/testthat/test_predict.R38
3 files changed, 67 insertions, 18 deletions
diff --git a/DESCRIPTION b/DESCRIPTION
index 5454b14..39e668c 100644
--- a/DESCRIPTION
+++ b/DESCRIPTION
@@ -32,7 +32,7 @@ Imports:
stats,
TMB (>= 1.7.18),
utils
-Suggests: knitr, rmarkdown, pkgdown, raster, shiny, testthat
+Suggests: knitr, rmarkdown, pkgdown, raster, shiny, terra, testthat
Description: Fits hierarchical models of animal abundance and occurrence to data collected using survey methods such as point counts, site occupancy sampling, distance sampling, removal sampling, and double observer sampling. Parameters governing the state and observation processes can be modeled as functions of covariates. Reference: Fiske and Chandler (2011) <doi:10.18637/jss.v043.i10>.
License: GPL (>=3)
LazyLoad: yes
diff --git a/R/predict.R b/R/predict.R
index 68efbb7..37b9f87 100644
--- a/R/predict.R
+++ b/R/predict.R
@@ -25,8 +25,7 @@ setMethod("predict", "unmarkedFit",
orig_formula <- get_formula(object, type)
# 2. If newdata is raster, get newdata from raster as data.frame
- if(inherits(newdata, c("RasterLayer","RasterStack"))){
- if(!require(raster)) stop("raster package required", call.=FALSE)
+ if(inherits(newdata, c("RasterLayer","RasterStack","SpatRaster"))){
is_raster <- TRUE
orig_raster <- newdata
newdata <- newdata_from_raster(newdata, all.vars(orig_formula))
@@ -95,8 +94,8 @@ setMethod("check_predict_arguments", "unmarkedFit",
check_type(object, type)
# Check newdata class
- if(!inherits(newdata, c("unmarkedFrame", "data.frame", "RasterLayer", "RasterStack"))){
- stop("newdata must be unmarkedFrame, data.frame, RasterLayer, or RasterStack", call.=FALSE)
+ if(!inherits(newdata, c("unmarkedFrame", "data.frame", "RasterLayer", "RasterStack", "SpatRaster"))){
+ stop("newdata must be unmarkedFrame, data.frame, RasterLayer, RasterStack, or SpatRaster", call.=FALSE)
}
invisible(TRUE)
})
@@ -261,27 +260,39 @@ setMethod("predict_by_chunk", "unmarkedFit",
# Raster handling functions----------------------------------------------------
# Convert a raster into a data frame to use as newdata
-newdata_from_raster <- function(rst, vars){
- nd <- raster::as.data.frame(rst)
- # Handle factor rasters
- is_fac <- raster::is.factor(rst)
- rem_string <- paste(paste0("^",names(rst),"_"), collapse="|")
- names(nd)[is_fac] <- gsub(rem_string, "", names(nd)[is_fac])
+newdata_from_raster <- function(object, vars){
+ if(inherits(object, "Raster")){
+ if(!requireNamespace("raster", quietly=TRUE)) stop("raster package required", call.=FALSE)
+ nd <- raster::as.data.frame(object)
+ # Handle factor rasters
+ is_fac <- raster::is.factor(object)
+ rem_string <- paste(paste0("^",names(object),"_"), collapse="|")
+ names(nd)[is_fac] <- gsub(rem_string, "", names(nd)[is_fac])
+ } else if(inherits(object, "SpatRaster")){
+ if(!requireNamespace("terra", quietly=TRUE)) stop("terra package required", call.=FALSE)
+ nd <- terra::as.data.frame(object)
+ }
# Check if variables are missing
no_match <- vars[! vars %in% names(nd)]
if(length(no_match) > 0){
- stop(paste0("Variable(s) ",paste(no_match, collapse=", "), " not found in raster stack"),
+ stop(paste0("Variable(s) ",paste(no_match, collapse=", "), " not found in raster(s)"),
call.=FALSE)
}
return(nd)
}
-# Convert predict output into a raster
-raster_from_predict <- function(pr, orig_rst, appendData){
- new_rast <- data.frame(raster::coordinates(orig_rst), pr)
- new_rast <- raster::stack(raster::rasterFromXYZ(new_rast))
- raster::crs(new_rast) <- raster::crs(orig_rst)
- if(appendData) new_rast <- raster::stack(new_rast, orig_rst)
+raster_from_predict <- function(pr, object, appendData){
+ if(inherits(object, "Raster")){
+ new_rast <- data.frame(raster::coordinates(object), pr)
+ new_rast <- raster::stack(raster::rasterFromXYZ(new_rast))
+ raster::crs(new_rast) <- raster::crs(object)
+ if(appendData) new_rast <- raster::stack(new_rast, object)
+ } else if(inherits(object, "SpatRaster")){
+ new_rast <- data.frame(terra::crds(object), pr)
+ new_rast <- terra::rast(new_rast, type="xyz")
+ terra::crs(new_rast) <- terra::crs(object)
+ if(appendData) new_rast <- c(new_rast, object)
+ }
new_rast
}
diff --git a/tests/testthat/test_predict.R b/tests/testthat/test_predict.R
index e92add1..686da9b 100644
--- a/tests/testthat/test_predict.R
+++ b/tests/testthat/test_predict.R
@@ -138,3 +138,41 @@ test_that("predicting from raster works",{
nd_2 <- nd_raster[[1]]
expect_error(predict(mod, 'state', newdata=nd_2))
})
+
+test_that("predicting from terra::rast works",{
+
+ skip_if(!require(terra), "terra package unavailable")
+
+ set.seed(123)
+ # Create rasters
+ # Elevation
+ r_elev <- data.frame(x=rep(1:10, 10), y=rep(1:10, each=10), z=rnorm(100))
+ r_elev <- terra::rast(r_elev, type="xyz")
+
+ #Group
+ r_group <- data.frame(x=rep(1:10, 10), y=rep(1:10, each=10),
+ z=sample(1:length(levels(umf@siteCovs$group)), 100, replace=T))
+ # Convert to 'factor' raster
+ r_group <- terra::as.factor(terra::rast(r_group, type="xyz"))
+ levels(r_group) <- data.frame(ID=terra::levels(r_group)[[1]]$ID, group=levels(umf@siteCovs$group))
+
+ # Stack
+ nd_raster <- c(r_elev, r_group)
+ names(nd_raster) <- c("elev", "group")
+ terra::crs(nd_raster) <- "epsg:32616"
+
+ pr <- predict(mod, 'state', newdata=nd_raster)
+ expect_is(pr, 'SpatRaster')
+ expect_equal(names(pr), c("Predicted","SE","lower","upper"))
+ expect_equivalent(pr[1,1][1], 0.3675313, tol=1e-5)
+ expect_equal(crs(pr), crs(nd_raster))
+
+ #append data
+ pr <- predict(mod, 'state', newdata=nd_raster, appendData=TRUE)
+ expect_is(pr, 'SpatRaster')
+ expect_equal(names(pr)[5:6], c("elev","group"))
+
+ # Missing levels are handled
+ nd_2 <- nd_raster[[1]]
+ expect_error(predict(mod, 'state', newdata=nd_2))
+})