11#!/usr/bin/env python
2- # -*- coding: utf-8 -*-
32# Copyright (c) 2024 Oracle and/or its affiliates.
43# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
54"""AQUA utils and constants."""
5+
66import asyncio
77import base64
88import json
1919import oci
2020from oci .data_science .models import JobRun , Model
2121
22- from ads .aqua .common .enums import RqsAdditionalDetails
22+ from ads .aqua .common .enums import (
23+ InferenceContainerParamType ,
24+ InferenceContainerType ,
25+ RqsAdditionalDetails ,
26+ )
2327from ads .aqua .common .errors import (
2428 AquaFileNotFoundError ,
2529 AquaRuntimeError ,
2630 AquaValueError ,
2731)
28- from ads .aqua .constants import *
32+ from ads .aqua .constants import (
33+ AQUA_GA_LIST ,
34+ COMPARTMENT_MAPPING_KEY ,
35+ CONSOLE_LINK_RESOURCE_TYPE_MAPPING ,
36+ CONTAINER_INDEX ,
37+ MAXIMUM_ALLOWED_DATASET_IN_BYTE ,
38+ MODEL_BY_REFERENCE_OSS_PATH_KEY ,
39+ SERVICE_MANAGED_CONTAINER_URI_SCHEME ,
40+ SUPPORTED_FILE_FORMATS ,
41+ TGI_INFERENCE_RESTRICTED_PARAMS ,
42+ UNKNOWN ,
43+ UNKNOWN_JSON_STR ,
44+ VLLM_INFERENCE_RESTRICTED_PARAMS ,
45+ )
2946from ads .aqua .data import AquaResourceIdentifier
3047from ads .common .auth import default_signer
3148from ads .common .decorator .threaded import threaded
@@ -74,15 +91,15 @@ def get_status(evaluation_status: str, job_run_status: str = None):
7491
7592 status = LifecycleStatus .UNKNOWN
7693 if evaluation_status == Model .LIFECYCLE_STATE_ACTIVE :
77- if (
78- job_run_status == JobRun .LIFECYCLE_STATE_IN_PROGRESS
79- or job_run_status == JobRun .LIFECYCLE_STATE_ACCEPTED
80- ) :
94+ if job_run_status in {
95+ JobRun .LIFECYCLE_STATE_IN_PROGRESS ,
96+ JobRun .LIFECYCLE_STATE_ACCEPTED ,
97+ } :
8198 status = JobRun .LIFECYCLE_STATE_IN_PROGRESS
82- elif (
83- job_run_status == JobRun .LIFECYCLE_STATE_FAILED
84- or job_run_status == JobRun .LIFECYCLE_STATE_NEEDS_ATTENTION
85- ) :
99+ elif job_run_status in {
100+ JobRun .LIFECYCLE_STATE_FAILED ,
101+ JobRun .LIFECYCLE_STATE_NEEDS_ATTENTION ,
102+ } :
86103 status = JobRun .LIFECYCLE_STATE_FAILED
87104 else :
88105 status = job_run_status
@@ -199,10 +216,7 @@ def read_file(file_path: str, **kwargs) -> str:
199216@threaded ()
200217def load_config (file_path : str , config_file_name : str , ** kwargs ) -> dict :
201218 artifact_path = f"{ file_path .rstrip ('/' )} /{ config_file_name } "
202- if artifact_path .startswith ("oci://" ):
203- signer = default_signer ()
204- else :
205- signer = {}
219+ signer = default_signer () if artifact_path .startswith ("oci://" ) else {}
206220 config = json .loads (
207221 read_file (file_path = artifact_path , auth = signer , ** kwargs ) or UNKNOWN_JSON_STR
208222 )
@@ -448,7 +462,7 @@ def _build_resource_identifier(
448462
449463
450464def _get_experiment_info (
451- model : Union [oci .resource_search .models .ResourceSummary , DataScienceModel ]
465+ model : Union [oci .resource_search .models .ResourceSummary , DataScienceModel ],
452466) -> tuple :
453467 """Returns ocid and name of the experiment."""
454468 return (
@@ -609,7 +623,7 @@ def extract_id_and_name_from_tag(tag: str):
609623 base_model_name = UNKNOWN
610624 try :
611625 base_model_ocid , base_model_name = tag .split ("#" )
612- except :
626+ except Exception :
613627 pass
614628
615629 if not (is_valid_ocid (base_model_ocid ) and base_model_name ):
@@ -646,7 +660,7 @@ def get_resource_name(ocid: str) -> str:
646660 try :
647661 resource = query_resource (ocid , return_all = False )
648662 name = resource .display_name if resource else UNKNOWN
649- except :
663+ except Exception :
650664 name = UNKNOWN
651665 return name
652666
@@ -670,8 +684,8 @@ def get_model_by_reference_paths(model_file_description: dict):
670684
671685 if not models :
672686 raise AquaValueError (
673- f "Model path is not available in the model json artifact. "
674- f "Please check if the model created by reference has the correct artifact."
687+ "Model path is not available in the model json artifact. "
688+ "Please check if the model created by reference has the correct artifact."
675689 )
676690
677691 if len (models ) > 0 :
@@ -848,3 +862,46 @@ def copy_model_config(artifact_path: str, os_path: str, auth: dict = None):
848862 except Exception as ex :
849863 logger .debug (ex )
850864 logger .debug (f"Failed to copy config folder from { artifact_path } to { os_path } ." )
865+
866+
867+ def get_container_params_type (container_type_name : str ) -> str :
868+ """The utility function accepts the deployment container type name and returns the corresponding params name.
869+ Parameters
870+ ----------
871+ container_type_name: str
872+ type of deployment container, like odsc-vllm-serving or odsc-tgi-serving.
873+
874+ Returns
875+ -------
876+ InferenceContainerParamType value
877+
878+ """
879+ # check substring instead of direct match in case container_type_name changes in the future
880+ if InferenceContainerType .CONTAINER_TYPE_VLLM in container_type_name .lower ():
881+ return InferenceContainerParamType .PARAM_TYPE_VLLM
882+ elif InferenceContainerType .CONTAINER_TYPE_TGI in container_type_name .lower ():
883+ return InferenceContainerParamType .PARAM_TYPE_TGI
884+ else :
885+ return UNKNOWN
886+
887+
888+ def get_restricted_params_by_container (container_type_name : str ) -> set :
889+ """The utility function accepts the deployment container type name and returns a set of restricted params
890+ for that container.
891+ Parameters
892+ ----------
893+ container_type_name: str
894+ type of deployment container, like odsc-vllm-serving or odsc-tgi-serving.
895+
896+ Returns
897+ -------
898+ Set of restricted params based on container type
899+
900+ """
901+ # check substring instead of direct match in case container_type_name changes in the future
902+ if InferenceContainerType .CONTAINER_TYPE_VLLM in container_type_name .lower ():
903+ return VLLM_INFERENCE_RESTRICTED_PARAMS
904+ elif InferenceContainerType .CONTAINER_TYPE_TGI in container_type_name .lower ():
905+ return TGI_INFERENCE_RESTRICTED_PARAMS
906+ else :
907+ return set ()
0 commit comments