Skip to content

Commit 57b7aa6

Browse files
committed
Fixes issues with reference dataset inclusion when creating inference pipeline
1 parent b710f05 commit 57b7aa6

File tree

2 files changed

+29
-24
lines changed

2 files changed

+29
-24
lines changed

openlayer/__init__.py

Lines changed: 28 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -835,8 +835,8 @@ def load_project_version(self, version_id: str) -> Project:
835835
>>> version.wait_for_completion()
836836
>>> version.print_goal_report()
837837
838-
With the :obj:`project_versions.ProjectVersion` object loaded, you are able to check progress and
839-
goal statuses.
838+
With the :obj:`project_versions.ProjectVersion` object loaded, you are able to
839+
check progress and goal statuses.
840840
"""
841841
endpoint = f"versions/{version_id}"
842842
version_data = self.api.get_request(endpoint)
@@ -896,11 +896,17 @@ def create_inference_pipeline(
896896
" creating it.",
897897
) from None
898898

899-
# Validate reference dataset and augment config
899+
# Load dataset config
900900
if reference_dataset_config_file_path is not None:
901+
reference_dataset_config = utils.read_yaml(
902+
reference_dataset_config_file_path
903+
)
904+
905+
if reference_dataset_config is not None:
906+
# Validate reference dataset and augment config
901907
dataset_validator = dataset_validators.get_validator(
902908
task_type=task_type,
903-
dataset_config_file_path=reference_dataset_config_file_path,
909+
dataset_config=reference_dataset_config,
904910
dataset_df=reference_df,
905911
)
906912
failed_validations = dataset_validator.validate()
@@ -912,40 +918,39 @@ def create_inference_pipeline(
912918
" upload.",
913919
) from None
914920

915-
# Load dataset config and augment with defaults
916-
reference_dataset_config = utils.read_yaml(
917-
reference_dataset_config_file_path
918-
)
919921
reference_dataset_data = DatasetSchema().load(
920922
{"task_type": task_type.value, **reference_dataset_config}
921923
)
922924

923-
with tempfile.TemporaryDirectory() as tmp_dir:
924925
# Copy relevant files to tmp dir if reference dataset is provided
925-
if reference_dataset_config_file_path is not None:
926+
with tempfile.TemporaryDirectory() as tmp_dir:
926927
utils.write_yaml(
927928
reference_dataset_data, f"{tmp_dir}/dataset_config.yaml"
928929
)
929930
if reference_df is not None:
930931
reference_df.to_csv(f"{tmp_dir}/dataset.csv", index=False)
931932
else:
932933
shutil.copy(
933-
reference_dataset_file_path,
934-
f"{tmp_dir}/dataset.csv",
934+
reference_dataset_file_path, f"{tmp_dir}/dataset.csv"
935935
)
936936

937-
tar_file_path = os.path.join(tmp_dir, "tarfile")
938-
with tarfile.open(tar_file_path, mode="w:gz") as tar:
939-
tar.add(tmp_dir, arcname=os.path.basename("reference_dataset"))
940-
937+
tar_file_path = os.path.join(tmp_dir, "tarfile")
938+
with tarfile.open(tar_file_path, mode="w:gz") as tar:
939+
tar.add(tmp_dir, arcname=os.path.basename("reference_dataset"))
940+
941+
endpoint = f"projects/{project_id}/inference-pipelines"
942+
inference_pipeline_data = self.api.upload(
943+
endpoint=endpoint,
944+
file_path=tar_file_path,
945+
object_name="tarfile",
946+
body=inference_pipeline_config,
947+
storage_uri_key="referenceDatasetUri",
948+
method="POST",
949+
)
950+
else:
941951
endpoint = f"projects/{project_id}/inference-pipelines"
942-
inference_pipeline_data = self.api.upload(
943-
endpoint=endpoint,
944-
file_path=tar_file_path,
945-
object_name="tarfile",
946-
body=inference_pipeline_config,
947-
storage_uri_key="referenceDatasetUri",
948-
method="POST",
952+
inference_pipeline_data = self.api.post_request(
953+
endpoint=endpoint, body=inference_pipeline_config
949954
)
950955
inference_pipeline = InferencePipeline(
951956
inference_pipeline_data, self.api.upload, self, task_type

openlayer/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,4 @@
2222
data=data,
2323
)
2424
"""
25-
__version__ = "0.1.0a12"
25+
__version__ = "0.1.0a13"

0 commit comments

Comments
 (0)