diff --git a/grasp/core/base_task_executor.py b/grasp/core/base_task_executor.py index caacf94..7160888 100644 --- a/grasp/core/base_task_executor.py +++ b/grasp/core/base_task_executor.py @@ -597,18 +597,27 @@ def execute(self): # since the output file will also be big and its efficient to append to jsonl out_file_type = "jsonl" if num_records_total > 25000 else "json" run_name_prefix = f"{self.args.run_name}_" if self.args.run_name else "" + # Create a subdirectory for outputs with the name format + output_subdir_name = f"{run_name_prefix}output{ts_suffix}" + if self.output_dir: if not os.path.exists(self.output_dir): os.makedirs(self.output_dir) - out_file = ( - self.output_dir - + f"/{run_name_prefix}output{ts_suffix}.{out_file_type}" - ) + + output_subdir_path = os.path.join(self.output_dir, output_subdir_name) + if not os.path.exists(output_subdir_path): + os.makedirs(output_subdir_path) + + out_file = os.path.join(output_subdir_path, f"output.{out_file_type}") else: - out_file = utils.get_file_in_task_dir( - self.args.task, - f"{run_name_prefix}output{ts_suffix}.{out_file_type}", - ) + # Get the task directory path + task_dir = utils.get_task_dir(self.task_name) + + output_subdir_path = os.path.join(task_dir, output_subdir_name) + if not os.path.exists(output_subdir_path): + os.makedirs(output_subdir_path) + + out_file = os.path.join(output_subdir_path, f"output.{out_file_type}") if not self.resumable and os.path.exists(out_file): logger.info( f"Deleting existing output file since resumable=False: {out_file}" diff --git a/grasp/tasks/data_quality/llm_based/README.md b/grasp/internal_tasks/data_quality/llm_based/README.md similarity index 100% rename from grasp/tasks/data_quality/llm_based/README.md rename to grasp/internal_tasks/data_quality/llm_based/README.md diff --git a/grasp/tasks/data_quality/llm_based/graph_config.yaml b/grasp/internal_tasks/data_quality/llm_based/graph_config.yaml similarity index 95% rename from grasp/tasks/data_quality/llm_based/graph_config.yaml rename to grasp/internal_tasks/data_quality/llm_based/graph_config.yaml index 243cf9f..7d08623 100644 --- a/grasp/tasks/data_quality/llm_based/graph_config.yaml +++ b/grasp/internal_tasks/data_quality/llm_based/graph_config.yaml @@ -1,7 +1,7 @@ data_config: source: type: "disk" - file_path: "grasp/tasks/data_quality/llm_based/sample.json" + file_path: "grasp/internal_tasks/data_quality/llm_based/sample.json" transformations: - transform: grasp.processors.data_transform.AddNewFieldTransform @@ -10,7 +10,7 @@ data_config: category: "Generic" scores: {} metadata: {} - - transform: grasp.tasks.data_quality.llm_based.task_executor.ConvertToQuestionAnswerTransform + - transform: grasp.internal_tasks.data_quality.llm_based.task_executor.ConvertToQuestionAnswerTransform @@ -50,7 +50,7 @@ graph_config: parameters: max_tokens: 1000 temperature: 0 - post_process: grasp.tasks.data_quality.llm_based.task_executor.DataQualityQuestionQualityPostProcessor + post_process: grasp.internal_tasks.data_quality.llm_based.task_executor.DataQualityQuestionQualityPostProcessor generic_prompt: node_type: llm @@ -108,7 +108,7 @@ graph_config: parameters: max_tokens: 4096 temperature: 0 - post_process: grasp.tasks.data_quality.llm_based.task_executor.GenericPromptPostProcessor + post_process: grasp.internal_tasks.data_quality.llm_based.task_executor.GenericPromptPostProcessor math_prompt: node_type: llm @@ -163,7 +163,7 @@ graph_config: parameters: max_tokens: 4096 temperature: 0 - post_process: grasp.tasks.data_quality.llm_based.task_executor.GenericPromptPostProcessor + post_process: grasp.internal_tasks.data_quality.llm_based.task_executor.GenericPromptPostProcessor reasoning_prompt: node_type: llm @@ -224,7 +224,7 @@ graph_config: parameters: max_tokens: 4096 temperature: 0 - post_process: grasp.tasks.data_quality.llm_based.task_executor.GenericPromptPostProcessor + post_process: grasp.internal_tasks.data_quality.llm_based.task_executor.GenericPromptPostProcessor coding_prompt: node_type: llm @@ -291,7 +291,7 @@ graph_config: parameters: max_tokens: 4096 temperature: 0 - post_process: grasp.tasks.data_quality.llm_based.task_executor.GenericPromptPostProcessor + post_process: grasp.internal_tasks.data_quality.llm_based.task_executor.GenericPromptPostProcessor instruction_following_prompt: node_type: llm @@ -349,7 +349,7 @@ graph_config: parameters: max_tokens: 4096 temperature: 0 - post_process: grasp.tasks.data_quality.llm_based.task_executor.GenericPromptPostProcessor + post_process: grasp.internal_tasks.data_quality.llm_based.task_executor.GenericPromptPostProcessor open_qa_prompt: node_type: llm @@ -404,7 +404,7 @@ graph_config: parameters: max_tokens: 4096 temperature: 0 - post_process: grasp.tasks.data_quality.llm_based.task_executor.GenericPromptPostProcessor + post_process: grasp.internal_tasks.data_quality.llm_based.task_executor.GenericPromptPostProcessor closed_qa_prompt: node_type: llm @@ -459,13 +459,13 @@ graph_config: parameters: max_tokens: 4096 temperature: 0 - post_process: grasp.tasks.data_quality.llm_based.task_executor.GenericPromptPostProcessor + post_process: grasp.internal_tasks.data_quality.llm_based.task_executor.GenericPromptPostProcessor edges: - from: START to: extract_question_quality - from: extract_question_quality - condition: grasp.tasks.data_quality.llm_based.task_executor.DataQualityCategoryCondition + condition: grasp.internal_tasks.data_quality.llm_based.task_executor.DataQualityCategoryCondition path_map: generic: generic_prompt math_solving: math_prompt @@ -490,7 +490,7 @@ graph_config: to: END output_config: - generator: grasp.tasks.data_quality.llm_based.task_executor.DataQualityOutputGenerator + generator: grasp.internal_tasks.data_quality.llm_based.task_executor.DataQualityOutputGenerator output_map: id: from: "id" diff --git a/grasp/tasks/data_quality/llm_based/prompt_config.yaml b/grasp/internal_tasks/data_quality/llm_based/prompt_config.yaml similarity index 100% rename from grasp/tasks/data_quality/llm_based/prompt_config.yaml rename to grasp/internal_tasks/data_quality/llm_based/prompt_config.yaml diff --git a/grasp/tasks/data_quality/llm_based/sample.json b/grasp/internal_tasks/data_quality/llm_based/sample.json similarity index 100% rename from grasp/tasks/data_quality/llm_based/sample.json rename to grasp/internal_tasks/data_quality/llm_based/sample.json diff --git a/grasp/tasks/data_quality/llm_based/task_executor.py b/grasp/internal_tasks/data_quality/llm_based/task_executor.py similarity index 98% rename from grasp/tasks/data_quality/llm_based/task_executor.py rename to grasp/internal_tasks/data_quality/llm_based/task_executor.py index 5c60651..3448cfe 100644 --- a/grasp/tasks/data_quality/llm_based/task_executor.py +++ b/grasp/internal_tasks/data_quality/llm_based/task_executor.py @@ -122,7 +122,7 @@ def apply(state: GraspState) -> str: # Retrieve the category from state and check against allowed categories category = state.get("category", "").lower().replace(" ", "_") prompt_config = load_prompt_config( - utils.get_file_in_dir("tasks.data_quality.llm_based", "prompt_config.yaml") + utils.get_file_in_dir("internal_tasks.data_quality.llm_based", "prompt_config.yaml") ) if category not in prompt_config: diff --git a/grasp/tasks/data_quality/metadata_tagging/filter_tags.py b/grasp/internal_tasks/data_quality/metadata_tagging/filter_tags.py similarity index 100% rename from grasp/tasks/data_quality/metadata_tagging/filter_tags.py rename to grasp/internal_tasks/data_quality/metadata_tagging/filter_tags.py diff --git a/grasp/tasks/data_quality/metadata_tagging/graph_config.yaml b/grasp/internal_tasks/data_quality/metadata_tagging/graph_config.yaml similarity index 90% rename from grasp/tasks/data_quality/metadata_tagging/graph_config.yaml rename to grasp/internal_tasks/data_quality/metadata_tagging/graph_config.yaml index e938778..d489310 100644 --- a/grasp/tasks/data_quality/metadata_tagging/graph_config.yaml +++ b/grasp/internal_tasks/data_quality/metadata_tagging/graph_config.yaml @@ -1,10 +1,10 @@ data_config: source: type: "disk" - file_path: "grasp/tasks/examples/glaive_code_assistant/test_output.json" + file_path: "tasks/examples/glaive_code_assistant/test_output.json" transformations: - - transform: grasp.tasks.data_quality.metadata_tagging.task_executor.MetadataTaggingDataTransform + - transform: grasp.internal_tasks.data_quality.metadata_tagging.task_executor.MetadataTaggingDataTransform params: data_type: "sft" @@ -12,8 +12,8 @@ graph_config: nodes: extract_category: node_type: llm - post_process: grasp.tasks.data_quality.metadata_tagging.task_executor.ExtractCategoryPostProcess - pre_process: grasp.tasks.data_quality.metadata_tagging.task_executor.ExtractCategoryPreProcess + post_process: grasp.internal_tasks.data_quality.metadata_tagging.task_executor.ExtractCategoryPostProcess + pre_process: grasp.internal_tasks.data_quality.metadata_tagging.task_executor.ExtractCategoryPreProcess prompt: - system: | You are a classification AI that strictly follows given instructions and categorizes data accurately. Your primary task is to select the most appropriate task category from a provided list based on User Question and Assistant Response. @@ -104,7 +104,7 @@ graph_config: name: gpt-4o parameters: temperature: 0.1 - post_process: grasp.tasks.data_quality.metadata_tagging.task_executor.ExtractTagsPostProcess + post_process: grasp.internal_tasks.data_quality.metadata_tagging.task_executor.ExtractTagsPostProcess edges: - from: START @@ -115,7 +115,7 @@ graph_config: to: END output_config: - generator: grasp.tasks.data_quality.metadata_tagging.task_executor.MetaTaggingOutputGenerator + generator: grasp.internal_tasks.data_quality.metadata_tagging.task_executor.MetaTaggingOutputGenerator output_map: id: from: "id" diff --git a/grasp/tasks/data_quality/metadata_tagging/task_executor.py b/grasp/internal_tasks/data_quality/metadata_tagging/task_executor.py similarity index 97% rename from grasp/tasks/data_quality/metadata_tagging/task_executor.py rename to grasp/internal_tasks/data_quality/metadata_tagging/task_executor.py index 70f0bf1..8be8102 100644 --- a/grasp/tasks/data_quality/metadata_tagging/task_executor.py +++ b/grasp/internal_tasks/data_quality/metadata_tagging/task_executor.py @@ -80,8 +80,8 @@ def transform( """ # Load taxonomy metadata and prepare category descriptions sub_tasks = self.get_data( - utils.get_file_in_task_dir( - "data_quality.metadata_tagging.taxonomy", "taxonomy.json" + utils.get_file_in_dir( + "internal_tasks.data_quality.metadata_tagging.taxonomy", "taxonomy.json" ) ) task_category = {cat: sub_tasks[cat]["Description"] for cat in sub_tasks} diff --git a/grasp/tasks/data_quality/metadata_tagging/taxonomy/taxonomy.json b/grasp/internal_tasks/data_quality/metadata_tagging/taxonomy/taxonomy.json similarity index 100% rename from grasp/tasks/data_quality/metadata_tagging/taxonomy/taxonomy.json rename to grasp/internal_tasks/data_quality/metadata_tagging/taxonomy/taxonomy.json diff --git a/grasp/tools/toolkits/data_quality/tasks/llm_based_quality.py b/grasp/tools/toolkits/data_quality/tasks/llm_based_quality.py index be0e4e5..c8d4a74 100644 --- a/grasp/tools/toolkits/data_quality/tasks/llm_based_quality.py +++ b/grasp/tools/toolkits/data_quality/tasks/llm_based_quality.py @@ -38,10 +38,10 @@ def execute(self) -> str: BaseTaskExecutor(args, graph_config_dict).execute() - output_file = os.path.join(self.output_dir, "llm_based_quality_output.jsonl") + output_file = os.path.join(self.output_dir, "llm_based_quality_output", "output.jsonl") if os.path.exists(output_file): return output_file - return os.path.join(self.output_dir, "llm_based_quality_output.json") + return os.path.join(self.output_dir, "llm_based_quality_output", "output.json") def _construct_args(self) -> Namespace: """ @@ -51,7 +51,7 @@ def _construct_args(self) -> Namespace: Namespace: A namespace object containing task arguments. """ args = { - "task": "data_quality.llm_based", + "task": "grasp.internal_tasks.data_quality.llm_based", "start_index": 0, "num_records": self.num_records, "run_name": "llm_based_quality", @@ -101,7 +101,7 @@ def _load_and_update_graph_config(self, data_config: dict) -> dict: """ graph_config = utils.load_yaml_file( filepath=utils.get_file_in_task_dir( - "data_quality.llm_based", "graph_config.yaml" + "grasp.internal_tasks.data_quality.llm_based", "graph_config.yaml" ) ) transformations = ( diff --git a/grasp/tools/toolkits/data_quality/tasks/metadata_tagging.py b/grasp/tools/toolkits/data_quality/tasks/metadata_tagging.py index 2261d55..58bb939 100644 --- a/grasp/tools/toolkits/data_quality/tasks/metadata_tagging.py +++ b/grasp/tools/toolkits/data_quality/tasks/metadata_tagging.py @@ -2,7 +2,7 @@ import logging from argparse import Namespace from grasp.utils import utils -from grasp.tasks.data_quality.metadata_tagging.filter_tags import ( +from grasp.internal_tasks.data_quality.metadata_tagging.filter_tags import ( PipelineConfig, extract_instag_stats, ) @@ -44,13 +44,15 @@ def execute(self) -> str: graph_config_dict = self._load_and_update_graph_config(data_config) BaseTaskExecutor(args, graph_config_dict).execute() - output_file = os.path.join(self.output_dir, "metadata_tagging_output.jsonl") + output_file = os.path.join(self.output_dir, "metadata_tagging_output", "output.jsonl") if not os.path.exists(output_file): - output_file = os.path.join(self.output_dir, "metadata_tagging_output.json") + output_file = os.path.join(self.output_dir, "metadata_tagging_output", "output.json") # Run filter_tags on the output file if it exists if os.path.exists(output_file): self._run_filter_tags(output_file) + else: + logger.warning(f"Output file {output_file} does not exist. Skipping tags normalization and filtering.") return output_file @@ -62,7 +64,7 @@ def _construct_args(self) -> Namespace: Namespace: A namespace object containing task arguments. """ args = { - "task": "data_quality.metadata_tagging", + "task": "grasp.internal_tasks.data_quality.metadata_tagging", "start_index": 0, "num_records": self.num_records, "run_name": "metadata_tagging", @@ -97,7 +99,7 @@ def _load_and_update_graph_config(self, data_config: dict) -> dict: """ graph_config = utils.load_yaml_file( filepath=utils.get_file_in_task_dir( - "data_quality.metadata_tagging", "graph_config.yaml" + "grasp.internal_tasks.data_quality.metadata_tagging", "graph_config.yaml" ) ) transformations = ( diff --git a/grasp/utils/utils.py b/grasp/utils/utils.py index 7b80d60..88ca676 100644 --- a/grasp/utils/utils.py +++ b/grasp/utils/utils.py @@ -232,12 +232,18 @@ def delete_file(filepath: str): if os.path.exists(filepath): os.remove(filepath) +def get_task_dir(task: str): + task_dir = "/".join(task.split(".")) + return task_dir or f"{task_dir}/" -@deprecated("Use get_file_in_task_dir instead") def get_file_in_task_dir(task: str, file: str): task_dir = "/".join(task.split(".")) return os.path.join(task_dir, file) or f"{task_dir}/{file}" +def is_valid_task_name(task: str): + task_dir = "/".join(task.split(".")) + return os.path.exists(f"{task_dir}/graph_config.yaml") + def get_file_in_dir(dot_walk_path: str, file: str): dir_path = "/".join(dot_walk_path.split(".")) diff --git a/main.py b/main.py index c2330d1..41b6d46 100644 --- a/main.py +++ b/main.py @@ -172,7 +172,14 @@ def check_model_availability(task_name): # check models are available and normalize task name if not task_name.startswith("tasks."): - full_task_name = f"tasks.{task_name}" + #check if task_name is valid or "tasks.{task_name}" is valid. Whichever is valid, use that as args.task and full_task_name + if utils.is_valid_task_name(task_name): + full_task_name = task_name + elif utils.is_valid_task_name(f"tasks.{task_name}"): + full_task_name = f"tasks.{task_name}" + else: + logger.error(f"Invalid task name: {task_name}. Exiting the process.") + sys.exit(1) check_model_availability(full_task_name) args.task = full_task_name utils.current_task = full_task_name # Set current_task to the full task name with prefix