Skip to content

Commit e5c5306

Browse files
authored
multi-outcome tabnet-fit and predict (#118)
* add multi-output and multilabel test * switch to `expect_no_error()` for readability * `momentum` consistently default to 0.02 switch target `y` from vector to array * turn `output_dim` into vector when multi_output manage loss cases for multi_output lint and refactor code * pass `is_multi_outcome` to predict encode output_dim for multi-outcome improve multi-outcome classification loss split predict based on `is_multi_outcome` * working predict_impl_class and predict_numeric switch to hardhat v1.3.0 * refactor predict_impl_ for a clearer case_when() call * improve `check_type` to manage multi-outcome fix tests vqlues add mixed-outcome and multi-outcome with valid test * add consistency checks for outcome types * add multi-output description in `tabnet-fit` move multi-output tests is a dedicated file * improve multi-outcome tests fix multi-outcome classification
1 parent 162134c commit e5c5306

33 files changed

+841
-512
lines changed

.Rbuildignore

+2
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,5 @@
99
^cran-comments\.md$
1010
^CRAN-RELEASE$
1111
^.V8*
12+
^doc$
13+
^Meta$

.github/workflows/R-CMD-check.yaml

+24-84
Original file line numberDiff line numberDiff line change
@@ -35,66 +35,33 @@ jobs:
3535
TORCH_TEST: 1
3636

3737
steps:
38-
- uses: actions/checkout@v2
38+
- uses: actions/checkout@v3
3939

40-
- uses: r-lib/actions/setup-r@v1
40+
- uses: r-lib/actions/setup-r@v2
4141
with:
4242
r-version: ${{ matrix.config.r }}
4343

44-
- uses: r-lib/actions/setup-pandoc@v1
45-
46-
- name: Query dependencies
47-
run: |
48-
install.packages('remotes')
49-
saveRDS(remotes::dev_package_deps(dependencies = TRUE), ".github/depends.Rds", version = 2)
50-
writeLines(sprintf("R-%i.%i", getRversion()$major, getRversion()$minor), ".github/R-version")
51-
shell: Rscript {0}
44+
- uses: r-lib/actions/setup-pandoc@v2
5245

53-
- name: Cache R packages
54-
if: runner.os != 'Windows'
55-
uses: actions/cache@v2
46+
- uses: r-lib/actions/setup-r-dependencies@v2
5647
with:
57-
path: ${{ env.R_LIBS_USER }}
58-
key: ${{ runner.os }}-${{ hashFiles('.github/R-version') }}-1-${{ hashFiles('.github/depends.Rds') }}
59-
restore-keys: ${{ runner.os }}-${{ hashFiles('.github/R-version') }}-1-
60-
61-
- name: Install system dependencies
62-
if: runner.os == 'Linux'
63-
run: |
64-
while read -r cmd
65-
do
66-
eval sudo $cmd
67-
done < <(Rscript -e 'writeLines(remotes::system_requirements("ubuntu", "20.04"))')
68-
69-
- name: Install macOS dependencies
70-
if: runner.os == 'macOS'
71-
run: brew install --cask xquartz
72-
73-
- name: Install dependencies
74-
run: |
75-
remotes::install_deps(dependencies = TRUE)
76-
remotes::install_cran("rcmdcheck")
77-
shell: Rscript {0}
78-
79-
- name: Check
80-
env:
81-
_R_CHECK_CRAN_INCOMING_REMOTE_: false
82-
run: rcmdcheck::rcmdcheck(args = c("--no-manual", "--as-cran"), error_on = "warning", check_dir = "check")
83-
shell: Rscript {0}
84-
85-
- name: Upload check results
86-
if: failure()
87-
uses: actions/upload-artifact@main
48+
extra-packages: any::rcmdcheck
49+
needs: check
50+
51+
- uses: r-lib/actions/check-r-package@v2
8852
with:
89-
name: ${{ runner.os }}-r${{ matrix.config.r }}-results
90-
path: check
53+
error-on: '"error"'
54+
args: 'c("--no-multiarch", "--no-manual", "--as-cran")'
55+
9156
GPU:
9257
runs-on: ['self-hosted', 'gce', 'gpu']
9358
name: 'gpu'
9459

9560
container:
96-
image: nvidia/cuda:11.6.0-cudnn8-devel-ubuntu18.04
97-
options: --gpus all
61+
image: 'nvidia/cuda:11.6.0-cudnn8-devel-ubuntu18.04'
62+
options: '--gpus all --runtime=nvidia'
63+
64+
timeout-minutes: 120
9865

9966
env:
10067
R_REMOTES_NO_ERRORS_FROM_WARNINGS: true
@@ -105,49 +72,22 @@ jobs:
10572
DEBIAN_FRONTEND: 'noninteractive'
10673

10774
steps:
108-
- uses: actions/checkout@v2
75+
- uses: actions/checkout@v3
10976

11077
- run: |
11178
apt-get update -y
112-
apt-get install -y sudo software-properties-common dialog apt-utils tzdata
79+
apt-get install -y sudo software-properties-common dialog apt-utils tzdata libpng-dev
11380
11481
- uses: r-lib/actions/setup-r@v2
115-
with:
116-
r-version: 'release'
11782

11883
- uses: r-lib/actions/setup-pandoc@v2
11984

120-
- name: Query dependencies
121-
run: |
122-
install.packages('remotes')
123-
saveRDS(remotes::dev_package_deps(dependencies = TRUE), ".github/depends.Rds", version = 2)
124-
writeLines(sprintf("R-%i.%i", getRversion()$major, getRversion()$minor), ".github/R-version")
125-
shell: Rscript {0}
126-
127-
- name: Cache R packages
128-
if: runner.os != 'Windows'
129-
uses: actions/cache@v2
85+
- uses: r-lib/actions/setup-r-dependencies@v2
13086
with:
131-
path: ${{ env.R_LIBS_USER }}
132-
key: ${{ runner.os }}-${{ hashFiles('.github/R-version') }}-1-${{ hashFiles('.github/depends.Rds') }}
133-
restore-keys: ${{ runner.os }}-${{ hashFiles('.github/R-version') }}-1-
134-
135-
- name: Install dependencies
136-
run: |
137-
remotes::install_deps(dependencies = TRUE)
138-
remotes::install_cran("rcmdcheck")
139-
shell: Rscript {0}
140-
141-
- name: Check
142-
env:
143-
_R_CHECK_CRAN_INCOMING_REMOTE_: false
144-
run: rcmdcheck::rcmdcheck(args = c("--no-manual", "--as-cran"), error_on = "error", check_dir = "check")
145-
shell: Rscript {0}
146-
147-
- name: Upload check results
148-
if: failure()
149-
uses: actions/upload-artifact@main
150-
with:
151-
name: ${{ runner.os }}-r${{ matrix.config.r }}-results
152-
path: check
87+
extra-packages: any::rcmdcheck
88+
needs: check
15389

90+
- uses: r-lib/actions/check-r-package@v2
91+
with:
92+
error-on: '"error"'
93+
args: 'c("--no-multiarch", "--no-manual", "--as-cran")'

.github/workflows/test-coverage.yaml

+27-28
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@ jobs:
1313
runs-on: ['self-hosted', 'gce', 'gpu']
1414

1515
container:
16-
image: nvidia/cuda:11.6.0-cudnn8-devel-ubuntu18.04
17-
options: --gpus all
16+
image: 'nvidia/cuda:11.6.0-cudnn8-devel-ubuntu18.04'
17+
options: '--gpus all --runtime=nvidia'
18+
19+
timeout-minutes: 120
1820

1921
env:
2022
RSPM: https://packagemanager.rstudio.com/cran/__linux__/bionic/latest
@@ -24,42 +26,39 @@ jobs:
2426
DEBIAN_FRONTEND: 'noninteractive'
2527

2628
steps:
27-
- uses: actions/checkout@v2
29+
- uses: actions/checkout@v3
2830

2931
- run: |
3032
apt-get update -y
3133
apt-get install -y sudo software-properties-common dialog apt-utils tzdata
32-
- uses: r-lib/actions/setup-r@v1
33-
id: install-r
3434
35-
- name: Install pak and query dependencies
36-
run: |
37-
install.packages("pak", repos = "https://r-lib.github.io/p/pak/dev/")
38-
saveRDS(pak::pkg_deps("local::.", dependencies = TRUE), ".github/r-depends.rds")
39-
shell: Rscript {0}
35+
- uses: r-lib/actions/setup-r@v2
4036

41-
- name: Restore R package cache
42-
uses: actions/cache@v2
37+
- uses: r-lib/actions/setup-r-dependencies@v2
4338
with:
44-
path: |
45-
${{ env.R_LIBS_USER }}/*
46-
!${{ env.R_LIBS_USER }}/pak
47-
key: ubuntu-18.04-${{ steps.install-r.outputs.installed-r-version }}-1-${{ hashFiles('.github/r-depends.rds') }}
48-
restore-keys: ubuntu-18.04-${{ steps.install-r.outputs.installed-r-version }}-1-
39+
extra-packages: |
40+
any::covr
4941
50-
- name: Install system dependencies
51-
if: runner.os == 'Linux'
42+
- name: Test coverage
5243
run: |
53-
pak::local_system_requirements(execute = TRUE)
54-
pak::pkg_system_requirements("covr", execute = TRUE)
44+
covr::codecov(
45+
quiet = FALSE,
46+
clean = FALSE,
47+
install_path = file.path(Sys.getenv("RUNNER_TEMP"), "package")
48+
)
5549
shell: Rscript {0}
5650

57-
- name: Install dependencies
51+
- name: Show testthat output
52+
if: always()
5853
run: |
59-
pak::local_install_dev_deps(upgrade = TRUE)
60-
pak::pkg_install("covr")
61-
shell: Rscript {0}
54+
## --------------------------------------------------------------------
55+
find ${{ runner.temp }}/package -name 'testthat.Rout*' -exec cat '{}' \; || true
56+
shell: bash
57+
58+
- name: Upload test results
59+
if: failure()
60+
uses: actions/upload-artifact@v3
61+
with:
62+
name: coverage-test-failures
63+
path: ${{ runner.temp }}/package
6264

63-
- name: Test coverage
64-
run: covr::codecov()
65-
shell: Rscript {0}

.gitignore

+3
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,6 @@ inst/doc
55
.venv
66
activate
77
.V8history
8+
/doc/
9+
/Meta/
10+
revdep

DESCRIPTION

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Package: tabnet
22
Title: Fit 'TabNet' Models for Classification and Regression
3-
Version: 0.3.0.9000
3+
Version: 0.4.0
44
Authors@R: c(
55
person(given = "Daniel", family = "Falbel", role = c("aut"), email = "[email protected]"),
66
person(family = "RStudio", role = c("cph")),
@@ -19,7 +19,7 @@ URL: https://github.com/mlverse/tabnet
1919
BugReports: https://github.com/mlverse/tabnet/issues
2020
Imports:
2121
torch (>= 0.4.0),
22-
hardhat,
22+
hardhat (>= 1.3.0),
2323
magrittr,
2424
glue,
2525
progress,

NEWS.md

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
# tabnet (development version)
2-
31
# tabnet 0.4.0
42

53
## New features
@@ -11,18 +9,20 @@
119
* Allow missing-values values in predictor for unsupervised training. (#68)
1210
* Improve performance of `random_obfuscator()` torch_nn module. (#68)
1311
* Add support for early stopping (#69)
14-
* `tabnet_fit()` and `predict()` now allow missing values in predictors. (#76)
12+
* `tabnet_fit()` and `predict()` now allow **missing values** in predictors. (#76)
1513
* `tabnet_config()` now supports a `num_workers=` parameters to control parallel dataloading (#83)
14+
* Add a vignette on missing data (#83)
1615
* `tabnet_config()` now has a flag `skip_importance` to skip calculating feature importance (@egillax, #91)
1716
* Export and document `tabnet_nn`
1817
* Added `min_grid.tabnet` method for `tune` (@cphaarmeyer, #107)
1918
* Added `tabnet_explain()` method for parsnip models (@cphaarmeyer, #108)
19+
* `tabnet_fit()` and `predict()` now allow **multi-outcome**, all numeric or all factors but not mixed. (#118)
2020

2121
## Bugfixes
2222

2323
* `tabnet_explain()` is now correctly handling missing values in predictors. (#77)
2424
* `dataloader` can now use `num_workers>0` (#83)
25-
* new default values for `batch_size` and `virtual_batch_size` do not limit performance on mid-range devices.
25+
* new default values for `batch_size` and `virtual_batch_size` improves performance on mid-range devices.
2626
* add default `engine="torch"` to tabnet parsnip model (#114)
2727
* fix `autoplot()` warnings turned into errors with {ggplot2} v3.4 (#113)
2828

0 commit comments

Comments
 (0)