|
11 | 11 | -------------------------------------------------------------------------------- |
12 | 12 | """ |
13 | 13 |
|
| 14 | +import os |
14 | 15 | import unittest as ut |
15 | 16 | import numpy as np |
16 | 17 | import torch |
17 | 18 | import pandas as pd |
18 | 19 | import tempfile |
19 | 20 |
|
20 | | -import graphium |
21 | 21 | from graphium.utils.fs import rm, exists, get_size |
22 | 22 | from graphium.data import GraphOGBDataModule, MultitaskFromSmilesDataModule |
23 | 23 |
|
24 | 24 | import graphium_cpp |
25 | 25 |
|
26 | 26 | TEMP_CACHE_DATA_PATH = "tests/temp_cache_0000" |
27 | 27 |
|
28 | | - |
29 | 28 | class test_DataModule(ut.TestCase): |
| 29 | + |
30 | 30 | def test_ogb_datamodule(self): |
31 | 31 | # other datasets are too large to be tested |
32 | 32 | dataset_names = ["ogbg-molhiv", "ogbg-molpcba", "ogbg-moltox21", "ogbg-molfreesolv"] |
@@ -380,7 +380,7 @@ def test_datamodule_multiple_data_files(self): |
380 | 380 |
|
381 | 381 | self.assertEqual(len(ds.train_ds), 20) |
382 | 382 |
|
383 | | - def test_splits_file(self, tmp_path): |
| 383 | + def test_splits_file(self): |
384 | 384 | # Test single CSV files |
385 | 385 | csv_file = "tests/data/micro_ZINC_shard_1.csv" |
386 | 386 | df = pd.read_csv(csv_file) |
@@ -423,15 +423,17 @@ def test_splits_file(self, tmp_path): |
423 | 423 | self.assertEqual(len(ds.val_ds), len(split_val)) |
424 | 424 | self.assertEqual(len(ds.test_ds), len(split_test)) |
425 | 425 |
|
426 | | - # Create a TemporaryFile to save the splits, and test the datamodule |
427 | | - with tempfile.NamedTemporaryFile(suffix=".pt", dir=tmp_path) as temp: |
| 426 | + try: |
| 427 | + # Create a TemporaryFile to save the splits, and test the datamodule |
| 428 | + temp_file = tempfile.NamedTemporaryFile(suffix=".pt", delete=False) |
| 429 | + |
428 | 430 | # Save the splits |
429 | | - torch.save(splits, temp) |
| 431 | + torch.save(splits, temp_file) |
430 | 432 |
|
431 | 433 | # Test the datamodule |
432 | 434 | task_kwargs = { |
433 | 435 | "df_path": csv_file, |
434 | | - "splits_path": temp.name, |
| 436 | + "splits_path": temp_file.name, |
435 | 437 | "split_val": 0.0, |
436 | 438 | "split_test": 0.0, |
437 | 439 | } |
@@ -468,6 +470,10 @@ def test_splits_file(self, tmp_path): |
468 | 470 | ) |
469 | 471 | np.testing.assert_array_equal(ds.val_ds.smiles_offsets_tensor, ds2.val_ds.smiles_offsets_tensor) |
470 | 472 | np.testing.assert_array_equal(ds.test_ds.smiles_offsets_tensor, ds2.test_ds.smiles_offsets_tensor) |
| 473 | + |
| 474 | + finally: |
| 475 | + temp_file.close() |
| 476 | + os.unlink(temp_file.name) |
471 | 477 |
|
472 | 478 |
|
473 | 479 | if __name__ == "__main__": |
|
0 commit comments