1414"""This module contains utilities related to SageMaker JumpStart Hub.""" 
1515from  __future__ import  absolute_import 
1616import  re 
17- from  typing  import  Optional 
17+ from  typing  import  Optional ,  List ,  Any 
1818from  sagemaker .jumpstart .hub .types  import  S3ObjectLocation 
1919from  sagemaker .s3_utils  import  parse_s3_url 
2020from  sagemaker .session  import  Session 
2323from  sagemaker .jumpstart  import  constants 
2424from  packaging .specifiers  import  SpecifierSet , InvalidSpecifier 
2525
26+ PROPRIETARY_VERSION_KEYWORD  =  "@marketplace-version:" 
27+ 
28+ 
29+ def  _convert_str_to_optional (string : str ) ->  Optional [str ]:
30+     if  string  ==  "None" :
31+         string  =  None 
32+     return  string 
33+ 
2634
2735def  get_info_from_hub_resource_arn (
2836    arn : str ,
@@ -37,7 +45,7 @@ def get_info_from_hub_resource_arn(
3745        hub_name  =  match .group (4 )
3846        hub_content_type  =  match .group (5 )
3947        hub_content_name  =  match .group (6 )
40-         hub_content_version  =  match .group (7 )
48+         hub_content_version  =  _convert_str_to_optional ( match .group (7 ) )
4149
4250        return  HubArnExtractedInfo (
4351            partition = partition ,
@@ -194,10 +202,14 @@ def get_hub_model_version(
194202    hub_model_version : Optional [str ] =  None ,
195203    sagemaker_session : Session  =  constants .DEFAULT_JUMPSTART_SAGEMAKER_SESSION ,
196204) ->  str :
197-     """Returns available Jumpstart hub model version 
205+     """Returns available Jumpstart hub model version. 
206+ 
207+     It will attempt both a semantic HubContent version search and Marketplace version search. 
208+     If the Marketplace version is also semantic, this function will default to HubContent version. 
198209
199210    Raises: 
200211        ClientError: If the specified model is not found in the hub. 
212+         KeyError: If the specified model version is not found. 
201213    """ 
202214
203215    try :
@@ -207,6 +219,22 @@ def get_hub_model_version(
207219    except  Exception  as  ex :
208220        raise  Exception (f"Failed calling list_hub_content_versions: { str (ex )}  " )
209221
222+     try :
223+         return  _get_hub_model_version_for_open_weight_version (
224+             hub_content_summaries , hub_model_version 
225+         )
226+     except  KeyError :
227+         marketplace_hub_content_version  =  _get_hub_model_version_for_marketplace_version (
228+             hub_content_summaries , hub_model_version 
229+         )
230+         if  marketplace_hub_content_version :
231+             return  marketplace_hub_content_version 
232+         raise 
233+ 
234+ 
235+ def  _get_hub_model_version_for_open_weight_version (
236+     hub_content_summaries : List [Any ], hub_model_version : Optional [str ] =  None 
237+ ) ->  str :
210238    available_model_versions  =  [model .get ("HubContentVersion" ) for  model  in  hub_content_summaries ]
211239
212240    if  hub_model_version  ==  "*"  or  hub_model_version  is  None :
@@ -222,3 +250,37 @@ def get_hub_model_version(
222250    hub_model_version  =  str (max (available_versions_filtered ))
223251
224252    return  hub_model_version 
253+ 
254+ 
255+ def  _get_hub_model_version_for_marketplace_version (
256+     hub_content_summaries : List [Any ], marketplace_version : str 
257+ ) ->  Optional [str ]:
258+     """Returns the HubContent version associated with the Marketplace version. 
259+ 
260+     This function will check within the HubContentSearchKeywords for the proprietary version. 
261+     """ 
262+     for  model  in  hub_content_summaries :
263+         model_search_keywords  =  model .get ("HubContentSearchKeywords" , [])
264+         if  _hub_search_keywords_contains_marketplace_version (
265+             model_search_keywords , marketplace_version 
266+         ):
267+             return  model .get ("HubContentVersion" )
268+ 
269+     return  None 
270+ 
271+ 
272+ def  _hub_search_keywords_contains_marketplace_version (
273+     model_search_keywords : List [str ], marketplace_version : str 
274+ ) ->  bool :
275+     proprietary_version_keyword  =  next (
276+         filter (lambda  s : s .startswith (PROPRIETARY_VERSION_KEYWORD ), model_search_keywords ), None 
277+     )
278+ 
279+     if  not  proprietary_version_keyword :
280+         return  False 
281+ 
282+     proprietary_version  =  proprietary_version_keyword .lstrip (PROPRIETARY_VERSION_KEYWORD )
283+     if  proprietary_version  ==  marketplace_version :
284+         return  True 
285+ 
286+     return  False 
0 commit comments