@@ -1119,6 +1119,7 @@ def optimize(
11191119 quantization_config : Optional [Dict ] = None ,
11201120 compilation_config : Optional [Dict ] = None ,
11211121 speculative_decoding_config : Optional [Dict ] = None ,
1122+ sharding_config : Optional [Dict ] = None ,
11221123 env_vars : Optional [Dict ] = None ,
11231124 vpc_config : Optional [Dict ] = None ,
11241125 kms_key : Optional [str ] = None ,
@@ -1142,6 +1143,8 @@ def optimize(
11421143 compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``.
11431144 speculative_decoding_config (Optional[Dict]): Speculative decoding configuration.
11441145 Defaults to ``None``
1146+ sharding_config (Optional[Dict]): Model sharding configuration.
1147+ Defaults to ``None``
11451148 env_vars (Optional[Dict]): Additional environment variables to run the optimization
11461149 container. Defaults to ``None``.
11471150 vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``.
@@ -1170,6 +1173,7 @@ def optimize(
11701173 quantization_config = quantization_config ,
11711174 compilation_config = compilation_config ,
11721175 speculative_decoding_config = speculative_decoding_config ,
1176+ sharding_config = sharding_config ,
11731177 env_vars = env_vars ,
11741178 vpc_config = vpc_config ,
11751179 kms_key = kms_key ,
@@ -1189,6 +1193,7 @@ def _model_builder_optimize_wrapper(
11891193 quantization_config : Optional [Dict ] = None ,
11901194 compilation_config : Optional [Dict ] = None ,
11911195 speculative_decoding_config : Optional [Dict ] = None ,
1196+ sharding_config : Optional [Dict ] = None ,
11921197 env_vars : Optional [Dict ] = None ,
11931198 vpc_config : Optional [Dict ] = None ,
11941199 kms_key : Optional [str ] = None ,
@@ -1212,6 +1217,8 @@ def _model_builder_optimize_wrapper(
12121217 compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``.
12131218 speculative_decoding_config (Optional[Dict]): Speculative decoding configuration.
12141219 Defaults to ``None``
1220+ sharding_config (Optional[Dict]): Model sharding configuration.
1221+ Defaults to ``None``
12151222 env_vars (Optional[Dict]): Additional environment variables to run the optimization
12161223 container. Defaults to ``None``.
12171224 vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``.
@@ -1238,6 +1245,12 @@ def _model_builder_optimize_wrapper(
12381245 if quantization_config and compilation_config :
12391246 raise ValueError ("Quantization config and compilation config are mutually exclusive." )
12401247
1248+ if sharding_config and (quantization_config or compilation_config or speculative_decoding_config ):
1249+ raise ValueError ("Sharding config is mutually exclusive and cannot be combined with any other optimization." )
1250+
1251+ if sharding_config and "OPTION_TENSOR_PARALLEL_DEGREE" not in env_vars :
1252+ raise ValueError ("OPTION_TENSOR_PARALLEL_DEGREE is required environment variable with Sharding config." )
1253+
12411254 self .sagemaker_session = sagemaker_session or self .sagemaker_session or Session ()
12421255 self .instance_type = instance_type or self .instance_type
12431256 self .role_arn = role_arn or self .role_arn
@@ -1254,6 +1267,7 @@ def _model_builder_optimize_wrapper(
12541267 quantization_config = quantization_config ,
12551268 compilation_config = compilation_config ,
12561269 speculative_decoding_config = speculative_decoding_config ,
1270+ sharding_config = sharding_config ,
12571271 env_vars = env_vars ,
12581272 vpc_config = vpc_config ,
12591273 kms_key = kms_key ,
@@ -1272,6 +1286,7 @@ def _model_builder_optimize_wrapper(
12721286 quantization_config = quantization_config ,
12731287 compilation_config = compilation_config ,
12741288 speculative_decoding_config = speculative_decoding_config ,
1289+ sharding_config = sharding_config ,
12751290 env_vars = env_vars ,
12761291 vpc_config = vpc_config ,
12771292 kms_key = kms_key ,
@@ -1287,6 +1302,9 @@ def _model_builder_optimize_wrapper(
12871302 if not speculative_decoding_config :
12881303 self .pysdk_model .remove_tag_with_key (Tag .SPECULATIVE_DRAFT_MODEL_PROVIDER )
12891304
1305+ if sharding_config :
1306+ self .pysdk_model ._is_sharded_model = True
1307+
12901308 return self .pysdk_model
12911309
12921310 def _optimize_for_hf (
@@ -1297,6 +1315,7 @@ def _optimize_for_hf(
12971315 quantization_config : Optional [Dict ] = None ,
12981316 compilation_config : Optional [Dict ] = None ,
12991317 speculative_decoding_config : Optional [Dict ] = None ,
1318+ sharding_config : Optional [Dict ] = None ,
13001319 env_vars : Optional [Dict ] = None ,
13011320 vpc_config : Optional [Dict ] = None ,
13021321 kms_key : Optional [str ] = None ,
@@ -1312,6 +1331,8 @@ def _optimize_for_hf(
13121331 compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``.
13131332 speculative_decoding_config (Optional[Dict]): Speculative decoding configuration.
13141333 Defaults to ``None``
1334+ sharding_config (Optional[Dict]): Model sharding configuration.
1335+ Defaults to ``None``
13151336 env_vars (Optional[Dict]): Additional environment variables to run the optimization
13161337 container. Defaults to ``None``.
13171338 vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``.
@@ -1327,7 +1348,7 @@ def _optimize_for_hf(
13271348 self .pysdk_model , speculative_decoding_config , False
13281349 )
13291350
1330- if quantization_config or compilation_config :
1351+ if quantization_config or compilation_config or sharding_config :
13311352 create_optimization_job_args = {
13321353 "OptimizationJobName" : job_name ,
13331354 "DeploymentInstanceType" : self .instance_type ,
0 commit comments