@@ -86,15 +86,18 @@ def _parse_gpu_ids(
8686 # We know the user requested GPUs therefore if some of the
8787 # requested GPUs are not available an exception is thrown.
8888 gpus = _normalize_parse_gpu_string_input (gpus )
89- gpus = _normalize_parse_gpu_input_to_list (gpus , include_cuda = include_cuda , include_mps = include_mps , include_musa = include_musa )
89+ gpus = _normalize_parse_gpu_input_to_list (
90+ gpus , include_cuda = include_cuda , include_mps = include_mps , include_musa = include_musa
91+ )
9092 if not gpus :
9193 raise MisconfigurationException ("GPUs requested but none are available." )
9294
9395 if (
9496 torch .distributed .is_available ()
9597 and torch .distributed .is_torchelastic_launched ()
9698 and len (gpus ) != 1
97- and len (_get_all_available_gpus (include_cuda = include_cuda , include_mps = include_mps , include_musa = include_musa )) == 1
99+ and len (_get_all_available_gpus (include_cuda = include_cuda , include_mps = include_mps , include_musa = include_musa ))
100+ == 1
98101 ):
99102 # Omit sanity check on torchelastic because by default it shows one visible GPU per process
100103 return gpus
@@ -115,7 +118,9 @@ def _normalize_parse_gpu_string_input(s: Union[int, str, list[int]]) -> Union[in
115118 return int (s .strip ())
116119
117120
118- def _sanitize_gpu_ids (gpus : list [int ], include_cuda : bool = False , include_mps : bool = False , include_musa : bool = False ) -> list [int ]:
121+ def _sanitize_gpu_ids (
122+ gpus : list [int ], include_cuda : bool = False , include_mps : bool = False , include_musa : bool = False
123+ ) -> list [int ]:
119124 """Checks that each of the GPUs in the list is actually available. Raises a MisconfigurationException if any of the
120125 GPUs is not available.
121126
@@ -132,7 +137,9 @@ def _sanitize_gpu_ids(gpus: list[int], include_cuda: bool = False, include_mps:
132137 """
133138 if sum ((include_cuda , include_mps , include_musa )) == 0 :
134139 raise ValueError ("At least one gpu type should be specified!" )
135- all_available_gpus = _get_all_available_gpus (include_cuda = include_cuda , include_mps = include_mps , include_musa = include_musa )
140+ all_available_gpus = _get_all_available_gpus (
141+ include_cuda = include_cuda , include_mps = include_mps , include_musa = include_musa
142+ )
136143 for gpu in gpus :
137144 if gpu not in all_available_gpus :
138145 raise MisconfigurationException (
@@ -157,7 +164,9 @@ def _normalize_parse_gpu_input_to_list(
157164 return list (range (gpus ))
158165
159166
160- def _get_all_available_gpus (include_cuda : bool = False , include_mps : bool = False , include_musa : bool = False ) -> list [int ]:
167+ def _get_all_available_gpus (
168+ include_cuda : bool = False , include_mps : bool = False , include_musa : bool = False
169+ ) -> list [int ]:
161170 """
162171 Returns:
163172 A list of all available GPUs
@@ -214,8 +223,8 @@ def _select_auto_accelerator() -> str:
214223 """Choose the accelerator type (str) based on availability."""
215224 from lightning .fabric .accelerators .cuda import CUDAAccelerator
216225 from lightning .fabric .accelerators .mps import MPSAccelerator
217- from lightning .fabric .accelerators .xla import XLAAccelerator
218226 from lightning .fabric .accelerators .musa import MUSAAccelerator
227+ from lightning .fabric .accelerators .xla import XLAAccelerator
219228
220229 if XLAAccelerator .is_available ():
221230 return "tpu"
0 commit comments