@@ -835,8 +835,8 @@ def load_project_version(self, version_id: str) -> Project:
835
835
>>> version.wait_for_completion()
836
836
>>> version.print_goal_report()
837
837
838
- With the :obj:`project_versions.ProjectVersion` object loaded, you are able to check progress and
839
- goal statuses.
838
+ With the :obj:`project_versions.ProjectVersion` object loaded, you are able to
839
+ check progress and goal statuses.
840
840
"""
841
841
endpoint = f"versions/{ version_id } "
842
842
version_data = self .api .get_request (endpoint )
@@ -896,11 +896,17 @@ def create_inference_pipeline(
896
896
" creating it." ,
897
897
) from None
898
898
899
- # Validate reference dataset and augment config
899
+ # Load dataset config
900
900
if reference_dataset_config_file_path is not None :
901
+ reference_dataset_config = utils .read_yaml (
902
+ reference_dataset_config_file_path
903
+ )
904
+
905
+ if reference_dataset_config is not None :
906
+ # Validate reference dataset and augment config
901
907
dataset_validator = dataset_validators .get_validator (
902
908
task_type = task_type ,
903
- dataset_config_file_path = reference_dataset_config_file_path ,
909
+ dataset_config = reference_dataset_config ,
904
910
dataset_df = reference_df ,
905
911
)
906
912
failed_validations = dataset_validator .validate ()
@@ -912,40 +918,39 @@ def create_inference_pipeline(
912
918
" upload." ,
913
919
) from None
914
920
915
- # Load dataset config and augment with defaults
916
- reference_dataset_config = utils .read_yaml (
917
- reference_dataset_config_file_path
918
- )
919
921
reference_dataset_data = DatasetSchema ().load (
920
922
{"task_type" : task_type .value , ** reference_dataset_config }
921
923
)
922
924
923
- with tempfile .TemporaryDirectory () as tmp_dir :
924
925
# Copy relevant files to tmp dir if reference dataset is provided
925
- if reference_dataset_config_file_path is not None :
926
+ with tempfile . TemporaryDirectory () as tmp_dir :
926
927
utils .write_yaml (
927
928
reference_dataset_data , f"{ tmp_dir } /dataset_config.yaml"
928
929
)
929
930
if reference_df is not None :
930
931
reference_df .to_csv (f"{ tmp_dir } /dataset.csv" , index = False )
931
932
else :
932
933
shutil .copy (
933
- reference_dataset_file_path ,
934
- f"{ tmp_dir } /dataset.csv" ,
934
+ reference_dataset_file_path , f"{ tmp_dir } /dataset.csv"
935
935
)
936
936
937
- tar_file_path = os .path .join (tmp_dir , "tarfile" )
938
- with tarfile .open (tar_file_path , mode = "w:gz" ) as tar :
939
- tar .add (tmp_dir , arcname = os .path .basename ("reference_dataset" ))
940
-
937
+ tar_file_path = os .path .join (tmp_dir , "tarfile" )
938
+ with tarfile .open (tar_file_path , mode = "w:gz" ) as tar :
939
+ tar .add (tmp_dir , arcname = os .path .basename ("reference_dataset" ))
940
+
941
+ endpoint = f"projects/{ project_id } /inference-pipelines"
942
+ inference_pipeline_data = self .api .upload (
943
+ endpoint = endpoint ,
944
+ file_path = tar_file_path ,
945
+ object_name = "tarfile" ,
946
+ body = inference_pipeline_config ,
947
+ storage_uri_key = "referenceDatasetUri" ,
948
+ method = "POST" ,
949
+ )
950
+ else :
941
951
endpoint = f"projects/{ project_id } /inference-pipelines"
942
- inference_pipeline_data = self .api .upload (
943
- endpoint = endpoint ,
944
- file_path = tar_file_path ,
945
- object_name = "tarfile" ,
946
- body = inference_pipeline_config ,
947
- storage_uri_key = "referenceDatasetUri" ,
948
- method = "POST" ,
952
+ inference_pipeline_data = self .api .post_request (
953
+ endpoint = endpoint , body = inference_pipeline_config
949
954
)
950
955
inference_pipeline = InferencePipeline (
951
956
inference_pipeline_data , self .api .upload , self , task_type
0 commit comments