diff --git a/mlx_lm/utils.py b/mlx_lm/utils.py index ef3d266b9..da498567e 100644 --- a/mlx_lm/utils.py +++ b/mlx_lm/utils.py @@ -458,6 +458,7 @@ def load( lazy: bool = False, return_config: bool = False, revision: Optional[str] = None, + strict: bool = True, ) -> Union[ Tuple[nn.Module, TokenizerWrapper], Tuple[nn.Module, TokenizerWrapper, Dict[str, Any]], @@ -478,6 +479,7 @@ def load( when needed. Default: ``False`` return_config (bool: If ``True`` return the model config as the last item.. revision (str, optional): A revision id which can be a branch name, a tag, or a commit hash. + strict (bool): Whether to enforce strict weight checking. Default: ``True``. Returns: Union[Tuple[nn.Module, TokenizerWrapper], Tuple[nn.Module, TokenizerWrapper, Dict[str, Any]]]: A tuple containing the loaded model, tokenizer and, if requested, the model config. @@ -488,7 +490,7 @@ def load( """ model_path = _download(path_or_hf_repo, revision=revision) - model, config = load_model(model_path, lazy, model_config=model_config) + model, config = load_model(model_path, lazy, strict=strict, model_config=model_config) if adapter_path is not None: model = load_adapters(model, adapter_path) model.eval()