diff --git a/torchchat/generate.py b/torchchat/generate.py index ed1b27fa6..5ae7ecfad 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -869,13 +869,6 @@ def _gen_model_input( max_new_tokens: Optional[int] = None, max_seq_len: Optional[int] = 2048, ) -> Tuple[torch.Tensor, Optional[Dict[str, Any]]]: - # torchtune model definition dependencies - from torchtune.data import Message, padded_collate_tiled_images_and_mask - from torchtune.models.llama3_2_vision._model_builders import ( - llama3_2_vision_transform, - ) - from torchtune.training import set_default_dtype - """ Convert prompt and image prompts into consumable model input args. @@ -911,6 +904,14 @@ def _gen_model_input( return encoded, None # Llama 3.2 11B + + # torchtune model definition dependencies + from torchtune.data import Message, padded_collate_tiled_images_and_mask + from torchtune.models.llama3_2_vision._model_builders import ( + llama3_2_vision_transform, + ) + from torchtune.training import set_default_dtype + assert ( image_prompts is None or len(image_prompts) == 1 ), "At most one image is supported at the moment" diff --git a/torchchat/model.py b/torchchat/model.py index 4605aea33..da3cc1dd7 100644 --- a/torchchat/model.py +++ b/torchchat/model.py @@ -1054,13 +1054,13 @@ def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor: # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ try: - # For llama::sdpa_with_kv_cache.out, preprocess ops - from executorch.extension.llm.custom_ops import custom_ops # no-qa from executorch.extension.pybindings import portable_lib as exec_lib # ET changed the way it's loading the custom ops so it's not included in portable_lib but has to be loaded separately. # For quantized_decomposed ops from executorch.kernels import quantized # no-qa + # For llama::sdpa_with_kv_cache.out, preprocess ops + from executorch.extension.llm.custom_ops import custom_ops # no-qa class PTEModel(nn.Module): def __init__(self, config, path) -> None: