diff --git a/ramalama/ollama.py b/ramalama/ollama.py index e2435124..b80a8922 100644 --- a/ramalama/ollama.py +++ b/ramalama/ollama.py @@ -1,7 +1,7 @@ import os import urllib.request import json -from ramalama.common import run_cmd, verify_checksum, download_file +from ramalama.common import run_cmd, verify_checksum, download_file, available from ramalama.model import Model, rm_until_substring @@ -30,15 +30,18 @@ def pull_blob(repos, layer_digest, accept, registry_head, models, model_name, mo layer_blob_path = os.path.join(repos, "blobs", layer_digest) url = f"{registry_head}/blobs/{layer_digest}" headers = {"Accept": accept} - download_file(url, layer_blob_path, headers=headers, show_progress=True) - - # Verify checksum after downloading the blob - if not verify_checksum(layer_blob_path): - print(f"Checksum mismatch for blob {layer_blob_path}, retrying download...") - os.remove(layer_blob_path) + local_blob = in_existing_cache(model_name, model_tag) + if local_blob is not None: + run_cmd(["ln", "-sf", local_blob, layer_blob_path]) + else: download_file(url, layer_blob_path, headers=headers, show_progress=True) + # Verify checksum after downloading the blob if not verify_checksum(layer_blob_path): - raise ValueError(f"Checksum verification failed for blob {layer_blob_path}") + print(f"Checksum mismatch for blob {layer_blob_path}, retrying download...") + os.remove(layer_blob_path) + download_file(url, layer_blob_path, headers=headers, show_progress=True) + if not verify_checksum(layer_blob_path): + raise ValueError(f"Checksum verification failed for blob {layer_blob_path}") os.makedirs(models, exist_ok=True) relative_target_path = os.path.relpath(layer_blob_path, start=os.path.dirname(model_path)) @@ -58,6 +61,28 @@ def init_pull(repos, accept, registry_head, model_name, model_tag, models, model return model_path +def in_existing_cache(model_name, model_tag): + if not available("ollama"): + return None + default_ollama_caches=[ + os.path.join(os.environ['HOME'], '.ollama/models'), + '/usr/share/ollama/.ollama/models', + f'C:\\Users\\{os.getlogin()}\\.ollama\\models' + ] + + for cache_dir in default_ollama_caches: + manifest_path = os.path.join(cache_dir, 'manifests', 'registry.ollama.ai', model_name, model_tag) + if os.path.exists(manifest_path): + with open(manifest_path, 'r') as file: + manifest_data = json.load(file) + for layer in manifest_data["layers"]: + if layer["mediaType"] == "application/vnd.ollama.image.model": + layer_digest = layer["digest"] + ollama_digest_path = os.path.join(cache_dir, 'blobs', layer_digest) + if os.path.exists(str(ollama_digest_path).replace(':','-')): + return str(ollama_digest_path).replace(':','-') + return None + class Ollama(Model): def __init__(self, model): model = rm_until_substring(model, "ollama.com/library/")