forked from vanderbilt-data-science/ancient-artifacts
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path40-modeling.Rmd
84 lines (70 loc) · 3.33 KB
/
40-modeling.Rmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
---
title: "40-modeling"
output:
pdf_document: default
html_document: default
---
This model will end up containing some generalized method that we'll use for all the models. This will aid us in maintainability of our codebase, and also help us avoid copy/paste (and lack thereof) replication errors.
```{r modeling imports, results='hide'}
#load previous notebook data
source(knitr::purl("10-load-data.Rmd"))
fs::file_delete("10-load-data.R")
#load in required packages for here
pacman::p_load(tidymodels, usemodels, doParallel, doFuture)
#registerDoParallel()
registerDoFuture() #accre compat
plan(future::multisession, workers=11)
```
# Split the data
In this step, we perform a general data split as normal. We'll split with 75% to training and the other 25% to testing, stratified by the class. Note that we want to abstain from looking at the performance on the test data until the _very, VERY_ end. We might also want to think about unsetting this seed once the code is developed, and just recording whatever seed is randomly generated. This will allow us some variation on our generated data.
```{r ml data split}
set.seed(2434)
data_split <- initial_split(artifact_data, prop=3/4, strata=particle_class)
train_data <- training(data_split)
test_data <- testing(data_split)
```
# Cross-validation splits
Here, I'll add the cross validation splits for tuning. We'll use 5 fold cross-validation here, and we'll stratify by particle class one again.
```{r xval data split}
cv_folds <- vfold_cv(train_data, v=5, strata=particle_class)
cv_folds
```
# Model saving and loading
This section is for generalized model and related data saving and loading (moving).
```{r save model}
move_model_info <- function(model_prefix, move_type='save', file_post=NULL, box_dir=NULL) {
#Function move_model_info: loads or saves model data to/from Box
#Inputs: model_prefix: String of prefix of model variable name (e.g., 'glmnet', 'nb', etc)
# move_type: String 'load' or 'save'. To load or save model-related data from/to Box
# file_post: String of information to add to the end of the filename
# box_dir: String of directory base of where you want to save. By default, it will save to the shared Box drive.
#Outputs: move_type 'load' load variable information into the global environment. 'save' changes nothing in global env.
#Returns: String of the model loaded from/saved to
#find the base save directory
if(is.null(box_dir)){
box_dir <- tryCatch({
box_dir = get('box_base', envir=.GlobalEnv)
box_dir
}, error = function(){
message("Didn't find box_base in the global environment...using ~/../Box/DSI_AncientArtifacts/...")
path.expand('~/../Box/DSI_AncientArtifacts/')
})
}
#Fix the filename
rdata_name <- str_c(box_dir, 'RData/', model_prefix, '_modeling_info', file_post, '.RData')
#load or save to or from the global environment
if(move_type=='load'){
load(rdata_name, envir=.GlobalEnv)
}
else if (move_type=='save'){
save_params <- c(str_c(model_prefix, '_fold_metrics'),
str_c('best_', model_prefix, '_params'),
str_c(model_prefix, '_final_fit'))
save(list=save_params, file=rdata_name, envir=.GlobalEnv)
}
else {
stop("'load' and 'save' are the only two acceptable values for the move_type.")
}
return(rdata_name)
}
```