1414"""This module contains utilities related to SageMaker JumpStart Hub."""
1515from __future__ import absolute_import
1616import re
17- from typing import Optional
17+ from typing import Optional , List , Any
1818from sagemaker .jumpstart .hub .types import S3ObjectLocation
1919from sagemaker .s3_utils import parse_s3_url
2020from sagemaker .session import Session
2323from sagemaker .jumpstart import constants
2424from packaging .specifiers import SpecifierSet , InvalidSpecifier
2525
26+ PROPRIETARY_VERSION_KEYWORD = "@marketplace-version:"
27+
28+
29+ def _convert_str_to_optional (string : str ) -> Optional [str ]:
30+ if string == "None" :
31+ string = None
32+ return string
33+
2634
2735def get_info_from_hub_resource_arn (
2836 arn : str ,
@@ -37,7 +45,7 @@ def get_info_from_hub_resource_arn(
3745 hub_name = match .group (4 )
3846 hub_content_type = match .group (5 )
3947 hub_content_name = match .group (6 )
40- hub_content_version = match .group (7 )
48+ hub_content_version = _convert_str_to_optional ( match .group (7 ) )
4149
4250 return HubArnExtractedInfo (
4351 partition = partition ,
@@ -194,10 +202,14 @@ def get_hub_model_version(
194202 hub_model_version : Optional [str ] = None ,
195203 sagemaker_session : Session = constants .DEFAULT_JUMPSTART_SAGEMAKER_SESSION ,
196204) -> str :
197- """Returns available Jumpstart hub model version
205+ """Returns available Jumpstart hub model version.
206+
207+ It will attempt both a semantic HubContent version search and Marketplace version search.
208+ If the Marketplace version is also semantic, this function will default to HubContent version.
198209
199210 Raises:
200211 ClientError: If the specified model is not found in the hub.
212+ KeyError: If the specified model version is not found.
201213 """
202214
203215 try :
@@ -207,6 +219,23 @@ def get_hub_model_version(
207219 except Exception as ex :
208220 raise Exception (f"Failed calling list_hub_content_versions: { str (ex )} " )
209221
222+ marketplace_hub_content_version = _get_hub_model_version_for_marketplace_version (
223+ hub_content_summaries , hub_model_version
224+ )
225+
226+ try :
227+ return _get_hub_model_version_for_open_weight_version (
228+ hub_content_summaries , hub_model_version
229+ )
230+ except KeyError as e :
231+ if marketplace_hub_content_version :
232+ return marketplace_hub_content_version
233+ raise e
234+
235+
236+ def _get_hub_model_version_for_open_weight_version (
237+ hub_content_summaries : List [Any ], hub_model_version : Optional [str ] = None
238+ ) -> str :
210239 available_model_versions = [model .get ("HubContentVersion" ) for model in hub_content_summaries ]
211240
212241 if hub_model_version == "*" or hub_model_version is None :
@@ -222,3 +251,37 @@ def get_hub_model_version(
222251 hub_model_version = str (max (available_versions_filtered ))
223252
224253 return hub_model_version
254+
255+
256+ def _get_hub_model_version_for_marketplace_version (
257+ hub_content_summaries : List [Any ], marketplace_version : str
258+ ) -> Optional [str ]:
259+ """Returns the HubContent version associated with the Marketplace version.
260+
261+ This function will check within the HubContentSearchKeywords for the proprietary version.
262+ """
263+ for model in hub_content_summaries :
264+ model_search_keywords = model .get ("HubContentSearchKeywords" , [])
265+ if _hub_search_keywords_contains_marketplace_version (
266+ model_search_keywords , marketplace_version
267+ ):
268+ return model .get ("HubContentVersion" )
269+
270+ return None
271+
272+
273+ def _hub_search_keywords_contains_marketplace_version (
274+ model_search_keywords : List [str ], marketplace_version : str
275+ ) -> bool :
276+ proprietary_version_keyword = next (
277+ filter (lambda s : s .startswith (PROPRIETARY_VERSION_KEYWORD ), model_search_keywords ), None
278+ )
279+
280+ if not proprietary_version_keyword :
281+ return False
282+
283+ proprietary_version = proprietary_version_keyword .lstrip (PROPRIETARY_VERSION_KEYWORD )
284+ if proprietary_version == marketplace_version :
285+ return True
286+
287+ return False
0 commit comments