@@ -96,14 +96,27 @@ def _get_weight_files(
9696 is_local = os .path .isdir (model_name_or_path )
9797
9898 if is_local :
99- for pattern in allowed_patterns :
99+ patterns = list (allowed_patterns )
100+ # Prefer subfolder patterns if common subfolder exists locally.
101+ if os .path .isdir (os .path .join (model_name_or_path , "llm" )):
102+ patterns = [f"llm/{ p } " for p in allowed_patterns ] + patterns
103+ for pattern in patterns :
100104 weight_files = glob .glob (os .path .join (model_name_or_path , pattern ))
101105 if weight_files :
102106 return model_name_or_path , weight_files , pattern
103107 else :
104108 hf_api = HfApi ()
105109 repo_files = hf_api .list_repo_files (repo_id = model_name_or_path )
106- for pattern in allowed_patterns :
110+ search_patterns = list (allowed_patterns )
111+ # Prefer 'llm/' weights when present in the repo.
112+ if any (
113+ f .startswith ("llm/" ) and f .endswith ((".safetensors" , ".bin" , ".pt" ))
114+ for f in repo_files
115+ ):
116+ search_patterns = [
117+ f"llm/{ p } " for p in allowed_patterns
118+ ] + search_patterns
119+ for pattern in search_patterns :
107120 matching_files = fnmatch .filter (repo_files , pattern )
108121 if matching_files :
109122 hf_folder = download_weights_from_hf (
@@ -128,26 +141,35 @@ def _prepare_weights(
128141
129142 allowed_patterns = ["*.safetensors" , "*.bin" , "*.pt" ]
130143
144+ if getattr (self , "allow_patterns_overrides" , None ):
145+ allowed_patterns = list (self .allow_patterns_overrides )
146+
131147 hf_folder , hf_weights_files , matched_pattern = self ._get_weight_files (
132148 model_name_or_path , allowed_patterns , revision
133149 )
134150
135- use_safetensors = matched_pattern == "*.safetensors"
151+ # Detect safetensors robustly (pattern may include subfolder)
152+ use_safetensors = matched_pattern .endswith (".safetensors" )
153+ # Additionally guard by checking actual files
154+ if not use_safetensors :
155+ use_safetensors = any (f .endswith (".safetensors" ) for f in hf_weights_files )
136156 is_local = os .path .isdir (model_name_or_path )
137- index_file = SAFE_WEIGHTS_INDEX_NAME
157+ # If weights live under a subfolder (e.g., 'llm/*.safetensors'),
158+ # the index file will also live there.
159+ if "/" in matched_pattern :
160+ folder_prefix = matched_pattern .rsplit ("/" , 1 )[0 ] + "/"
161+ else :
162+ folder_prefix = ""
163+ index_file = folder_prefix + SAFE_WEIGHTS_INDEX_NAME
164+ if use_safetensors and not is_local :
165+ # Download index for safetensors to select correct shards.
166+ download_safetensors_index_file_from_hf (
167+ model_name_or_path ,
168+ index_file ,
169+ self .load_config .download_dir ,
170+ revision ,
171+ )
138172 if use_safetensors :
139- # For models like Mistral-7B-Instruct-v0.3
140- # there are both sharded safetensors files and a consolidated
141- # safetensors file. Using both breaks.
142- # Here, we download the `model.safetensors.index.json` and filter
143- # any files not found in the index.
144- if not is_local :
145- download_safetensors_index_file_from_hf (
146- model_name_or_path ,
147- index_file ,
148- self .load_config .download_dir ,
149- revision ,
150- )
151173 hf_weights_files = filter_duplicate_safetensors_files (
152174 hf_weights_files , hf_folder , index_file
153175 )
@@ -587,6 +609,8 @@ def _initialize_loader_state(
587609 self ._get_bnb_target_modules (model )
588610 self ._classify_module_sharding (model )
589611
612+ self .allow_patterns_overrides = getattr (model , "allow_patterns_overrides" , None )
613+
590614 def _dequantize_dq (self , quant_states : Any ):
591615 """
592616 When BNB employs Double Quantization, we perform the dequantization of
0 commit comments