diff --git a/ramalama/cli.py b/ramalama/cli.py index e551a5a3..aae46a8e 100644 --- a/ramalama/cli.py +++ b/ramalama/cli.py @@ -967,7 +967,7 @@ def rm_cli(args): def New(model, args): if model.startswith("huggingface://") or model.startswith("hf://") or model.startswith("hf.co/"): return Huggingface(model) - if model.startswith("ollama"): + if model.startswith("ollama") or "ollama.com/library/" in model: return Ollama(model) if model.startswith("oci://") or model.startswith("docker://"): return OCI(model, args.engine) diff --git a/ramalama/huggingface.py b/ramalama/huggingface.py index 675c76f4..b8eff02b 100644 --- a/ramalama/huggingface.py +++ b/ramalama/huggingface.py @@ -2,7 +2,7 @@ import pathlib import urllib.request from ramalama.common import available, run_cmd, exec_cmd, download_file, verify_checksum, perror -from ramalama.model import Model +from ramalama.model import Model, rm_until_substring missing_huggingface = """ Optional: Huggingface models require the huggingface-cli module. @@ -33,9 +33,8 @@ def fetch_checksum_from_api(url): class Huggingface(Model): def __init__(self, model): - model = model.removeprefix("huggingface://") - model = model.removeprefix("hf://") - model = model.removeprefix("hf.co/") + model = rm_until_substring(model, "hf.co/") + model = rm_until_substring(model, "://") super().__init__(model) self.type = "huggingface" split = self.model.rsplit("/", 1) diff --git a/ramalama/model.py b/ramalama/model.py index 94889238..cd714b4b 100644 --- a/ramalama/model.py +++ b/ramalama/model.py @@ -486,3 +486,12 @@ def distinfo_volume(): return "" return f"-v{path}:/usr/share/ramalama/{dist_info}:ro" + + +def rm_until_substring(model, substring): + pos = model.find(substring) + if pos == -1: + return model + + # Create a new string starting after the found substring + return ''.join(model[i] for i in range(pos + len(substring), len(model))) diff --git a/ramalama/oci.py b/ramalama/oci.py index 096b1062..1280c306 100644 --- a/ramalama/oci.py +++ b/ramalama/oci.py @@ -4,14 +4,8 @@ import tempfile import ramalama.annotations as annotations -from ramalama.model import Model, MODEL_TYPES -from ramalama.common import ( - engine_version, - exec_cmd, - MNT_FILE, - perror, - run_cmd, -) +from ramalama.model import Model, rm_until_substring +from ramalama.common import engine_version, exec_cmd, MNT_FILE, perror, run_cmd prefix = "oci://" @@ -103,10 +97,8 @@ def list_models(args): class OCI(Model): def __init__(self, model, conman): - super().__init__(model.removeprefix(prefix).removeprefix("docker://")) - for t in MODEL_TYPES: - if self.model.startswith(t + "://"): - raise ValueError(f"{model} invalid: Only OCI Model types supported") + model = rm_until_substring(model, "://") + super().__init__(model) self.type = "OCI" self.conman = conman diff --git a/ramalama/ollama.py b/ramalama/ollama.py index 4432cb4d..e2435124 100644 --- a/ramalama/ollama.py +++ b/ramalama/ollama.py @@ -2,7 +2,7 @@ import urllib.request import json from ramalama.common import run_cmd, verify_checksum, download_file -from ramalama.model import Model +from ramalama.model import Model, rm_until_substring def fetch_manifest_data(registry_head, model_tag, accept): @@ -60,7 +60,9 @@ def init_pull(repos, accept, registry_head, model_name, model_tag, models, model class Ollama(Model): def __init__(self, model): - super().__init__(model.removeprefix("ollama://")) + model = rm_until_substring(model, "ollama.com/library/") + model = rm_until_substring(model, "://") + super().__init__(model) self.type = "Ollama" def _local(self, args): diff --git a/ramalama/url.py b/ramalama/url.py index 8e11dfe4..119c0825 100644 --- a/ramalama/url.py +++ b/ramalama/url.py @@ -1,17 +1,13 @@ import os from ramalama.common import download_file -from ramalama.model import Model +from ramalama.model import Model, rm_until_substring +from urllib.parse import urlparse class URL(Model): def __init__(self, model): - self.type = "" - for prefix in ["file", "http", "https"]: - if model.startswith(f"{prefix}://"): - self.type = prefix - model = model.removeprefix(f"{prefix}://") - break - + self.type = urlparse(model).scheme + model = rm_until_substring(model, "://") super().__init__(model) split = self.model.rsplit("/", 1) self.directory = split[0].removeprefix("/") if len(split) > 1 else ""