Skip to content

Commit

Permalink
Merge pull request #782 from kush-gupt/main
Browse files Browse the repository at this point in the history
Reuse Ollama cached image when available
  • Loading branch information
ericcurtin authored Feb 11, 2025
2 parents 56ecf60 + 1428cd8 commit 6f0783c
Showing 1 changed file with 33 additions and 8 deletions.
41 changes: 33 additions & 8 deletions ramalama/ollama.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import urllib.request
import json
from ramalama.common import run_cmd, verify_checksum, download_file
from ramalama.common import run_cmd, verify_checksum, download_file, available
from ramalama.model import Model, rm_until_substring


Expand Down Expand Up @@ -30,15 +30,18 @@ def pull_blob(repos, layer_digest, accept, registry_head, models, model_name, mo
layer_blob_path = os.path.join(repos, "blobs", layer_digest)
url = f"{registry_head}/blobs/{layer_digest}"
headers = {"Accept": accept}
download_file(url, layer_blob_path, headers=headers, show_progress=True)

# Verify checksum after downloading the blob
if not verify_checksum(layer_blob_path):
print(f"Checksum mismatch for blob {layer_blob_path}, retrying download...")
os.remove(layer_blob_path)
local_blob = in_existing_cache(model_name, model_tag)
if local_blob is not None:
run_cmd(["ln", "-sf", local_blob, layer_blob_path])
else:
download_file(url, layer_blob_path, headers=headers, show_progress=True)
# Verify checksum after downloading the blob
if not verify_checksum(layer_blob_path):
raise ValueError(f"Checksum verification failed for blob {layer_blob_path}")
print(f"Checksum mismatch for blob {layer_blob_path}, retrying download...")
os.remove(layer_blob_path)
download_file(url, layer_blob_path, headers=headers, show_progress=True)
if not verify_checksum(layer_blob_path):
raise ValueError(f"Checksum verification failed for blob {layer_blob_path}")

os.makedirs(models, exist_ok=True)
relative_target_path = os.path.relpath(layer_blob_path, start=os.path.dirname(model_path))
Expand All @@ -58,6 +61,28 @@ def init_pull(repos, accept, registry_head, model_name, model_tag, models, model
return model_path


def in_existing_cache(model_name, model_tag):
if not available("ollama"):
return None
default_ollama_caches=[
os.path.join(os.environ['HOME'], '.ollama/models'),
'/usr/share/ollama/.ollama/models',
f'C:\\Users\\{os.getlogin()}\\.ollama\\models'
]

for cache_dir in default_ollama_caches:
manifest_path = os.path.join(cache_dir, 'manifests', 'registry.ollama.ai', model_name, model_tag)
if os.path.exists(manifest_path):
with open(manifest_path, 'r') as file:
manifest_data = json.load(file)
for layer in manifest_data["layers"]:
if layer["mediaType"] == "application/vnd.ollama.image.model":
layer_digest = layer["digest"]
ollama_digest_path = os.path.join(cache_dir, 'blobs', layer_digest)
if os.path.exists(str(ollama_digest_path).replace(':','-')):
return str(ollama_digest_path).replace(':','-')
return None

class Ollama(Model):
def __init__(self, model):
model = rm_until_substring(model, "ollama.com/library/")
Expand Down

0 comments on commit 6f0783c

Please sign in to comment.