Super Learner for Efficient Estimation of Treatment Effects in Randomized Clinical Trials.
The sleete
package can be installed from GitHub with:
devtools::install_github('ya-wang-git/sleete')
The sleete
package depends on the SuperLearner
and survival
packages. In addition, any machine learning algorithms that will be used
should be made available by installing the corresponding packages or as
user-defined wrapper functions.
The sleete
package has two public functions: mleete()
and
sleete()
. The former uses one specified learning algorithm to estimate
the optimal augmentation, while the latter combines multiple algorithms
into a super learner. Except for that difference, the two functions
serve the same purpose and have similar syntax:
mleete(data, resp, event=NULL, trt, stratcov=NULL, basecov.cont=NULL,
basecov.cat=NULL, pi=NULL, bounds=c(-Inf, Inf), method="log.HR", ...,
SL.method="SL.lm", sample.splitting=TRUE, n.folds.cf=5)
sleete(data, resp, event=NULL, trt, stratcov=NULL, basecov.cont=NULL,
basecov.cat=NULL, pi=NULL, bounds=c(-Inf, Inf), method="log.HR", ...,
SL.library=c("SL.lm", "SL.gam", "SL.rpart", "SL.randomForest"),
n.folds.cv=5, sample.splitting=TRUE, n.folds.cf=5)
Arguments
-
The argument
data
is a data frame that contains the outcome, treatment and baseline covariate data to be used in the analysis. -
The argument
resp
provides the name of the (main) outcome variable. For a survival endpoint,resp
is the observed event time$X$ , which may be a failure time or a censoring time. -
The argument
event
provides the name of the event type indicator$\Delta$ (1 failure; 0 censoring) for a survival endpoint. This optional argument should be ignored for a non-survival endpoint. -
The argument
trt
provides the name of the treatment variable, a numerical or character string vector with two distinct values. In both cases, the higher value is assumed to represent the experimental treatment ($A = 1$ ). -
The argument
stratcov
provides the name(s) of the stratification variable(s) to be used in a stratified analysis (e.g., a stratified PH model). This optional argument should only be invoked when the initial estimator$\bar{\theta}$ involves stratification; otherwise it should be ignored. Note thatstratcov
pertains solely to the analysis that produces$\bar{\theta}$ and is unrelated to other types of stratification such as stratified randomization. -
The argument
basecov.cont
provides the names of continuous baseline covariates to be included in covariate adjustment. -
The argument
basecov.cat
provides the names of categorical baseline covariates to be included in covariate adjustment. These will be represented using dummy variables (one set of dummy variables for each categorical covariate). Ifstratcov
is specified, its components may or may not be included inbasecov.cat
. Our recommendation is to include any and all stratification variables inbasecov.cat
. -
The argument
pi
is the probability$\pi$ of receiving the experimental treatment, which should be known in a randomized clinical trial. If unspecified,pi
will be set equal to the actual proportion of study subjects who were assigned to the experimental treatment. -
The argument
bounds
represents known lower and upper bounds, if any, for the treatment effect measure to be estimated. The effect ofbounds
is to force point estimates of$\theta$ to stay within the natural range by truncating them if necessary. For example, if the effect measure is a difference in survival probability, the natural range for$\theta$ is$[−1, 1]$ , and one may specifybounds=c(-1,1)
. If$\theta$ is the difference in RMST at time$\tau$ , the natural range is$[−\tau, \tau]$ . -
The argument
method
specifies the initial estimator$\bar{\theta}$ and the associated influence function estimate$\hat{\psi}$ , which may be computed analytically or empirically. Operationally, themethod
object is a list of two mandatory components and one optional component. The two mandatory components arept.est
, a function for computing$\bar{\theta}$ , andinf.fct.avail
, a logical indicator for the availability of a function to compute$\hat{\psi}$ analytically. If the value ofinf.fct.avail
isTRUE
, one has to also supply a function namedinf.fct
to compute$\hat{\psi}$ analytically. If the value ofinf.fct.avail
isFALSE
, the functioninf.fct
is not needed and$\hat{\psi}$ will be computed using a general-purpose empirical influence function. Currently, there are five built-in methods available for survival endpoints:-
"log.HR"
for the log-HR in a PH model; -
"surv.diff"
for the difference in survival probability at a specified time point; -
"RMST.diff"
for the difference in RMST at a specified time point; -
"MW.cens"
for Mann–Whitney-type effect measures (such as the win-lose probability difference) for restricted survival times; -
"log.HR.strat"
for the log-HR in a stratified PH model.
-
The first three methods are implemented with analytical
influence functions, and the last two without (i.e.,
inf.fct.avail=FALSE
). In addition to the built-in methods, users can
define their own methods in the same format. For illustration, the
current definition of the "log.HR.strat"
method is provided below:
# point estimate
pt.est.log.HR.strat = function(mData) {
x = mData[,1]; delta = mData[,2]; a = mData[,3]; s = mData[,4]
dat = Surv(x, delta)
mod = coxph(dat~a+strata(s))
coef(mod)
}
# the method
log.HR.strat = list(pt.est=pt.est.log.HR.strat, inf.fct.avail=FALSE)
In the above code section, mData
is not the original data
frame but a re-structured numerical data matrix whose $i$th row is
-
Following the
method
argument, optional arguments may be provided in place of the dots(. . . )
. If specified, such optional arguments will be fed into the specified method. One example of an optional argument is the time point (namedtau
) required by the"surv.diff"
,"RMST.diff"
and"MW.cens"
methods. Another example is the kernel functionh
in the"MW.cens"
method, where the default leads to the win-lose probability difference. -
The argument
SL.method
, specific to themleete()
function, specifies the prediction algorithm that will be used to estimate the optimal augmentation. Its value must be the name of one of the wrapper functions in theSuperLearner
package. A full list of wrapper functions available inSuperLearner
can be found by typinglistWrappers()
. The default functionSL.lm
corresponds to the linear regression method. -
The argument
SL.library
, specific to thesleete()
function, specifies the super learner library (i.e., the collection of candidate algorithms to be combined). Its value is a character vector consisting of the names of the wrapper functions for the candidate algorithms. The default library consists of the following prediction algorithms: linear model, additive model, regression tree, and random forest. The corresponding wrapper functions areSL.lm
,SL.gam
,SL.rpart
, andSL.randomForest
, respectively. -
The argument
n.folds.cv
, also specific to thesleete()
function, specifies the number of folds in the cross-validation procedure within the super learner. This is not the same procedure for obtaining a cross-validated variance estimate. The latter is specified by thesample.splitting
andn.folds.cf
arguments. -
The argument
sample.splitting
is a logical indicator for the use of sample splitting or cross-fitting. The use of sample splitting is not essential for the linear regression method when the sample size is moderate (e.g., 250) or large, but is highly recommended when nonparametric machine learning methods are used to estimate the optimal augmentation. -
The argument
n.folds.cf
specifies the number of folds in the sample splitting or cross-fitting procedure. This procedure is external to the cross-validation procedure in the super learner; that is, in thesleete()
function withsample.splitting=TRUE
, super learning is performed separately within each partitioning that results from the sample splitting procedure.
Output
-
The
mleete()
function returns a$2 \times 2$ numerical matrix with point estimates in the first column and standard errors in the second column. The first row of the matrix is for the unadjusted method specified in themethod
argument. The second row is for the augmentation method based on the algorithm specified bySL.method
. The standard error for the unadjusted estimate is based on the (analytical or empirical) influence function. The standard error for the augmented estimate is based on either$\hat{\sigma}^2(\cdot)$ (without sample splitting) or$\tilde{\sigma}^2(\cdot)$ (with sample splitting). -
The
sleete()
function returns a$(L + 2) \times 2$ numerical matrix, where$L$ is the size of the super learner library (i.e., the length ofSL.library
). The first column contains point estimates of the treatment effect, and the second column provides standard errors. The first row of the matrix is for the unadjusted method specified in the method argument. The next$L + 1$ rows report augmented estimates and standard errors based on the individual algorithms in the super learner library (in the original order) followed by the super learner itself. The standard errors are obtained in the same manner as in themleete()
function.
The sleete
package will be illustrated using the colon
dataset from
the survival
package. The dataset contains data from a randomized
clinical trial of adjuvant therapy regimens for preventing cancer
recurrence and death after resection of stage III colon carcinoma. The
trial enrolled 929 eligible patients who had curative-intent resections
of stage III colon cancer in the previous one to five weeks, and
randomly assigned them to observation only (control), levamisole alone
or levamisole plus fluorouracil (L+F) with equal allocation among the
three groups.
The original colon
data frame has 1858 observations and 16 variables.
library(survival)
data(cancer)
dim(colon)
#> [1] 1858 16
The first six entries are:
head(colon)
#> id study rx sex age obstruct perfor adhere nodes status differ extent
#> 1 1 1 Lev+5FU 1 43 0 0 0 5 1 2 3
#> 2 1 1 Lev+5FU 1 43 0 0 0 5 1 2 3
#> 3 2 1 Lev+5FU 1 63 0 0 0 1 0 2 3
#> 4 2 1 Lev+5FU 1 63 0 0 0 1 0 2 3
#> 5 3 1 Obs 0 71 0 0 1 7 1 2 2
#> 6 3 1 Obs 0 71 0 0 1 7 1 2 2
#> surg node4 time etype
#> 1 0 1 1521 2
#> 2 0 1 968 1
#> 3 0 0 3087 2
#> 4 0 0 3087 1
#> 5 0 1 963 2
#> 6 0 1 542 1
id
: patient idstudy
: 1 for all patientsrx
: Treatment - Obs(ervation), Lev(amisole), Lev(amisole)+5-FUsex
: 0=female, 1=maleage
: in yearsobstruct
: obstruction of colon by tumor (0=no, 1=yes)perfor
: perforation of colon (0=no, 1=yes)adhere
: adherence to nearby organs (0=no, 1=yes)nodes
: number of lymph nodes with detectable cancerstatus
: censoring statusdiffer
: differentiation of tumour (1=well, 2=moderate, 3=poor)extent
: Extent of local spread (1=submucosa, 2=muscle, 3=serosa, 4=contiguous structures)surg
: time from surgery to registration (0=short, 1=long)node4
: more than 4 positive lymph nodes (0=no, 1=yes)time
: days until event or censoringetype
: event type (1=recurrence, 2=death)
We are interested in comparing the L+F regimen with control with respect
to overall survival. To create an appropriate dataset for our analysis,
we will exclude duplicate entries for cancer recurrence as well as
subjects in the levamisole-only group. The available baseline covariates
are sex
, age
, obstruct
, perfor
, adhere
, nodes
, differ
,
extent
, surg
, and node4
.
data = subset(colon, subset=((etype==2)&(rx!="Lev")))
dim(data)
#> [1] 619 16
We first illustrate the mleete()
function using three built-in methods
for survival endpoints: "log.HR"
(log-HR), "surv.diff"
(difference
in survival probability), and "RMST.diff"
(difference in RMST). In the
last two methods, the time argument tau
is set to five years, or
rather, mleete()
function, including the use of linear regression with
sample splitting.
library(sleete)
#> Loading required package: SuperLearner
#> Warning: package 'SuperLearner' was built under R version 4.4.1
#> Loading required package: nnls
#> Loading required package: gam
#> Loading required package: splines
#> Loading required package: foreach
#> Loaded gam 1.22-5
#> Super Learner
#> Version: 2.0-29
#> Package created on 2024-02-06
data$trt = as.numeric(data$rx=="Lev+5FU")
pi = 0.5; tau = 5*365
set.seed(12345)
Log HR
round(mleete(data, "time", event="status", trt="trt",
basecov.cont=c("age", "nodes", "differ", "extent"),
basecov.cat=c("sex", "obstruct", "perfor", "adhere", "surg", "node4"),
pi=pi, method="log.HR"), digits=3)
#> Pt. Est. Std. Err.
#> Unadjusted -0.385 0.121
#> Augmented -0.320 0.113
Difference in survival probability
round(mleete(data, "time", event="status", trt="trt",
basecov.cont=c("age", "nodes", "differ", "extent"),
basecov.cat=c("sex", "obstruct", "perfor", "adhere", "surg", "node4"),
pi=pi, method="surv.diff", tau=tau), digits=3)
#> Pt. Est. Std. Err.
#> Unadjusted 0.116 0.040
#> Augmented 0.094 0.038
Difference in RMST
round(mleete(data, "time", event="status", trt="trt",
basecov.cont=c("age", "nodes", "differ", "extent"),
basecov.cat=c("sex", "obstruct", "perfor", "adhere", "surg", "node4"),
pi=pi, method="RMST.diff", tau=tau), digits=1)
#> Pt. Est. Std. Err.
#> Unadjusted 119 47.6
#> Augmented 93 43.8
Next, we apply the sleete()
function with the default settings,
including the use of sample splitting and a collection of candidate
algorithms including a linear model, an analogous additive model, a
regression tree, and a random forest.
Log HR
round(sleete(data, "time", event="status", trt="trt",
basecov.cont=c("age", "nodes", "differ", "extent"),
basecov.cat=c("sex", "obstruct", "perfor", "adhere", "surg", "node4"),
pi=pi, method="log.HR"), digits=3)
#> Loading required namespace: randomForest
#> Loading required namespace: rpart
#> Pt. Est. Std. Err.
#> Unadjusted -0.385 0.121
#> SL.lm -0.333 0.116
#> SL.gam -0.336 0.116
#> SL.rpart -0.315 0.119
#> SL.randomForest -0.317 0.117
#> SL -0.326 0.116
Difference in survival probability
round(sleete(data, "time", event="status", trt="trt",
basecov.cont=c("age", "nodes", "differ", "extent"),
basecov.cat=c("sex", "obstruct", "perfor", "adhere", "surg", "node4"),
pi=pi, method="surv.diff", tau=tau), digits=3)
#> Pt. Est. Std. Err.
#> Unadjusted 0.116 0.040
#> SL.lm 0.098 0.038
#> SL.gam 0.099 0.038
#> SL.rpart 0.108 0.039
#> SL.randomForest 0.109 0.039
#> SL 0.102 0.038
Difference in RMST
round(sleete(data, "time", event="status", trt="trt",
basecov.cont=c("age", "nodes", "differ", "extent"),
basecov.cat=c("sex", "obstruct", "perfor", "adhere", "surg", "node4"),
pi=pi, method="RMST.diff", tau=tau), digits=1)
#> Pt. Est. Std. Err.
#> Unadjusted 119.0 47.6
#> SL.lm 91.0 44.9
#> SL.gam 90.6 45.0
#> SL.rpart 99.3 45.6
#> SL.randomForest 103.7 46.2
#> SL 98.4 44.9