Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 17 additions & 8 deletions grasp/core/base_task_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,18 +597,27 @@ def execute(self):
# since the output file will also be big and its efficient to append to jsonl
out_file_type = "jsonl" if num_records_total > 25000 else "json"
run_name_prefix = f"{self.args.run_name}_" if self.args.run_name else ""
# Create a subdirectory for outputs with the name format
output_subdir_name = f"{run_name_prefix}output{ts_suffix}"

if self.output_dir:
if not os.path.exists(self.output_dir):
os.makedirs(self.output_dir)
out_file = (
self.output_dir
+ f"/{run_name_prefix}output{ts_suffix}.{out_file_type}"
)

output_subdir_path = os.path.join(self.output_dir, output_subdir_name)
if not os.path.exists(output_subdir_path):
os.makedirs(output_subdir_path)

out_file = os.path.join(output_subdir_path, f"output.{out_file_type}")
else:
out_file = utils.get_file_in_task_dir(
self.args.task,
f"{run_name_prefix}output{ts_suffix}.{out_file_type}",
)
# Get the task directory path
task_dir = utils.get_task_dir(self.task_name)

output_subdir_path = os.path.join(task_dir, output_subdir_name)
if not os.path.exists(output_subdir_path):
os.makedirs(output_subdir_path)

out_file = os.path.join(output_subdir_path, f"output.{out_file_type}")
if not self.resumable and os.path.exists(out_file):
logger.info(
f"Deleting existing output file since resumable=False: {out_file}"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
data_config:
source:
type: "disk"
file_path: "grasp/tasks/data_quality/llm_based/sample.json"
file_path: "grasp/internal_tasks/data_quality/llm_based/sample.json"

transformations:
- transform: grasp.processors.data_transform.AddNewFieldTransform
Expand All @@ -10,7 +10,7 @@ data_config:
category: "Generic"
scores: {}
metadata: {}
- transform: grasp.tasks.data_quality.llm_based.task_executor.ConvertToQuestionAnswerTransform
- transform: grasp.internal_tasks.data_quality.llm_based.task_executor.ConvertToQuestionAnswerTransform



Expand Down Expand Up @@ -50,7 +50,7 @@ graph_config:
parameters:
max_tokens: 1000
temperature: 0
post_process: grasp.tasks.data_quality.llm_based.task_executor.DataQualityQuestionQualityPostProcessor
post_process: grasp.internal_tasks.data_quality.llm_based.task_executor.DataQualityQuestionQualityPostProcessor

generic_prompt:
node_type: llm
Expand Down Expand Up @@ -108,7 +108,7 @@ graph_config:
parameters:
max_tokens: 4096
temperature: 0
post_process: grasp.tasks.data_quality.llm_based.task_executor.GenericPromptPostProcessor
post_process: grasp.internal_tasks.data_quality.llm_based.task_executor.GenericPromptPostProcessor

math_prompt:
node_type: llm
Expand Down Expand Up @@ -163,7 +163,7 @@ graph_config:
parameters:
max_tokens: 4096
temperature: 0
post_process: grasp.tasks.data_quality.llm_based.task_executor.GenericPromptPostProcessor
post_process: grasp.internal_tasks.data_quality.llm_based.task_executor.GenericPromptPostProcessor

reasoning_prompt:
node_type: llm
Expand Down Expand Up @@ -224,7 +224,7 @@ graph_config:
parameters:
max_tokens: 4096
temperature: 0
post_process: grasp.tasks.data_quality.llm_based.task_executor.GenericPromptPostProcessor
post_process: grasp.internal_tasks.data_quality.llm_based.task_executor.GenericPromptPostProcessor

coding_prompt:
node_type: llm
Expand Down Expand Up @@ -291,7 +291,7 @@ graph_config:
parameters:
max_tokens: 4096
temperature: 0
post_process: grasp.tasks.data_quality.llm_based.task_executor.GenericPromptPostProcessor
post_process: grasp.internal_tasks.data_quality.llm_based.task_executor.GenericPromptPostProcessor

instruction_following_prompt:
node_type: llm
Expand Down Expand Up @@ -349,7 +349,7 @@ graph_config:
parameters:
max_tokens: 4096
temperature: 0
post_process: grasp.tasks.data_quality.llm_based.task_executor.GenericPromptPostProcessor
post_process: grasp.internal_tasks.data_quality.llm_based.task_executor.GenericPromptPostProcessor

open_qa_prompt:
node_type: llm
Expand Down Expand Up @@ -404,7 +404,7 @@ graph_config:
parameters:
max_tokens: 4096
temperature: 0
post_process: grasp.tasks.data_quality.llm_based.task_executor.GenericPromptPostProcessor
post_process: grasp.internal_tasks.data_quality.llm_based.task_executor.GenericPromptPostProcessor

closed_qa_prompt:
node_type: llm
Expand Down Expand Up @@ -459,13 +459,13 @@ graph_config:
parameters:
max_tokens: 4096
temperature: 0
post_process: grasp.tasks.data_quality.llm_based.task_executor.GenericPromptPostProcessor
post_process: grasp.internal_tasks.data_quality.llm_based.task_executor.GenericPromptPostProcessor

edges:
- from: START
to: extract_question_quality
- from: extract_question_quality
condition: grasp.tasks.data_quality.llm_based.task_executor.DataQualityCategoryCondition
condition: grasp.internal_tasks.data_quality.llm_based.task_executor.DataQualityCategoryCondition
path_map:
generic: generic_prompt
math_solving: math_prompt
Expand All @@ -490,7 +490,7 @@ graph_config:
to: END

output_config:
generator: grasp.tasks.data_quality.llm_based.task_executor.DataQualityOutputGenerator
generator: grasp.internal_tasks.data_quality.llm_based.task_executor.DataQualityOutputGenerator
output_map:
id:
from: "id"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def apply(state: GraspState) -> str:
# Retrieve the category from state and check against allowed categories
category = state.get("category", "").lower().replace(" ", "_")
prompt_config = load_prompt_config(
utils.get_file_in_dir("tasks.data_quality.llm_based", "prompt_config.yaml")
utils.get_file_in_dir("internal_tasks.data_quality.llm_based", "prompt_config.yaml")
)

if category not in prompt_config:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
data_config:
source:
type: "disk"
file_path: "grasp/tasks/examples/glaive_code_assistant/test_output.json"
file_path: "tasks/examples/glaive_code_assistant/test_output.json"

transformations:
- transform: grasp.tasks.data_quality.metadata_tagging.task_executor.MetadataTaggingDataTransform
- transform: grasp.internal_tasks.data_quality.metadata_tagging.task_executor.MetadataTaggingDataTransform
params:
data_type: "sft"

graph_config:
nodes:
extract_category:
node_type: llm
post_process: grasp.tasks.data_quality.metadata_tagging.task_executor.ExtractCategoryPostProcess
pre_process: grasp.tasks.data_quality.metadata_tagging.task_executor.ExtractCategoryPreProcess
post_process: grasp.internal_tasks.data_quality.metadata_tagging.task_executor.ExtractCategoryPostProcess
pre_process: grasp.internal_tasks.data_quality.metadata_tagging.task_executor.ExtractCategoryPreProcess
prompt:
- system: |
You are a classification AI that strictly follows given instructions and categorizes data accurately. Your primary task is to select the most appropriate task category from a provided list based on User Question and Assistant Response.
Expand Down Expand Up @@ -104,7 +104,7 @@ graph_config:
name: gpt-4o
parameters:
temperature: 0.1
post_process: grasp.tasks.data_quality.metadata_tagging.task_executor.ExtractTagsPostProcess
post_process: grasp.internal_tasks.data_quality.metadata_tagging.task_executor.ExtractTagsPostProcess

edges:
- from: START
Expand All @@ -115,7 +115,7 @@ graph_config:
to: END

output_config:
generator: grasp.tasks.data_quality.metadata_tagging.task_executor.MetaTaggingOutputGenerator
generator: grasp.internal_tasks.data_quality.metadata_tagging.task_executor.MetaTaggingOutputGenerator
output_map:
id:
from: "id"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ def transform(
"""
# Load taxonomy metadata and prepare category descriptions
sub_tasks = self.get_data(
utils.get_file_in_task_dir(
"data_quality.metadata_tagging.taxonomy", "taxonomy.json"
utils.get_file_in_dir(
"internal_tasks.data_quality.metadata_tagging.taxonomy", "taxonomy.json"
)
)
task_category = {cat: sub_tasks[cat]["Description"] for cat in sub_tasks}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ def execute(self) -> str:

BaseTaskExecutor(args, graph_config_dict).execute()

output_file = os.path.join(self.output_dir, "llm_based_quality_output.jsonl")
output_file = os.path.join(self.output_dir, "llm_based_quality_output", "output.jsonl")
if os.path.exists(output_file):
return output_file
return os.path.join(self.output_dir, "llm_based_quality_output.json")
return os.path.join(self.output_dir, "llm_based_quality_output", "output.json")

def _construct_args(self) -> Namespace:
"""
Expand All @@ -51,7 +51,7 @@ def _construct_args(self) -> Namespace:
Namespace: A namespace object containing task arguments.
"""
args = {
"task": "data_quality.llm_based",
"task": "grasp.internal_tasks.data_quality.llm_based",
"start_index": 0,
"num_records": self.num_records,
"run_name": "llm_based_quality",
Expand Down Expand Up @@ -101,7 +101,7 @@ def _load_and_update_graph_config(self, data_config: dict) -> dict:
"""
graph_config = utils.load_yaml_file(
filepath=utils.get_file_in_task_dir(
"data_quality.llm_based", "graph_config.yaml"
"grasp.internal_tasks.data_quality.llm_based", "graph_config.yaml"
)
)
transformations = (
Expand Down
12 changes: 7 additions & 5 deletions grasp/tools/toolkits/data_quality/tasks/metadata_tagging.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
from argparse import Namespace
from grasp.utils import utils
from grasp.tasks.data_quality.metadata_tagging.filter_tags import (
from grasp.internal_tasks.data_quality.metadata_tagging.filter_tags import (
PipelineConfig,
extract_instag_stats,
)
Expand Down Expand Up @@ -44,13 +44,15 @@ def execute(self) -> str:
graph_config_dict = self._load_and_update_graph_config(data_config)
BaseTaskExecutor(args, graph_config_dict).execute()

output_file = os.path.join(self.output_dir, "metadata_tagging_output.jsonl")
output_file = os.path.join(self.output_dir, "metadata_tagging_output", "output.jsonl")
if not os.path.exists(output_file):
output_file = os.path.join(self.output_dir, "metadata_tagging_output.json")
output_file = os.path.join(self.output_dir, "metadata_tagging_output", "output.json")

# Run filter_tags on the output file if it exists
if os.path.exists(output_file):
self._run_filter_tags(output_file)
else:
logger.warning(f"Output file {output_file} does not exist. Skipping tags normalization and filtering.")

return output_file

Expand All @@ -62,7 +64,7 @@ def _construct_args(self) -> Namespace:
Namespace: A namespace object containing task arguments.
"""
args = {
"task": "data_quality.metadata_tagging",
"task": "grasp.internal_tasks.data_quality.metadata_tagging",
"start_index": 0,
"num_records": self.num_records,
"run_name": "metadata_tagging",
Expand Down Expand Up @@ -97,7 +99,7 @@ def _load_and_update_graph_config(self, data_config: dict) -> dict:
"""
graph_config = utils.load_yaml_file(
filepath=utils.get_file_in_task_dir(
"data_quality.metadata_tagging", "graph_config.yaml"
"grasp.internal_tasks.data_quality.metadata_tagging", "graph_config.yaml"
)
)
transformations = (
Expand Down
8 changes: 7 additions & 1 deletion grasp/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,12 +232,18 @@ def delete_file(filepath: str):
if os.path.exists(filepath):
os.remove(filepath)

def get_task_dir(task: str):
task_dir = "/".join(task.split("."))
return task_dir or f"{task_dir}/"

@deprecated("Use get_file_in_task_dir instead")
def get_file_in_task_dir(task: str, file: str):
task_dir = "/".join(task.split("."))
return os.path.join(task_dir, file) or f"{task_dir}/{file}"

def is_valid_task_name(task: str):
task_dir = "/".join(task.split("."))
return os.path.exists(f"{task_dir}/graph_config.yaml")


def get_file_in_dir(dot_walk_path: str, file: str):
dir_path = "/".join(dot_walk_path.split("."))
Expand Down
9 changes: 8 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,14 @@ def check_model_availability(task_name):

# check models are available and normalize task name
if not task_name.startswith("tasks."):
full_task_name = f"tasks.{task_name}"
#check if task_name is valid or "tasks.{task_name}" is valid. Whichever is valid, use that as args.task and full_task_name
if utils.is_valid_task_name(task_name):
full_task_name = task_name
elif utils.is_valid_task_name(f"tasks.{task_name}"):
full_task_name = f"tasks.{task_name}"
else:
logger.error(f"Invalid task name: {task_name}. Exiting the process.")
sys.exit(1)
check_model_availability(full_task_name)
args.task = full_task_name
utils.current_task = full_task_name # Set current_task to the full task name with prefix
Expand Down