Skip to content

Commit efedbfc

Browse files
committed
add omnivici
Signed-off-by: 0xrushi <[email protected]>
1 parent e39dc46 commit efedbfc

File tree

10 files changed

+329
-18
lines changed

10 files changed

+329
-18
lines changed
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
from typing import Optional
5+
6+
from vllm.config.load import LoadConfig
7+
from vllm.model_executor.model_loader.bitsandbytes_loader import (
8+
BitsAndBytesModelLoader,
9+
)
10+
11+
12+
class _DummyBitsAndBytesLoader(BitsAndBytesModelLoader):
13+
"""Test helper that bypasses any real HF interactions."""
14+
15+
def __init__(
16+
self, load_config: LoadConfig, mock_result: tuple[str, list[str], str]
17+
):
18+
super().__init__(load_config)
19+
self._mock_result = mock_result
20+
21+
def _get_weight_files( # type: ignore[override]
22+
self,
23+
model_name_or_path: str,
24+
allowed_patterns: list[str],
25+
revision: Optional[str] = None,
26+
) -> tuple[str, list[str], str]:
27+
return self._mock_result
28+
29+
30+
def test_bitsandbytes_loader_detects_safetensors_from_files(tmp_path):
31+
"""Even if the allow-pattern looks like *.bin, safetensors files are detected."""
32+
33+
llm_dir = tmp_path / "llm"
34+
llm_dir.mkdir()
35+
safetensor = llm_dir / "model-00001-of-00002.safetensors"
36+
safetensor.write_bytes(b"test")
37+
38+
load_config = LoadConfig()
39+
loader = _DummyBitsAndBytesLoader(
40+
load_config,
41+
mock_result=(str(tmp_path), [str(safetensor)], "*.bin"),
42+
)
43+
44+
files, use_safetensors = loader._prepare_weights(str(tmp_path), revision=None)
45+
46+
assert use_safetensors is True
47+
assert files == [str(safetensor)]

tests/model_executor/test_weight_utils.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4+
import json
45
import os
56
import tempfile
67

@@ -11,6 +12,7 @@
1112
from vllm.model_executor.model_loader.weight_utils import (
1213
download_weights_from_hf,
1314
enable_hf_transfer,
15+
filter_duplicate_safetensors_files,
1416
)
1517

1618

@@ -61,6 +63,28 @@ def test_download_weights_from_hf():
6163
)
6264

6365

66+
def test_filter_duplicate_safetensors_files_with_subfolder(tmp_path):
67+
llm_dir = tmp_path / "llm"
68+
llm_dir.mkdir()
69+
kept_file = llm_dir / "model-00001-of-00002.safetensors"
70+
kept_file.write_bytes(b"0")
71+
dropped_file = tmp_path / "other.safetensors"
72+
dropped_file.write_bytes(b"0")
73+
74+
index_path = llm_dir / "model.safetensors.index.json"
75+
index_path.write_text(
76+
json.dumps({"weight_map": {"w": "model-00001-of-00002.safetensors"}})
77+
)
78+
79+
filtered = filter_duplicate_safetensors_files(
80+
[str(kept_file), str(dropped_file)],
81+
str(tmp_path),
82+
"llm/model.safetensors.index.json",
83+
)
84+
85+
assert filtered == [str(kept_file)]
86+
87+
6488
if __name__ == "__main__":
6589
test_hf_transfer_auto_activation()
6690
test_download_weights_from_hf()
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
from typing import Any
5+
6+
from vllm.transformers_utils import tokenizer as tokenizer_module
7+
from vllm.transformers_utils.tokenizer import get_tokenizer
8+
9+
10+
class _DummyTokenizer:
11+
def __init__(self):
12+
self.all_special_ids: list[int] = []
13+
self.all_special_tokens: list[str] = []
14+
self.all_special_tokens_extended: list[str] = []
15+
self.special_tokens_map: dict[str, str] = {}
16+
self.vocab_size = 1
17+
18+
def get_vocab(self) -> dict[str, int]:
19+
return {"a": 0}
20+
21+
def __len__(self) -> int: # pragma: no cover - trivial
22+
return 1
23+
24+
def decode(self, *args: Any, **kwargs: Any) -> str:
25+
return ""
26+
27+
def encode(self, *args: Any, **kwargs: Any) -> list[int]:
28+
return []
29+
30+
31+
def test_tokenizer_prefers_llm_subfolder(monkeypatch):
32+
captured = {}
33+
34+
def fake_file_exists(repo_id: str, file_name: str, **kwargs: Any) -> bool:
35+
return file_name == "llm/tokenizer.json"
36+
37+
def fake_auto_from_pretrained(*args: Any, **kwargs: Any):
38+
captured["subfolder"] = kwargs.get("subfolder")
39+
return _DummyTokenizer()
40+
41+
monkeypatch.setattr(tokenizer_module, "file_exists", fake_file_exists)
42+
monkeypatch.setattr(
43+
tokenizer_module.AutoTokenizer,
44+
"from_pretrained",
45+
classmethod(
46+
lambda cls, *args, **kwargs: fake_auto_from_pretrained(*args, **kwargs)
47+
),
48+
)
49+
50+
tokenizer = get_tokenizer("fake/model")
51+
52+
assert tokenizer is not None
53+
assert captured["subfolder"] == "llm"
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
from typing import Optional, Union
5+
6+
from transformers import GenerationConfig, PretrainedConfig
7+
8+
from vllm.transformers_utils import config as config_module
9+
from vllm.transformers_utils.config import HFConfigParser, try_get_generation_config
10+
11+
12+
def test_hf_config_parser_uses_llm_subfolder(monkeypatch):
13+
parser = HFConfigParser()
14+
base_config = PretrainedConfig()
15+
subfolder_config = PretrainedConfig()
16+
17+
def fake_get_config_dict(
18+
cls,
19+
model: Union[str, bytes],
20+
revision: Optional[str] = None,
21+
code_revision: Optional[str] = None,
22+
**kwargs,
23+
):
24+
return {"llm_cfg": {}}, base_config
25+
26+
def fake_file_exists(
27+
model: Union[str, bytes], config_name: str, revision: Optional[str]
28+
):
29+
return config_name == "llm/config.json"
30+
31+
auto_called = {}
32+
33+
def fake_auto_from_pretrained(cls, *args, **kwargs):
34+
auto_called["subfolder"] = kwargs.get("subfolder")
35+
return subfolder_config
36+
37+
monkeypatch.setattr(
38+
PretrainedConfig,
39+
"get_config_dict",
40+
classmethod(fake_get_config_dict),
41+
)
42+
monkeypatch.setattr(config_module, "file_or_path_exists", fake_file_exists)
43+
monkeypatch.setattr(
44+
config_module.AutoConfig,
45+
"from_pretrained",
46+
classmethod(fake_auto_from_pretrained),
47+
)
48+
49+
returned_dict, returned_config = parser.parse("fake/model", trust_remote_code=False)
50+
51+
assert returned_dict == {"llm_cfg": {}}
52+
assert returned_config is subfolder_config
53+
assert auto_called["subfolder"] == "llm"
54+
55+
56+
def test_try_get_generation_config_llm_subfolder(monkeypatch):
57+
calls = []
58+
59+
def fake_from_pretrained(cls, model: str, **kwargs):
60+
calls.append(kwargs.get("subfolder"))
61+
if len(calls) == 1:
62+
raise OSError("missing")
63+
return GenerationConfig()
64+
65+
monkeypatch.setattr(
66+
config_module.GenerationConfig,
67+
"from_pretrained",
68+
classmethod(fake_from_pretrained),
69+
)
70+
71+
result = try_get_generation_config("fake/model", trust_remote_code=False)
72+
73+
assert isinstance(result, GenerationConfig)
74+
assert calls == [None, "llm"]

vllm/model_executor/model_loader/bitsandbytes_loader.py

Lines changed: 40 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -96,14 +96,27 @@ def _get_weight_files(
9696
is_local = os.path.isdir(model_name_or_path)
9797

9898
if is_local:
99-
for pattern in allowed_patterns:
99+
patterns = list(allowed_patterns)
100+
# Prefer subfolder patterns if common subfolder exists locally.
101+
if os.path.isdir(os.path.join(model_name_or_path, "llm")):
102+
patterns = [f"llm/{p}" for p in allowed_patterns] + patterns
103+
for pattern in patterns:
100104
weight_files = glob.glob(os.path.join(model_name_or_path, pattern))
101105
if weight_files:
102106
return model_name_or_path, weight_files, pattern
103107
else:
104108
hf_api = HfApi()
105109
repo_files = hf_api.list_repo_files(repo_id=model_name_or_path)
106-
for pattern in allowed_patterns:
110+
search_patterns = list(allowed_patterns)
111+
# Prefer 'llm/' weights when present in the repo.
112+
if any(
113+
f.startswith("llm/") and f.endswith((".safetensors", ".bin", ".pt"))
114+
for f in repo_files
115+
):
116+
search_patterns = [
117+
f"llm/{p}" for p in allowed_patterns
118+
] + search_patterns
119+
for pattern in search_patterns:
107120
matching_files = fnmatch.filter(repo_files, pattern)
108121
if matching_files:
109122
hf_folder = download_weights_from_hf(
@@ -128,26 +141,35 @@ def _prepare_weights(
128141

129142
allowed_patterns = ["*.safetensors", "*.bin", "*.pt"]
130143

144+
if getattr(self, "allow_patterns_overrides", None):
145+
allowed_patterns = list(self.allow_patterns_overrides)
146+
131147
hf_folder, hf_weights_files, matched_pattern = self._get_weight_files(
132148
model_name_or_path, allowed_patterns, revision
133149
)
134150

135-
use_safetensors = matched_pattern == "*.safetensors"
151+
# Detect safetensors robustly (pattern may include subfolder)
152+
use_safetensors = matched_pattern.endswith(".safetensors")
153+
# Additionally guard by checking actual files
154+
if not use_safetensors:
155+
use_safetensors = any(f.endswith(".safetensors") for f in hf_weights_files)
136156
is_local = os.path.isdir(model_name_or_path)
137-
index_file = SAFE_WEIGHTS_INDEX_NAME
157+
# If weights live under a subfolder (e.g., 'llm/*.safetensors'),
158+
# the index file will also live there.
159+
if "/" in matched_pattern:
160+
folder_prefix = matched_pattern.rsplit("/", 1)[0] + "/"
161+
else:
162+
folder_prefix = ""
163+
index_file = folder_prefix + SAFE_WEIGHTS_INDEX_NAME
164+
if use_safetensors and not is_local:
165+
# Download index for safetensors to select correct shards.
166+
download_safetensors_index_file_from_hf(
167+
model_name_or_path,
168+
index_file,
169+
self.load_config.download_dir,
170+
revision,
171+
)
138172
if use_safetensors:
139-
# For models like Mistral-7B-Instruct-v0.3
140-
# there are both sharded safetensors files and a consolidated
141-
# safetensors file. Using both breaks.
142-
# Here, we download the `model.safetensors.index.json` and filter
143-
# any files not found in the index.
144-
if not is_local:
145-
download_safetensors_index_file_from_hf(
146-
model_name_or_path,
147-
index_file,
148-
self.load_config.download_dir,
149-
revision,
150-
)
151173
hf_weights_files = filter_duplicate_safetensors_files(
152174
hf_weights_files, hf_folder, index_file
153175
)
@@ -587,6 +609,8 @@ def _initialize_loader_state(
587609
self._get_bnb_target_modules(model)
588610
self._classify_module_sharding(model)
589611

612+
self.allow_patterns_overrides = getattr(model, "allow_patterns_overrides", None)
613+
590614
def _dequantize_dq(self, quant_states: Any):
591615
"""
592616
When BNB employs Double Quantization, we perform the dequantization of

vllm/model_executor/model_loader/weight_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -499,8 +499,12 @@ def filter_duplicate_safetensors_files(
499499
with open(index_file_name) as f:
500500
weight_map = json.load(f)["weight_map"]
501501
weight_files_in_index = set()
502+
# If the index file is inside a subfolder (e.g., 'llm/model.safetensors.index.json'),
503+
# the shard paths in `weight_map` are relative to that subfolder. Use the
504+
# index file's directory as the base for joining shard filenames.
505+
base_dir = os.path.dirname(index_file_name)
502506
for weight_name in weight_map:
503-
weight_files_in_index.add(os.path.join(hf_folder, weight_map[weight_name]))
507+
weight_files_in_index.add(os.path.join(base_dir, weight_map[weight_name]))
504508
# Filter out any fields that are not found in the index file.
505509
hf_weights_files = [f for f in hf_weights_files if f in weight_files_in_index]
506510
return hf_weights_files
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
"""Thin wrapper to support nvidia/omnivinci LLM weights stored under llm/.
5+
6+
This model maps the root architecture (VILAForCausalLM) to the text-only
7+
Qwen2 architecture by reusing vLLM's Qwen2ForCausalLM and ensures the weight
8+
loader searches in the `llm/` subfolder of the repository.
9+
"""
10+
11+
from vllm.config import VllmConfig
12+
from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM
13+
14+
15+
class OmniVinciForCausalLM(Qwen2ForCausalLM):
16+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
17+
super().__init__(vllm_config=vllm_config, prefix=prefix)
18+
# direct the default loader to read weights from the llm/ subfolder
19+
self.allow_patterns_overrides = [
20+
"llm/*.safetensors",
21+
"llm/consolidated*.safetensors",
22+
"llm/*.pt",
23+
]

vllm/model_executor/models/registry.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,10 @@
166166
"TeleFLMForCausalLM": ("teleflm", "TeleFLMForCausalLM"),
167167
"XverseForCausalLM": ("llama", "LlamaForCausalLM"),
168168
"Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
169+
# nvidia/omnivinci root config advertises VILAForCausalLM but the LLM
170+
# component is Qwen2 with weights/config under the llm/ subfolder.
171+
# Map it to a thin wrapper that reuses Qwen2 implementation.
172+
"VILAForCausalLM": ("omnivinci", "OmniVinciForCausalLM"),
169173
}
170174

171175
_EMBEDDING_MODELS = {

0 commit comments

Comments
 (0)