|
10 | 10 | import oci |
11 | 11 | from cachetools import TTLCache |
12 | 12 | from huggingface_hub import snapshot_download |
13 | | -from oci.data_science.models import JobRun, Model |
| 13 | +from oci.data_science.models import JobRun, Metadata, Model, UpdateModelDetails |
14 | 14 |
|
15 | 15 | from ads.aqua import ODSC_MODEL_COMPARTMENT_OCID, logger |
16 | 16 | from ads.aqua.app import AquaApp |
17 | | -from ads.aqua.common.enums import InferenceContainerTypeFamily, Tags |
| 17 | +from ads.aqua.common.enums import ( |
| 18 | + FineTuningContainerTypeFamily, |
| 19 | + InferenceContainerTypeFamily, |
| 20 | + Tags, |
| 21 | +) |
18 | 22 | from ads.aqua.common.errors import AquaRuntimeError, AquaValueError |
19 | 23 | from ads.aqua.common.utils import ( |
20 | 24 | LifecycleStatus, |
|
23 | 27 | create_word_icon, |
24 | 28 | generate_tei_cmd_var, |
25 | 29 | get_artifact_path, |
| 30 | + get_container_config, |
26 | 31 | get_hf_model_info, |
27 | 32 | list_os_files_with_extension, |
28 | 33 | load_config, |
|
78 | 83 | TENANCY_OCID, |
79 | 84 | ) |
80 | 85 | from ads.model import DataScienceModel |
81 | | -from ads.model.model_metadata import ModelCustomMetadata, ModelCustomMetadataItem |
| 86 | +from ads.model.model_metadata import ( |
| 87 | + MetadataCustomCategory, |
| 88 | + ModelCustomMetadata, |
| 89 | + ModelCustomMetadataItem, |
| 90 | +) |
82 | 91 | from ads.telemetry import telemetry |
83 | 92 |
|
84 | 93 |
|
@@ -333,6 +342,96 @@ def get(self, model_id: str, load_model_card: Optional[bool] = True) -> "AquaMod |
333 | 342 |
|
334 | 343 | return model_details |
335 | 344 |
|
| 345 | + @telemetry(entry_point="plugin=model&action=delete", name="aqua") |
| 346 | + def delete_model(self, model_id): |
| 347 | + ds_model = DataScienceModel.from_id(model_id) |
| 348 | + is_registered_model = ds_model.freeform_tags.get(Tags.BASE_MODEL_CUSTOM, None) |
| 349 | + is_fine_tuned_model = ds_model.freeform_tags.get( |
| 350 | + Tags.AQUA_FINE_TUNED_MODEL_TAG, None |
| 351 | + ) |
| 352 | + if is_registered_model or is_fine_tuned_model: |
| 353 | + return ds_model.delete() |
| 354 | + else: |
| 355 | + raise AquaRuntimeError( |
| 356 | + f"Failed to delete model:{model_id}. Only registered models or finetuned model can be deleted." |
| 357 | + ) |
| 358 | + |
| 359 | + @telemetry(entry_point="plugin=model&action=delete", name="aqua") |
| 360 | + def edit_registered_model(self, id, inference_container, enable_finetuning, task): |
| 361 | + """Edits the default config of unverified registered model. |
| 362 | +
|
| 363 | + Parameters |
| 364 | + ---------- |
| 365 | + id: str |
| 366 | + The model OCID. |
| 367 | + inference_container: str. |
| 368 | + The inference container family name |
| 369 | + enable_finetuning: str |
| 370 | + Flag to enable or disable finetuning over the model. Defaults to None |
| 371 | + task: |
| 372 | + The usecase type of the model. e.g , text-generation , text_embedding etc. |
| 373 | +
|
| 374 | + Returns |
| 375 | + ------- |
| 376 | + Model: |
| 377 | + The instance of oci.data_science.models.Model. |
| 378 | +
|
| 379 | + """ |
| 380 | + ds_model = DataScienceModel.from_id(id) |
| 381 | + if ds_model.freeform_tags.get(Tags.BASE_MODEL_CUSTOM, None): |
| 382 | + if ds_model.freeform_tags.get(Tags.AQUA_SERVICE_MODEL_TAG, None): |
| 383 | + raise AquaRuntimeError( |
| 384 | + f"Failed to edit model:{id}. Only registered unverified models can be edited." |
| 385 | + ) |
| 386 | + else: |
| 387 | + custom_metadata_list = ds_model.custom_metadata_list |
| 388 | + freeform_tags = ds_model.freeform_tags |
| 389 | + if inference_container: |
| 390 | + custom_metadata_list.add( |
| 391 | + key=ModelCustomMetadataFields.DEPLOYMENT_CONTAINER, |
| 392 | + value=inference_container, |
| 393 | + category=MetadataCustomCategory.OTHER, |
| 394 | + description="Deployment container mapping for SMC", |
| 395 | + replace=True, |
| 396 | + ) |
| 397 | + if enable_finetuning is not None: |
| 398 | + if enable_finetuning.lower() == "true": |
| 399 | + custom_metadata_list.add( |
| 400 | + key=ModelCustomMetadataFields.FINETUNE_CONTAINER, |
| 401 | + value=FineTuningContainerTypeFamily.AQUA_FINETUNING_CONTAINER_FAMILY, |
| 402 | + category=MetadataCustomCategory.OTHER, |
| 403 | + description="Fine-tuning container mapping for SMC", |
| 404 | + replace=True, |
| 405 | + ) |
| 406 | + freeform_tags.update({Tags.READY_TO_FINE_TUNE: "true"}) |
| 407 | + elif enable_finetuning.lower() == "false": |
| 408 | + try: |
| 409 | + custom_metadata_list.remove( |
| 410 | + ModelCustomMetadataFields.FINETUNE_CONTAINER |
| 411 | + ) |
| 412 | + freeform_tags.pop(Tags.READY_TO_FINE_TUNE) |
| 413 | + except Exception as ex: |
| 414 | + raise AquaRuntimeError( |
| 415 | + f"The given model already doesn't support finetuning: {ex}" |
| 416 | + ) |
| 417 | + |
| 418 | + custom_metadata_list.remove("modelDescription") |
| 419 | + if task: |
| 420 | + freeform_tags.update({Tags.TASK: task}) |
| 421 | + updated_custom_metadata_list = [ |
| 422 | + Metadata(**metadata) |
| 423 | + for metadata in custom_metadata_list.to_dict()["data"] |
| 424 | + ] |
| 425 | + update_model_details = UpdateModelDetails( |
| 426 | + custom_metadata_list=updated_custom_metadata_list, |
| 427 | + freeform_tags=freeform_tags, |
| 428 | + ) |
| 429 | + AquaApp().update_model(id, update_model_details) |
| 430 | + else: |
| 431 | + raise AquaRuntimeError( |
| 432 | + f"Failed to edit model:{id}. Only registered unverified models can be edited." |
| 433 | + ) |
| 434 | + |
336 | 435 | def _fetch_metric_from_metadata( |
337 | 436 | self, |
338 | 437 | custom_metadata_list: ModelCustomMetadata, |
@@ -629,6 +728,32 @@ def clear_model_list_cache( |
629 | 728 | } |
630 | 729 | return res |
631 | 730 |
|
| 731 | + def clear_model_details_cache(self, model_id): |
| 732 | + """ |
| 733 | + Allows user to clear model details cache item |
| 734 | + Returns |
| 735 | + ------- |
| 736 | + dict with the key used, and True if cache has the key that needs to be deleted. |
| 737 | + """ |
| 738 | + res = {} |
| 739 | + logger.info(f"Clearing _service_model_details_cache for {model_id}") |
| 740 | + with self._cache_lock: |
| 741 | + if model_id in self._service_model_details_cache: |
| 742 | + self._service_model_details_cache.pop(key=model_id) |
| 743 | + res = {"key": {"model_id": model_id}, "cache_deleted": True} |
| 744 | + |
| 745 | + return res |
| 746 | + |
| 747 | + @staticmethod |
| 748 | + def list_valid_inference_containers(): |
| 749 | + containers = list( |
| 750 | + AquaContainerConfig.from_container_index_json( |
| 751 | + config=get_container_config(), enable_spec=True |
| 752 | + ).inference.values() |
| 753 | + ) |
| 754 | + family_values = [item.family for item in containers] |
| 755 | + return family_values |
| 756 | + |
632 | 757 | def _create_model_catalog_entry( |
633 | 758 | self, |
634 | 759 | os_path: str, |
|
0 commit comments