diff options
author | Ken Kellner <ken@kenkellner.com> | 2023-03-07 14:41:20 -0500 |
---|---|---|
committer | Ken Kellner <ken@kenkellner.com> | 2023-03-07 14:55:23 -0500 |
commit | e68ea989f7510991aaaf635140f1507d06b1c2d7 (patch) | |
tree | b0131b88568e54fa49ce9dc4a4ec9f479c936214 | |
parent | 8bf6bf43ca45808a5839e091534c6b4a4b494741 (diff) |
Add support for terra:rast in predict newdata
-rw-r--r-- | DESCRIPTION | 2 | ||||
-rw-r--r-- | R/predict.R | 45 | ||||
-rw-r--r-- | tests/testthat/test_predict.R | 38 |
3 files changed, 67 insertions, 18 deletions
diff --git a/DESCRIPTION b/DESCRIPTION index 5454b14..39e668c 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -32,7 +32,7 @@ Imports: stats, TMB (>= 1.7.18), utils -Suggests: knitr, rmarkdown, pkgdown, raster, shiny, testthat +Suggests: knitr, rmarkdown, pkgdown, raster, shiny, terra, testthat Description: Fits hierarchical models of animal abundance and occurrence to data collected using survey methods such as point counts, site occupancy sampling, distance sampling, removal sampling, and double observer sampling. Parameters governing the state and observation processes can be modeled as functions of covariates. Reference: Fiske and Chandler (2011) <doi:10.18637/jss.v043.i10>. License: GPL (>=3) LazyLoad: yes diff --git a/R/predict.R b/R/predict.R index 68efbb7..37b9f87 100644 --- a/R/predict.R +++ b/R/predict.R @@ -25,8 +25,7 @@ setMethod("predict", "unmarkedFit", orig_formula <- get_formula(object, type) # 2. If newdata is raster, get newdata from raster as data.frame - if(inherits(newdata, c("RasterLayer","RasterStack"))){ - if(!require(raster)) stop("raster package required", call.=FALSE) + if(inherits(newdata, c("RasterLayer","RasterStack","SpatRaster"))){ is_raster <- TRUE orig_raster <- newdata newdata <- newdata_from_raster(newdata, all.vars(orig_formula)) @@ -95,8 +94,8 @@ setMethod("check_predict_arguments", "unmarkedFit", check_type(object, type) # Check newdata class - if(!inherits(newdata, c("unmarkedFrame", "data.frame", "RasterLayer", "RasterStack"))){ - stop("newdata must be unmarkedFrame, data.frame, RasterLayer, or RasterStack", call.=FALSE) + if(!inherits(newdata, c("unmarkedFrame", "data.frame", "RasterLayer", "RasterStack", "SpatRaster"))){ + stop("newdata must be unmarkedFrame, data.frame, RasterLayer, RasterStack, or SpatRaster", call.=FALSE) } invisible(TRUE) }) @@ -261,27 +260,39 @@ setMethod("predict_by_chunk", "unmarkedFit", # Raster handling functions---------------------------------------------------- # Convert a raster into a data frame to use as newdata -newdata_from_raster <- function(rst, vars){ - nd <- raster::as.data.frame(rst) - # Handle factor rasters - is_fac <- raster::is.factor(rst) - rem_string <- paste(paste0("^",names(rst),"_"), collapse="|") - names(nd)[is_fac] <- gsub(rem_string, "", names(nd)[is_fac]) +newdata_from_raster <- function(object, vars){ + if(inherits(object, "Raster")){ + if(!requireNamespace("raster", quietly=TRUE)) stop("raster package required", call.=FALSE) + nd <- raster::as.data.frame(object) + # Handle factor rasters + is_fac <- raster::is.factor(object) + rem_string <- paste(paste0("^",names(object),"_"), collapse="|") + names(nd)[is_fac] <- gsub(rem_string, "", names(nd)[is_fac]) + } else if(inherits(object, "SpatRaster")){ + if(!requireNamespace("terra", quietly=TRUE)) stop("terra package required", call.=FALSE) + nd <- terra::as.data.frame(object) + } # Check if variables are missing no_match <- vars[! vars %in% names(nd)] if(length(no_match) > 0){ - stop(paste0("Variable(s) ",paste(no_match, collapse=", "), " not found in raster stack"), + stop(paste0("Variable(s) ",paste(no_match, collapse=", "), " not found in raster(s)"), call.=FALSE) } return(nd) } -# Convert predict output into a raster -raster_from_predict <- function(pr, orig_rst, appendData){ - new_rast <- data.frame(raster::coordinates(orig_rst), pr) - new_rast <- raster::stack(raster::rasterFromXYZ(new_rast)) - raster::crs(new_rast) <- raster::crs(orig_rst) - if(appendData) new_rast <- raster::stack(new_rast, orig_rst) +raster_from_predict <- function(pr, object, appendData){ + if(inherits(object, "Raster")){ + new_rast <- data.frame(raster::coordinates(object), pr) + new_rast <- raster::stack(raster::rasterFromXYZ(new_rast)) + raster::crs(new_rast) <- raster::crs(object) + if(appendData) new_rast <- raster::stack(new_rast, object) + } else if(inherits(object, "SpatRaster")){ + new_rast <- data.frame(terra::crds(object), pr) + new_rast <- terra::rast(new_rast, type="xyz") + terra::crs(new_rast) <- terra::crs(object) + if(appendData) new_rast <- c(new_rast, object) + } new_rast } diff --git a/tests/testthat/test_predict.R b/tests/testthat/test_predict.R index e92add1..686da9b 100644 --- a/tests/testthat/test_predict.R +++ b/tests/testthat/test_predict.R @@ -138,3 +138,41 @@ test_that("predicting from raster works",{ nd_2 <- nd_raster[[1]] expect_error(predict(mod, 'state', newdata=nd_2)) }) + +test_that("predicting from terra::rast works",{ + + skip_if(!require(terra), "terra package unavailable") + + set.seed(123) + # Create rasters + # Elevation + r_elev <- data.frame(x=rep(1:10, 10), y=rep(1:10, each=10), z=rnorm(100)) + r_elev <- terra::rast(r_elev, type="xyz") + + #Group + r_group <- data.frame(x=rep(1:10, 10), y=rep(1:10, each=10), + z=sample(1:length(levels(umf@siteCovs$group)), 100, replace=T)) + # Convert to 'factor' raster + r_group <- terra::as.factor(terra::rast(r_group, type="xyz")) + levels(r_group) <- data.frame(ID=terra::levels(r_group)[[1]]$ID, group=levels(umf@siteCovs$group)) + + # Stack + nd_raster <- c(r_elev, r_group) + names(nd_raster) <- c("elev", "group") + terra::crs(nd_raster) <- "epsg:32616" + + pr <- predict(mod, 'state', newdata=nd_raster) + expect_is(pr, 'SpatRaster') + expect_equal(names(pr), c("Predicted","SE","lower","upper")) + expect_equivalent(pr[1,1][1], 0.3675313, tol=1e-5) + expect_equal(crs(pr), crs(nd_raster)) + + #append data + pr <- predict(mod, 'state', newdata=nd_raster, appendData=TRUE) + expect_is(pr, 'SpatRaster') + expect_equal(names(pr)[5:6], c("elev","group")) + + # Missing levels are handled + nd_2 <- nd_raster[[1]] + expect_error(predict(mod, 'state', newdata=nd_2)) +}) |