forked from bghira/SimpleTuner
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpredict.py
More file actions
263 lines (233 loc) · 10.6 KB
/
predict.py
File metadata and controls
263 lines (233 loc) · 10.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
"""Cog predictor entrypoint using the SimpleTuner trainer directly."""
import json
import os
import pathlib
from typing import Any, Dict, List, Optional, Tuple
from cog import BasePredictor, Input, Path, Secret
os.environ["HF_ENDPOINT"] = "https://huggingface.co"
os.environ["HUGGINGFACE_HUB_ENDPOINT"] = "https://huggingface.co"
class Predictor(BasePredictor):
def setup(self) -> None:
"""Initialise reusable runner state for Cog."""
# Lazy import to avoid colored output during Cog introspection
from simpletuner.cog import SimpleTunerCogRunner
self.runner = SimpleTunerCogRunner()
def _parse_json_or_path(self, value: str, param_name: str) -> Tuple[Optional[pathlib.Path], Optional[dict]]:
"""Parse a string as either inline JSON or a file path.
Returns (path, None) if it's a file path, or (None, dict) if it's inline JSON.
"""
value = value.strip()
# Try parsing as JSON first
if value.startswith("{") or value.startswith("["):
try:
parsed = json.loads(value)
return None, parsed
except json.JSONDecodeError as e:
raise ValueError(f"{param_name} looks like JSON but failed to parse: {e}")
# Treat as file path
path = pathlib.Path(value)
if not path.exists():
raise FileNotFoundError(f"{param_name} file not found: {value}")
return path, None
def _normalize_hub_model_id(self, hub_model_id: str) -> str:
from urllib.parse import urlparse
cleaned = hub_model_id.strip()
if not cleaned:
return cleaned
parsed = urlparse(cleaned)
if parsed.scheme and parsed.netloc:
cleaned = parsed.path
cleaned = cleaned.split("?", 1)[0].split("#", 1)[0].lstrip("/")
if "huggingface.co/" in cleaned:
cleaned = cleaned.split("huggingface.co/", 1)[1]
if cleaned.startswith("huggingface.co"):
cleaned = cleaned[len("huggingface.co") :].lstrip("/")
return cleaned
def predict(
self,
images: Path = Input(
description="Zip or tar archive of training images. Not required if dataloader_json points to external data.",
default=None,
),
config_json: str = Input(
description="Training config: either a JSON string or path to config.json. Defaults to config/config.json if present.",
default=None,
),
dataloader_json: str = Input(
description="Multidatabackend config: either a JSON string or path to file. If not provided, auto-generated from images.",
default=None,
),
max_train_steps: int = Input(
description="Override --max_train_steps for quicker Cog runs.",
default=None,
),
# S3 publishing options (override config)
s3_bucket: Optional[str] = Input(
description="S3-compatible bucket for publishing checkpoints (overrides config).",
default=None,
),
s3_region: Optional[str] = Input(
description="S3 region (optional).",
default=None,
),
s3_endpoint_url: Optional[str] = Input(
description="Custom S3 endpoint URL (for non-AWS providers like Backblaze B2, Cloudflare R2).",
default=None,
),
s3_base_path: Optional[str] = Input(
description="Prefix inside the bucket (defaults to simpletuner/{job_id}).",
default=None,
),
s3_public_base_url: Optional[str] = Input(
description="Public base URL to build shareable links (optional).",
default=None,
),
s3_access_key: Optional[Secret] = Input(
description="S3 access key (leave blank to use IAM/instance roles).",
default=None,
),
s3_secret_key: Optional[Secret] = Input(
description="S3 secret key (leave blank to use IAM/instance roles).",
default=None,
),
# HuggingFace Hub publishing options (override config)
hub_model_id: Optional[str] = Input(
description="HuggingFace Hub repo ID (e.g., 'username/my-lora') - overrides config.",
default=None,
),
hf_token: Secret = Input(
description="Hugging Face token for model downloads and Hub publishing.",
default=None,
),
lycoris_config: str = Input(
description="LyCORIS config: either a JSON string or path to lycoris_config.json. Required when lora_type is 'lycoris'.",
default=None,
),
return_logs: bool = Input(
description="Print the tail of debug.log to Cog output.",
default=True,
),
) -> str:
"""Launch a SimpleTuner training job and return the output location."""
token_value = hf_token.get_secret_value() if hf_token else None
dataset_archive = pathlib.Path(images) if images else None
# Parse config_json - can be JSON string or file path
config_path = None
config_dict = None
if config_json:
config_path, config_dict = self._parse_json_or_path(config_json, "config_json")
# Parse dataloader_json - can be JSON string or file path
dataloader_path = None
dataloader_dict = None
if dataloader_json:
dataloader_path, dataloader_dict = self._parse_json_or_path(dataloader_json, "dataloader_json")
# Parse lycoris_config - can be JSON string or file path
lycoris_config_path = None
if lycoris_config:
lycoris_path, lycoris_dict = self._parse_json_or_path(lycoris_config, "lycoris_config")
if lycoris_dict is not None:
# Write inline JSON to disk
lycoris_config_path = pathlib.Path("config") / "cog" / "lycoris_config.json"
lycoris_config_path.parent.mkdir(parents=True, exist_ok=True)
with lycoris_config_path.open("w", encoding="utf-8") as handle:
json.dump(lycoris_dict, handle, indent=2)
else:
lycoris_config_path = lycoris_path
# Build config overrides for publishing
config_overrides: Dict[str, Any] = {}
# LyCORIS config
if lycoris_config_path:
config_overrides["--lycoris_config"] = str(lycoris_config_path)
# S3 publishing config
publishing_config: Optional[List[Dict[str, Any]]] = None
if s3_bucket:
s3_access_value = s3_access_key.get_secret_value() if s3_access_key else None
s3_secret_value = s3_secret_key.get_secret_value() if s3_secret_key else None
publishing_config = self._build_s3_publishing_config(
bucket=s3_bucket,
base_path=s3_base_path,
region=s3_region,
endpoint_url=s3_endpoint_url,
access_key=s3_access_value,
secret_key=s3_secret_value,
public_base_url=s3_public_base_url,
)
config_overrides["publishing_config"] = publishing_config
# HuggingFace Hub publishing config
if hub_model_id:
if not token_value:
raise ValueError("hf_token is required when using hub_model_id for HuggingFace Hub publishing.")
# Strip any URL prefix from hub_model_id (Replicate's proxy can add prefixes)
clean_hub_model_id = self._normalize_hub_model_id(hub_model_id)
# Override HF_ENDPOINT to bypass Replicate's proxy for Hub uploads
os.environ["HF_ENDPOINT"] = "https://huggingface.co"
config_overrides["--push_to_hub"] = True
config_overrides["--hub_model_id"] = clean_hub_model_id
config_overrides["--push_checkpoints_to_hub"] = True
# Start the webhook receiver to capture training events in Cog logs
from simpletuner.cog import CogWebhookReceiver
with CogWebhookReceiver() as webhook_receiver:
webhook_config = [CogWebhookReceiver.build_webhook_config(webhook_receiver.url)]
run_result = self.runner.run(
dataset_archive=dataset_archive,
hf_token=token_value,
base_config_path=config_path,
base_config_dict=config_dict,
dataloader_config_path=dataloader_path,
dataloader_config_dict=dataloader_dict,
config_overrides=config_overrides,
max_train_steps=max_train_steps,
webhook_config=webhook_config,
)
if return_logs:
log_tail = self.runner.read_debug_log()
if log_tail:
print("\n=== debug.log tail ===")
print(log_tail[-5000:])
# Build output URL based on publishing destination
if s3_bucket and publishing_config:
path_prefix = publishing_config[0].get("base_path", "").lstrip("/")
if s3_public_base_url:
output_url = f"{s3_public_base_url.rstrip('/')}/{path_prefix}"
elif s3_endpoint_url:
output_url = f"{s3_endpoint_url.rstrip('/')}/{s3_bucket}/{path_prefix}"
else:
output_url = f"s3://{s3_bucket}/{path_prefix}"
print(f"\nCheckpoints published to: {output_url}")
elif hub_model_id:
output_url = f"https://huggingface.co/{self._normalize_hub_model_id(hub_model_id)}"
print(f"\nModel published to: {output_url}")
else:
# Publishing configured via config_json
output_url = f"Training complete. Output: {run_result['output_dir']}"
print(f"\n{output_url}")
return output_url
@staticmethod
def _build_s3_publishing_config(
*,
bucket: str,
base_path: Optional[str] = None,
region: Optional[str] = None,
endpoint_url: Optional[str] = None,
access_key: Optional[str] = None,
secret_key: Optional[str] = None,
public_base_url: Optional[str] = None,
) -> List[Dict[str, Any]]:
"""Build S3 publishing config for SimpleTuner."""
path_prefix = base_path or "simpletuner"
entry: Dict[str, Any] = {
"provider": "s3",
"bucket": bucket,
"base_path": path_prefix,
}
if region:
entry["region"] = region
if endpoint_url:
entry["endpoint_url"] = endpoint_url
if access_key:
entry["access_key"] = access_key
if secret_key:
entry["secret_key"] = secret_key
if public_base_url:
entry["public_base_url"] = public_base_url
return [entry]