@@ -564,6 +564,7 @@ def get_eula_message(model_specs: JumpStartModelSpecs, region: str) -> str:
564564
565565
566566def format_eula_message_from_specs (model_id : str , region : str , hosting_eula_key : str ):
567+ """Returns a formatted EULA message."""
567568 return (
568569 f"Model '{ model_id } ' requires accepting end-user license agreement (EULA). "
569570 f"See https://{ get_jumpstart_content_bucket (region = region )} .s3.{ region } ."
@@ -1552,21 +1553,25 @@ def _add_model_access_configs_to_model_data_sources(
15521553 hosting_eula_key = model_data_source .get ("HostingEulaKey" )
15531554 if hosting_eula_key :
15541555 if not model_access_configs or not model_access_configs .get (model_id ):
1555- eula_message_template = "{model_source}{base_eula_message}{model_access_configs_message}"
1556+ eula_message_template = (
1557+ "{model_source}{base_eula_message}{model_access_configs_message}"
1558+ )
15561559 model_access_config_entry = (
1557- " \" {model_id}\ " :ModelAccessConfig(accept_eula=True)" .format (model_id = model_id )
1560+ '" {model_id}":ModelAccessConfig(accept_eula=True)' .format (model_id = model_id )
15581561 )
1559- raise ValueError (eula_message_template .format (
1560- model_source = "Draft " if model_data_source .get ("ChannelName" ) else "" ,
1561- base_eula_message = format_eula_message_from_specs (
1562- model_id = model_id , region = region , hosting_eula_key = hosting_eula_key
1563- ),
1564- model_access_configs_message = (
1565- " Please add a ModelAccessConfig entry:"
1566- f" { model_access_config_entry } "
1567- "to model_access_configs to acknowledge the EULA."
1562+ raise ValueError (
1563+ eula_message_template .format (
1564+ model_source = "Draft " if model_data_source .get ("ChannelName" ) else "" ,
1565+ base_eula_message = format_eula_message_from_specs (
1566+ model_id = model_id , region = region , hosting_eula_key = hosting_eula_key
1567+ ),
1568+ model_access_configs_message = (
1569+ " Please add a ModelAccessConfig entry:"
1570+ f" { model_access_config_entry } "
1571+ "to model_access_configs to acknowledge the EULA."
1572+ ),
15681573 )
1569- ))
1574+ )
15701575 acked_model_data_source = model_data_source .copy ()
15711576 acked_model_data_source .pop ("HostingEulaKey" )
15721577 acked_model_data_source ["S3DataSource" ]["ModelAccessConfig" ] = (
@@ -1576,3 +1581,17 @@ def _add_model_access_configs_to_model_data_sources(
15761581 else :
15771582 acked_model_data_sources .append (model_data_source )
15781583 return acked_model_data_sources
1584+
1585+
1586+ def get_draft_model_content_bucket (provider : Dict , region : str ) -> str :
1587+ """Returns the correct content bucket for a 1p draft model."""
1588+ neo_bucket = get_neo_content_bucket (region = region )
1589+ if not provider :
1590+ return neo_bucket
1591+ provider_name = provider .get ("name" , "" )
1592+ if provider_name == "JumpStart" :
1593+ classification = provider .get ("classification" , "ungated" )
1594+ if classification == "gated" :
1595+ return get_jumpstart_gated_content_bucket (region = region )
1596+ return get_jumpstart_content_bucket (region = region )
1597+ return neo_bucket
0 commit comments