-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtools.py
45 lines (38 loc) · 1.04 KB
/
tools.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import torch
import logging
import sys
import gc
from transformers.utils import (
SAFE_WEIGHTS_NAME,
WEIGHTS_NAME,
is_torch_cuda_available,
is_torch_mps_available,
is_torch_npu_available,
is_torch_xpu_available,
)
def torch_gc() -> None:
r"""
Collects GPU or NPU memory.
"""
gc.collect()
if is_torch_xpu_available():
torch.xpu.empty_cache()
elif is_torch_npu_available():
torch.npu.empty_cache()
elif is_torch_mps_available():
torch.mps.empty_cache()
elif is_torch_cuda_available():
torch.cuda.empty_cache()
def get_logger(name: str) -> logging.Logger:
r"""
Gets a standard logger with a stream hander to stdout.
"""
formatter = logging.Formatter(
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S"
)
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(formatter)
logger = logging.getLogger(name)
logger.setLevel(logging.INFO)
logger.addHandler(handler)
return logger