|
| 1 | +import os |
| 2 | +import shutil |
| 3 | +import time |
| 4 | +import uuid |
| 5 | +from unittest import TestCase, skip |
| 6 | + |
| 7 | +from huggingface_hub import HfApi |
| 8 | +from huggingface_hub.fastai_utils import ( |
| 9 | + _save_pretrained_fastai, |
| 10 | + from_pretrained_fastai, |
| 11 | + push_to_hub_fastai, |
| 12 | +) |
| 13 | +from huggingface_hub.file_download import ( |
| 14 | + is_fastai_available, |
| 15 | + is_fastcore_available, |
| 16 | + is_torch_available, |
| 17 | +) |
| 18 | + |
| 19 | +from .testing_constants import ENDPOINT_STAGING, TOKEN, USER |
| 20 | +from .testing_utils import set_write_permission_and_retry |
| 21 | + |
| 22 | + |
| 23 | +def repo_name(id=uuid.uuid4().hex[:6]): |
| 24 | + return "fastai-repo-{0}-{1}".format(id, int(time.time() * 10e3)) |
| 25 | + |
| 26 | + |
| 27 | +WORKING_REPO_SUBDIR = f"fixtures/working_repo_{__name__.split('.')[-1]}" |
| 28 | +WORKING_REPO_DIR = os.path.join( |
| 29 | + os.path.dirname(os.path.abspath(__file__)), WORKING_REPO_SUBDIR |
| 30 | +) |
| 31 | + |
| 32 | +if is_fastai_available(): |
| 33 | + from fastai.data.block import DataBlock |
| 34 | + from fastai.test_utils import synth_learner |
| 35 | + |
| 36 | +if is_torch_available(): |
| 37 | + import torch |
| 38 | + |
| 39 | + |
| 40 | +def require_fastai_fastcore(test_case): |
| 41 | + """ |
| 42 | + Decorator marking a test that requires fastai and fastcore. |
| 43 | + These tests are skipped when fastai and fastcore are not installed. |
| 44 | + """ |
| 45 | + if not is_fastai_available(): |
| 46 | + return skip("Test requires fastai")(test_case) |
| 47 | + elif not is_fastcore_available(): |
| 48 | + return skip("Test requires fastcore")(test_case) |
| 49 | + else: |
| 50 | + return test_case |
| 51 | + |
| 52 | + |
| 53 | +def fake_dataloaders(a=2, b=3, bs=16, n=10): |
| 54 | + def get_data(n): |
| 55 | + x = torch.randn(bs * n, 1) |
| 56 | + return torch.cat((x, a * x + b + 0.1 * torch.randn(bs * n, 1)), 1) |
| 57 | + |
| 58 | + ds = get_data(n) |
| 59 | + dblock = DataBlock() |
| 60 | + return dblock.dataloaders(ds) |
| 61 | + |
| 62 | + |
| 63 | +if is_fastai_available(): |
| 64 | + dummy_model = synth_learner(data=fake_dataloaders()) |
| 65 | + dummy_config = dict(test="test_0") |
| 66 | +else: |
| 67 | + dummy_model = None |
| 68 | + dummy_config = None |
| 69 | + |
| 70 | + |
| 71 | +@require_fastai_fastcore |
| 72 | +class TestFastaiUtils(TestCase): |
| 73 | + @classmethod |
| 74 | + def setUpClass(cls): |
| 75 | + """ |
| 76 | + Share this valid token in all tests below. |
| 77 | + """ |
| 78 | + cls._api = HfApi(endpoint=ENDPOINT_STAGING) |
| 79 | + cls._token = TOKEN |
| 80 | + cls._api.set_access_token(TOKEN) |
| 81 | + |
| 82 | + def tearDown(self) -> None: |
| 83 | + try: |
| 84 | + shutil.rmtree(WORKING_REPO_DIR, onerror=set_write_permission_and_retry) |
| 85 | + except FileNotFoundError: |
| 86 | + pass |
| 87 | + |
| 88 | + def test_save_pretrained_without_config(self): |
| 89 | + REPO_NAME = repo_name("save") |
| 90 | + _save_pretrained_fastai(dummy_model, f"{WORKING_REPO_DIR}/{REPO_NAME}") |
| 91 | + files = os.listdir(f"{WORKING_REPO_DIR}/{REPO_NAME}") |
| 92 | + self.assertTrue("model.pkl" in files) |
| 93 | + self.assertTrue("pyproject.toml" in files) |
| 94 | + self.assertTrue("README.md" in files) |
| 95 | + self.assertEqual(len(files), 3) |
| 96 | + |
| 97 | + def test_save_pretrained_with_config(self): |
| 98 | + REPO_NAME = repo_name("save") |
| 99 | + _save_pretrained_fastai( |
| 100 | + dummy_model, f"{WORKING_REPO_DIR}/{REPO_NAME}", config=dummy_config |
| 101 | + ) |
| 102 | + files = os.listdir(f"{WORKING_REPO_DIR}/{REPO_NAME}") |
| 103 | + self.assertTrue("config.json" in files) |
| 104 | + self.assertEqual(len(files), 4) |
| 105 | + |
| 106 | + def test_push_to_hub_and_from_pretrained_fastai(self): |
| 107 | + REPO_NAME = repo_name("push_to_hub") |
| 108 | + push_to_hub_fastai( |
| 109 | + learner=dummy_model, |
| 110 | + repo_id=f"{USER}/{REPO_NAME}", |
| 111 | + token=self._token, |
| 112 | + config=dummy_config, |
| 113 | + ) |
| 114 | + model_info = self._api.model_info( |
| 115 | + f"{USER}/{REPO_NAME}", |
| 116 | + ) |
| 117 | + self.assertEqual(model_info.modelId, f"{USER}/{REPO_NAME}") |
| 118 | + |
| 119 | + loaded_model = from_pretrained_fastai(f"{USER}/{REPO_NAME}") |
| 120 | + self.assertEqual( |
| 121 | + dummy_model.show_training_loop(), loaded_model.show_training_loop() |
| 122 | + ) |
| 123 | + self._api.delete_repo(repo_id=f"{REPO_NAME}", token=self._token) |
0 commit comments