Skip to content

Commit 70f56ff

Browse files
authored
Support DISABLE_KERNEL_MAPPING env var for completely disabling kernel mappings (#70)
* Disable kernel mappings with `DISABLE_KERNEL_MAPPING=1` * Rename HF_KERNELS_CACHE to KERNELS_CACHE But still recognize the old variant for compatibility. * Add documentation for environment variables
1 parent 7178b0b commit 70f56ff

File tree

4 files changed

+32
-1
lines changed

4 files changed

+32
-1
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ the Hub.
4747

4848
- [Using layers](docs/layers.md)
4949
- [Locking kernel versions](docs/locking.md)
50+
- [Environment variables](docs/env.md)
5051
- [Using kernels in a Docker container](docs/docker.md)
5152
- [Kernel requirements](docs/kernel-requirements.md)
5253
- [Writing kernels](https://github.com/huggingface/kernel-builder/blob/main/docs/writing-kernels.md) using [kernel-builder](https://github.com/huggingface/kernel-builder/)

docs/env.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# Environment variables
2+
3+
## `KERNELS_CACHE`
4+
5+
The directory to use as the local kernel cache. If not set, the cache
6+
of the `huggingface_hub` package is used.
7+
8+
## `DISABLE_KERNEL_MAPPING`
9+
10+
Disables kernel mappings for [`layers`](layers.md).

src/kernels/layer.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import inspect
2+
import os
23
import warnings
34
from contextvars import ContextVar
45
from copy import deepcopy
@@ -10,6 +11,8 @@
1011
if TYPE_CHECKING:
1112
from torch import nn
1213

14+
_DISABLE_KERNEL_MAPPING: bool = bool(int(os.environ.get("DISABLE_KERNEL_MAPPING", "0")))
15+
1316

1417
@dataclass(frozen=True)
1518
class Device:
@@ -131,6 +134,9 @@ def replace_kernel_forward_from_hub(cls, layer_name: str, *, use_fallback: bool
131134
cached_forward: Dict[LayerRepository, Callable] = {}
132135

133136
def forward(self, x, *args, **kwargs):
137+
if _DISABLE_KERNEL_MAPPING:
138+
return fallback_forward(self, x, *args, **kwargs)
139+
134140
kernel = _KERNEL_MAPPING.get().get(layer_name)
135141
if kernel is None:
136142
warnings.warn(

src/kernels/utils.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import importlib.metadata
55
import inspect
66
import json
7+
import logging
78
import os
89
import platform
910
import sys
@@ -17,7 +18,20 @@
1718

1819
from kernels.lockfile import KernelLock, VariantLock
1920

20-
CACHE_DIR: Optional[str] = os.environ.get("HF_KERNELS_CACHE", None)
21+
22+
def _get_cache_dir() -> Optional[str]:
23+
"""Returns the kernels cache directory."""
24+
cache_dir = os.environ.get("HF_KERNELS_CACHE", None)
25+
if cache_dir is not None:
26+
logging.warning(
27+
"HF_KERNELS_CACHE will be removed in the future, use KERNELS_CACHE instead"
28+
)
29+
return cache_dir
30+
31+
return os.environ.get("KERNELS_CACHE", None)
32+
33+
34+
CACHE_DIR: Optional[str] = _get_cache_dir()
2135

2236

2337
def build_variant() -> str:

0 commit comments

Comments
 (0)