Skip to content

Commit 4be9653

Browse files
authored
fix(transformers): fix import error if env contains accelerate module (#1431)
1 parent da4a28a commit 4be9653

File tree

6 files changed

+1095
-13
lines changed

6 files changed

+1095
-13
lines changed

mindone/transformers/integrations/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,6 @@
1616
# limitations under the License.
1717
from .accelerate import *
1818
from .flash_attention import *
19+
from .integration_utils import get_reporting_integration_callbacks
1920
from .peft import PeftAdapterMixin
2021
from .sdpa_attention import *
Lines changed: 260 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,260 @@
1+
# Copyright 2020 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""
15+
Integrations with other Python libraries.
16+
"""
17+
18+
import importlib.metadata
19+
import importlib.util
20+
import os
21+
from typing import TYPE_CHECKING
22+
23+
import packaging.version
24+
25+
if os.getenv("WANDB_MODE") == "offline":
26+
print("⚙️ Running in WANDB offline mode")
27+
28+
from ..utils import logging
29+
30+
logger = logging.get_logger(__name__)
31+
32+
# comet_ml requires to be imported before any ML frameworks
33+
_MIN_COMET_VERSION = "3.43.2"
34+
try:
35+
_comet_version = importlib.metadata.version("comet_ml")
36+
_is_comet_installed = True
37+
38+
_is_comet_recent_enough = packaging.version.parse(_comet_version) >= packaging.version.parse(_MIN_COMET_VERSION)
39+
40+
# Check if the Comet API Key is set
41+
import comet_ml
42+
43+
if comet_ml.config.get_config("comet.api_key") is not None:
44+
_is_comet_configured = True
45+
else:
46+
_is_comet_configured = False
47+
except (importlib.metadata.PackageNotFoundError, ImportError, ValueError, TypeError, AttributeError, KeyError):
48+
_comet_version = None
49+
_is_comet_installed = False
50+
_is_comet_recent_enough = False
51+
_is_comet_configured = False
52+
53+
_has_neptune = importlib.util.find_spec("neptune") is not None or importlib.util.find_spec("neptune-client") is not None
54+
if TYPE_CHECKING and _has_neptune:
55+
try:
56+
_neptune_version = importlib.metadata.version("neptune")
57+
logger.info(f"Neptune version {_neptune_version} available.")
58+
except importlib.metadata.PackageNotFoundError:
59+
try:
60+
_neptune_version = importlib.metadata.version("neptune-client")
61+
logger.info(f"Neptune-client version {_neptune_version} available.")
62+
except importlib.metadata.PackageNotFoundError:
63+
_has_neptune = False
64+
65+
from transformers.utils import ENV_VARS_TRUE_VALUES # noqa: E402
66+
67+
68+
# Integration functions:
69+
def is_wandb_available():
70+
# any value of WANDB_DISABLED disables wandb
71+
if os.getenv("WANDB_DISABLED", "").upper() in ENV_VARS_TRUE_VALUES:
72+
logger.warning(
73+
"Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the "
74+
"--report_to flag to control the integrations used for logging result (for instance --report_to none)."
75+
)
76+
return False
77+
if importlib.util.find_spec("wandb") is not None:
78+
import wandb
79+
80+
# wandb might still be detected by find_spec after an uninstall (leftover files or metadata), but not actually
81+
# import correctly. To confirm it's fully installed and usable, we check for a key attribute like "run".
82+
return hasattr(wandb, "run")
83+
else:
84+
return False
85+
86+
87+
def is_trackio_available():
88+
return importlib.util.find_spec("trackio") is not None
89+
90+
91+
def is_clearml_available():
92+
return importlib.util.find_spec("clearml") is not None
93+
94+
95+
def is_comet_available():
96+
if os.getenv("COMET_MODE", "").upper() == "DISABLED":
97+
logger.warning(
98+
"Using the `COMET_MODE=DISABLED` environment variable is deprecated and will be removed in v5. Use the "
99+
"--report_to flag to control the integrations used for logging result (for instance --report_to none)."
100+
)
101+
return False
102+
103+
if _is_comet_installed is False:
104+
return False
105+
106+
if _is_comet_recent_enough is False:
107+
logger.warning(
108+
"comet_ml version %s is installed, but version %s or higher is required. "
109+
"Please update comet_ml to the latest version to enable Comet logging with pip install 'comet-ml>=%s'.",
110+
_comet_version,
111+
_MIN_COMET_VERSION,
112+
_MIN_COMET_VERSION,
113+
)
114+
return False
115+
116+
if _is_comet_configured is False:
117+
logger.warning(
118+
"comet_ml is installed but the Comet API Key is not configured. "
119+
"Please set the `COMET_API_KEY` environment variable to enable Comet logging. "
120+
"Check out the documentation for other ways of configuring it: "
121+
"https://www.comet.com/docs/v2/guides/experiment-management/configure-sdk/#set-the-api-key"
122+
)
123+
return False
124+
125+
return True
126+
127+
128+
def is_tensorboard_available():
129+
return importlib.util.find_spec("tensorboard") is not None or importlib.util.find_spec("tensorboardX") is not None
130+
131+
132+
def is_optuna_available():
133+
return importlib.util.find_spec("optuna") is not None
134+
135+
136+
def is_ray_available():
137+
return importlib.util.find_spec("ray") is not None
138+
139+
140+
def is_ray_tune_available():
141+
if not is_ray_available():
142+
return False
143+
return importlib.util.find_spec("ray.tune") is not None
144+
145+
146+
def is_sigopt_available():
147+
return importlib.util.find_spec("sigopt") is not None
148+
149+
150+
def is_azureml_available():
151+
if importlib.util.find_spec("azureml") is None:
152+
return False
153+
if importlib.util.find_spec("azureml.core") is None:
154+
return False
155+
return importlib.util.find_spec("azureml.core.run") is not None
156+
157+
158+
def is_mlflow_available():
159+
if os.getenv("DISABLE_MLFLOW_INTEGRATION", "FALSE").upper() == "TRUE":
160+
return False
161+
return importlib.util.find_spec("mlflow") is not None
162+
163+
164+
def is_dagshub_available():
165+
return None not in [importlib.util.find_spec("dagshub"), importlib.util.find_spec("mlflow")]
166+
167+
168+
def is_neptune_available():
169+
return _has_neptune
170+
171+
172+
def is_codecarbon_available():
173+
return importlib.util.find_spec("codecarbon") is not None
174+
175+
176+
def is_flytekit_available():
177+
return importlib.util.find_spec("flytekit") is not None
178+
179+
180+
def is_flyte_deck_standard_available():
181+
if not is_flytekit_available():
182+
return False
183+
return importlib.util.find_spec("flytekitplugins.deck") is not None
184+
185+
186+
def is_dvclive_available():
187+
return importlib.util.find_spec("dvclive") is not None
188+
189+
190+
def is_swanlab_available():
191+
return importlib.util.find_spec("swanlab") is not None
192+
193+
194+
def get_available_reporting_integrations():
195+
integrations = []
196+
if is_azureml_available() and not is_mlflow_available():
197+
integrations.append("azure_ml")
198+
if is_comet_available():
199+
integrations.append("comet_ml")
200+
if is_dagshub_available():
201+
integrations.append("dagshub")
202+
if is_dvclive_available():
203+
integrations.append("dvclive")
204+
if is_mlflow_available():
205+
integrations.append("mlflow")
206+
if is_neptune_available():
207+
integrations.append("neptune")
208+
if is_tensorboard_available():
209+
integrations.append("tensorboard")
210+
if is_wandb_available():
211+
integrations.append("wandb")
212+
if is_codecarbon_available():
213+
integrations.append("codecarbon")
214+
if is_clearml_available():
215+
integrations.append("clearml")
216+
if is_swanlab_available():
217+
integrations.append("swanlab")
218+
if is_trackio_available():
219+
integrations.append("trackio")
220+
return integrations
221+
222+
223+
def rewrite_logs(d):
224+
new_d = {}
225+
eval_prefix = "eval_"
226+
eval_prefix_len = len(eval_prefix)
227+
test_prefix = "test_"
228+
test_prefix_len = len(test_prefix)
229+
for k, v in d.items():
230+
if k.startswith(eval_prefix):
231+
new_d["eval/" + k[eval_prefix_len:]] = v
232+
elif k.startswith(test_prefix):
233+
new_d["test/" + k[test_prefix_len:]] = v
234+
else:
235+
new_d["train/" + k] = v
236+
return new_d
237+
238+
239+
INTEGRATION_TO_CALLBACK = {}
240+
241+
242+
def get_reporting_integration_callbacks(report_to):
243+
if report_to is None:
244+
return []
245+
246+
if isinstance(report_to, str):
247+
if "none" == report_to:
248+
return []
249+
elif "all" == report_to:
250+
report_to = get_available_reporting_integrations()
251+
else:
252+
report_to = [report_to]
253+
254+
for integration in report_to:
255+
if integration not in INTEGRATION_TO_CALLBACK:
256+
raise ValueError(
257+
f"{integration} is not supported, only {', '.join(INTEGRATION_TO_CALLBACK.keys())} are supported."
258+
)
259+
260+
return [INTEGRATION_TO_CALLBACK[integration] for integration in report_to]

mindone/transformers/modelcard.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
from huggingface_hub import model_info
3030
from huggingface_hub.utils import HFValidationError
3131
from transformers import __version__
32-
from transformers.training_args import ParallelMode
3332
from transformers.utils import (
3433
MODEL_CARD_NAME,
3534
cached_file,
@@ -55,6 +54,7 @@
5554
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
5655
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES,
5756
)
57+
from .training_args import ParallelMode
5858
from .utils import is_mindspore_available
5959

6060
TASK_MAPPING = {

mindone/transformers/trainer.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -35,19 +35,7 @@
3535
import numpy as np
3636
from packaging import version
3737
from transformers import PreTrainedTokenizerBase
38-
from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor
39-
from transformers.integrations import get_reporting_integration_callbacks
4038
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
41-
from transformers.trainer_callback import (
42-
CallbackHandler,
43-
DefaultFlowCallback,
44-
ExportableState,
45-
PrinterCallback,
46-
ProgressCallback,
47-
TrainerCallback,
48-
TrainerControl,
49-
TrainerState,
50-
)
5139
from transformers.trainer_utils import (
5240
EvalPrediction,
5341
get_last_checkpoint,
@@ -62,6 +50,19 @@
6250
from mindspore.communication import GlobalComm
6351
from mindspore.communication.management import get_group_size
6452

53+
from mindone.transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor
54+
from mindone.transformers.integrations import get_reporting_integration_callbacks
55+
from mindone.transformers.trainer_callback import (
56+
CallbackHandler,
57+
DefaultFlowCallback,
58+
ExportableState,
59+
PrinterCallback,
60+
ProgressCallback,
61+
TrainerCallback,
62+
TrainerControl,
63+
TrainerState,
64+
)
65+
6566
from ..safetensors.mindspore import save_file
6667
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
6768
from .mindspore_adapter import RandomSampler, Sampler, TrainOneStepWrapper, auto_mixed_precision

0 commit comments

Comments
 (0)