Skip to content

Commit 7409f09

Browse files
fix #140
1 parent fb431c5 commit 7409f09

File tree

5 files changed

+55
-15
lines changed

5 files changed

+55
-15
lines changed

R/activation.R

+10-8
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
#' @param alpha (float) the weight of ELU activation component.
66
#' @param beta (float) the weight of PReLU activation component.
77
#' @param gamma (float) the weight of SiLU activation component.
8-
#' @param init (float): the initial value of \eqn{a} of PReLU. Default: 0.25.
8+
#' @param weight (torch_tensor): the initial value of \eqn{weight} of PReLU. Default: 0.25.
99
#'
1010
#' @return an activation function computing
1111
#' \eqn{\mathbf{MBwLU(input) = \alpha \times ELU(input) + \beta \times PReLU(input) + \gamma \times SiLU(input)}}
@@ -20,14 +20,15 @@
2020
#' @export
2121
nn_mb_wlu <- torch::nn_module(
2222
"multibranch Weighted Linear Unit",
23-
initialize = function(alpha = 0.6, beta = 0.2, gamma = 0.2, init = 0.25) {
23+
initialize = function(alpha = 0.6, beta = 0.2, gamma = 0.2, weight = torch::torch_tensor(0.25)) {
24+
stopifnot("weight must be a torch_tensor()" = inherits(weight, "torch_tensor"))
2425
self$alpha <- alpha
2526
self$beta <- beta
2627
self$gamma <- gamma
27-
self$init <- init
28+
self$weight <- weight
2829
},
2930
forward = function(input) {
30-
nnf_mb_wlu(input, self$alpha, self$beta, self$gamma, self$init)
31+
nnf_mb_wlu(input, self$alpha, self$beta, self$gamma, self$weight)
3132
}
3233
)
3334

@@ -36,9 +37,10 @@ nn_mb_wlu <- torch::nn_module(
3637
#' @seealso [nn_mb_wlu()].
3738
#' @export
3839
#' @rdname nn_mb_wlu
39-
nnf_mb_wlu <- function(input, alpha = 0.6, beta = 0.2, gamma = 0.2, init = 0.25) {
40-
alpha * torch::nnf_elu(input) +
41-
beta * torch::nnf_prelu(input, init) +
42-
gamma * torch::nnf_silu(input)
40+
nnf_mb_wlu <- function(input, alpha = 0.6, beta = 0.2, gamma = 0.2, weight = torch::torch_tensor(0.25)) {
41+
stopifnot("weight and input must reside on the same device" = weight$device == input$device)
42+
alpha * torch::nnf_elu(input) +
43+
beta * torch::nnf_prelu(input, weight) +
44+
gamma * torch::nnf_silu(input)
4345

4446
}

R/model.R

+8-2
Original file line numberDiff line numberDiff line change
@@ -256,12 +256,18 @@ interpretabnet_config <- function(mask_type = "entmax",
256256
mlp_hidden_multiplier = c(4,2),
257257
mlp_activation = NULL,
258258
encoder_activation = nn_mb_wlu(), ...) {
259-
tabnet_config(mask_type = mask_type,
259+
interpretabnet_conf <- tabnet_config(mask_type = mask_type,
260260
mlp_hidden_multiplier = mlp_hidden_multiplier,
261261
mlp_activation = mlp_activation,
262262
encoder_activation = encoder_activation,
263263
...)
264-
264+
# align nn_mb_wlu weight device with the config device
265+
device <- get_device_from_config(interpretabnet_conf)
266+
if (!grepl(device,interpretabnet_conf$encoder_activation$weight$device )) {
267+
# move the weight to the config device
268+
interpretabnet_conf$encoder_activation$weight <- interpretabnet_conf$encoder_activation$weight$to(device = device)
269+
}
270+
interpretabnet_conf
265271
}
266272

267273
get_constr_output <- function(x, R) {

man/nn_mb_wlu.Rd

+14-3
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/tabnet_config.Rd

+7-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test-activation.R

+16-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
test_that("multibranch_weighted_linear_unit nn_module works", {
1+
test_that("multibranch_weighted_linear_unit activation works", {
22
mb_wlu <- nn_mb_wlu()
33
input <- torch::torch_tensor(c(-1.0, 0.0, 1.0))
44
expected_output <- torch::torch_tensor(c(-0.48306063, 0.0, 0.94621176))
@@ -7,3 +7,18 @@ test_that("multibranch_weighted_linear_unit nn_module works", {
77

88
})
99

10+
test_that("multibranch_weighted_linear_unit correctly prevent weight not being a tensor", {
11+
expect_error(mb_wlu <- nn_mb_wlu( weight = 0.25),
12+
regexp = "must be a torch_tensor")
13+
})
14+
15+
test_that("multibranch_weighted_linear_unit correctly prevent weight not being on the same device", {
16+
skip_if_not(torch::backends_openmp_is_available())
17+
weight <- torch::torch_tensor(0.25)$to(device = "cpu")
18+
z <- torch::torch_randr(c(2,2))$to(device = "openmp")
19+
20+
expect_no_error(mb_wlu <- nn_mb_wlu( weight = weight))
21+
expect_error(mb_wlu(z),
22+
regexp = "reside on the same device")
23+
})
24+

0 commit comments

Comments
 (0)