diff --git a/ramalama/cli.py b/ramalama/cli.py index 0b3577b4..e853a35a 100644 --- a/ramalama/cli.py +++ b/ramalama/cli.py @@ -857,14 +857,14 @@ def rm_cli(args): def New(model, args): - if model.startswith("huggingface://") or model.startswith("hf://") or model.startswith("hf.co/"): + if model.startswith("hf://") or model.startswith("huggingface://") or model.startswith("hf.co/"): return Huggingface(model) - if model.startswith("ollama"): - return Ollama(model) - if model.startswith("oci://") or model.startswith("docker://"): - return OCI(model, args.engine) - if model.startswith("http://") or model.startswith("https://") or model.startswith("file://"): + elif ( + model.startswith("https://") or model.startswith("http://") or model.startswith("file://") + ) and not model.startswith("https://ollama.com/library/"): return URL(model) + elif model.startswith("oci://") or model.startswith("docker://"): + return OCI(model, args.engine) transport = config.get("transport", "ollama") if transport == "huggingface": diff --git a/ramalama/common.py b/ramalama/common.py index 7f4f2e4e..88a002f1 100644 --- a/ramalama/common.py +++ b/ramalama/common.py @@ -243,3 +243,12 @@ def get_env_vars(): # env_vars[gpu_type] = str(gpu_num) return env_vars + + +def rm_until_substring(model, substring): + pos = model.find(substring) + if pos == -1: + return model + + # Create a new string starting after the found substring + return ''.join(model[i] for i in range(pos + len(substring), len(model))) diff --git a/ramalama/huggingface.py b/ramalama/huggingface.py index 184a107d..1212ccb4 100644 --- a/ramalama/huggingface.py +++ b/ramalama/huggingface.py @@ -1,7 +1,7 @@ import os import pathlib import urllib.request -from ramalama.common import available, run_cmd, exec_cmd, download_file, verify_checksum, perror +from ramalama.common import available, run_cmd, exec_cmd, download_file, verify_checksum, perror, rm_until_substring from ramalama.model import Model missing_huggingface = """ @@ -33,9 +33,8 @@ def fetch_checksum_from_api(url): class Huggingface(Model): def __init__(self, model): - model = model.removeprefix("huggingface://") - model = model.removeprefix("hf://") - model = model.removeprefix("hf.co/") + model = rm_until_substring(model, "hf.co/") + model = rm_until_substring(model, "://") super().__init__(model) self.type = "huggingface" split = self.model.rsplit("/", 1) diff --git a/ramalama/oci.py b/ramalama/oci.py index 6f5e50ff..f83cfeaf 100644 --- a/ramalama/oci.py +++ b/ramalama/oci.py @@ -4,14 +4,8 @@ import tempfile import ramalama.annotations as annotations -from ramalama.model import Model, MODEL_TYPES -from ramalama.common import ( - engine_version, - exec_cmd, - MNT_FILE, - perror, - run_cmd, -) +from ramalama.model import Model +from ramalama.common import engine_version, exec_cmd, MNT_FILE, perror, run_cmd, rm_until_substring prefix = "oci://" @@ -103,10 +97,8 @@ def list_models(args): class OCI(Model): def __init__(self, model, conman): - super().__init__(model.removeprefix(prefix).removeprefix("docker://")) - for t in MODEL_TYPES: - if self.model.startswith(t + "://"): - raise ValueError(f"{model} invalid: Only OCI Model types supported") + model = rm_until_substring(model, "://") + super().__init__(model) self.type = "OCI" self.conman = conman diff --git a/ramalama/ollama.py b/ramalama/ollama.py index 32aa4565..8713d1cd 100644 --- a/ramalama/ollama.py +++ b/ramalama/ollama.py @@ -1,7 +1,7 @@ import os import urllib.request import json -from ramalama.common import run_cmd, verify_checksum, download_file +from ramalama.common import run_cmd, verify_checksum, download_file, rm_until_substring from ramalama.model import Model @@ -60,7 +60,9 @@ def init_pull(repos, accept, registry_head, model_name, model_tag, models, model class Ollama(Model): def __init__(self, model): - super().__init__(model.removeprefix("ollama://")) + model = rm_until_substring(model, "ollama.com/library/") + model = rm_until_substring(model, "://") + super().__init__(model) self.type = "Ollama" def _local(self, args): diff --git a/ramalama/url.py b/ramalama/url.py index 8e11dfe4..b2769325 100644 --- a/ramalama/url.py +++ b/ramalama/url.py @@ -1,16 +1,12 @@ import os -from ramalama.common import download_file +from ramalama.common import download_file, rm_until_substring from ramalama.model import Model class URL(Model): def __init__(self, model): self.type = "" - for prefix in ["file", "http", "https"]: - if model.startswith(f"{prefix}://"): - self.type = prefix - model = model.removeprefix(f"{prefix}://") - break + model = rm_until_substring(model, "://") super().__init__(model) split = self.model.rsplit("/", 1)