Skip to content

Commit c636658

Browse files
authored
Fixed evaluation validation for model group deployment. (#1268)
2 parents b9752df + f11a3a3 commit c636658

File tree

1 file changed

+29
-21
lines changed

1 file changed

+29
-21
lines changed

ads/aqua/evaluation/evaluation.py

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@
9999
from ads.jobs.builders.runtimes.base import Runtime
100100
from ads.jobs.builders.runtimes.container_runtime import ContainerRuntime
101101
from ads.model.datascience_model import DataScienceModel
102+
from ads.model.datascience_model_group import DataScienceModelGroup
102103
from ads.model.deployment import ModelDeploymentContainerRuntime
103104
from ads.model.deployment.model_deployment import ModelDeployment
104105
from ads.model.generic_model import ModelDeploymentRuntimeType
@@ -254,7 +255,11 @@ def create(
254255
f"Make sure the {Tags.AQUA_MODEL_ID_TAG} tag is added to the deployment."
255256
)
256257

257-
aqua_model = DataScienceModel.from_id(multi_model_id)
258+
aqua_model = (
259+
DataScienceModelGroup.from_id(multi_model_id)
260+
if "datasciencemodelgroup" in multi_model_id
261+
else DataScienceModel.from_id(multi_model_id)
262+
)
258263
AquaEvaluationApp.validate_model_name(
259264
aqua_model, create_aqua_evaluation_details
260265
)
@@ -630,23 +635,23 @@ def create(
630635

631636
@staticmethod
632637
def validate_model_name(
633-
evaluation_source: DataScienceModel,
638+
evaluation_source: Union[DataScienceModel, DataScienceModelGroup],
634639
create_aqua_evaluation_details: CreateAquaEvaluationDetails,
635640
) -> None:
636641
"""
637642
Validates the user input for the model name when creating an Aqua evaluation.
638643
639644
This function verifies that:
640645
- The model group is not empty.
641-
- The model multi metadata is present in the DataScienceModel metadata.
646+
- The model multi metadata is present in the DataScienceModel or DataScienceModelGroup metadata.
642647
- The user provided a non-empty model name.
643-
- The provided model name exists in the DataScienceModel metadata.
648+
- The provided model name exists in the DataScienceModel or DataScienceModelGroup metadata.
644649
- The deployment configuration contains core metadata required for validation.
645650
646651
Parameters
647652
----------
648-
evaluation_source : DataScienceModel
649-
The DataScienceModel object containing metadata about each model in the deployment.
653+
evaluation_source : Union[DataScienceModel, DataScienceModelGroup]
654+
The DataScienceModel or DataScienceModelGroup object containing metadata about each model in the deployment.
650655
create_aqua_evaluation_details : CreateAquaEvaluationDetails
651656
Contains required and optional fields for creating the Aqua evaluation.
652657
@@ -711,27 +716,30 @@ def validate_model_name(
711716
logger.debug(error_message)
712717
raise AquaRuntimeError(error_message)
713718

714-
try:
715-
multi_model_metadata = json.loads(
716-
evaluation_source.dsc_model.get_custom_metadata_artifact(
717-
metadata_key_name=ModelCustomMetadataFields.MULTIMODEL_METADATA
718-
).decode("utf-8")
719-
)
720-
except Exception as ex:
721-
error_message = (
722-
f"Error fetching {ModelCustomMetadataFields.MULTIMODEL_METADATA} "
723-
f"from custom metadata for evaluation source ID '{evaluation_source.id}'. "
724-
f"Details: {ex}"
725-
)
726-
logger.error(error_message)
727-
raise AquaRuntimeError(error_message) from ex
719+
if isinstance(evaluation_source, DataScienceModel):
720+
try:
721+
multi_model_metadata = json.loads(
722+
evaluation_source.dsc_model.get_custom_metadata_artifact(
723+
metadata_key_name=ModelCustomMetadataFields.MULTIMODEL_METADATA
724+
).decode("utf-8")
725+
)
726+
except Exception as ex:
727+
error_message = (
728+
f"Error fetching {ModelCustomMetadataFields.MULTIMODEL_METADATA} "
729+
f"from custom metadata for evaluation source ID '{evaluation_source.id}'. "
730+
f"Details: {ex}"
731+
)
732+
logger.error(error_message)
733+
raise AquaRuntimeError(error_message) from ex
728734

729735
# Build the list of valid model names from custom metadata.
730736
model_names = []
731737
for metadata in multi_model_metadata:
732738
model = AquaMultiModelRef(**metadata)
733739
model_names.append(model.model_name)
734-
model_names.extend(ft.model_name for ft in (model.fine_tune_weights or []) if ft.model_name)
740+
model_names.extend(
741+
ft.model_name for ft in (model.fine_tune_weights or []) if ft.model_name
742+
)
735743

736744
# Check if the provided model name is among the valid names.
737745
if user_model_name not in model_names:

0 commit comments

Comments
 (0)