From 9bc16f48ff27a17d8a60ecd99197dd458f8cca15 Mon Sep 17 00:00:00 2001 From: maoyixie Date: Wed, 20 Nov 2024 16:07:00 +0800 Subject: [PATCH 1/3] fix multiple threads share one logger --- logger.py | 36 ++++++++++++++++-------------------- 1 file changed, 16 insertions(+), 20 deletions(-) diff --git a/logger.py b/logger.py index ea747da865..076d5ffd6e 100644 --- a/logger.py +++ b/logger.py @@ -4,13 +4,14 @@ import json import logging import os +import threading from typing import Mapping from results import Result FINAL_RESULT_JSON = 'result.json' -_trial_logger = None +_thread_local = threading.local() class CustomLoggerAdapter(logging.LoggerAdapter): @@ -126,22 +127,17 @@ def error(msg: object, def get_trial_logger(name: str = __name__, trial: int = 0, level=logging.DEBUG) -> CustomLoggerAdapter: - """Sets up or retrieves the singleton instance of CustomLoggerAdapter.""" - global _trial_logger - if _trial_logger: - return _trial_logger - - logger = logging.getLogger(name) - if not logger.handlers: - formatter = logging.Formatter( - fmt=('%(asctime)s [Trial ID: %(trial)02d] %(levelname)s ' - '[%(module)s.%(funcName)s]: %(message)s'), - datefmt='%Y-%m-%d %H:%M:%S') - handler = logging.StreamHandler() - handler.setFormatter(formatter) - logger.addHandler(handler) - logger.setLevel(level) - logger.propagate = False - - _trial_logger = CustomLoggerAdapter(logger, {'trial': trial}) - return _trial_logger + """Sets up or retrieves a thread-local CustomLoggerAdapter for each thread.""" + if not hasattr(_thread_local, 'trial_logger'): + logger = logging.getLogger(f'{name}_trial_{trial}') + if not logger.handlers: + formatter = logging.Formatter( + fmt=('%(asctime)s [Trial ID: %(trial)02d] %(levelname)s ' + '[%(module)s.%(funcName)s]: %(message)s'), + datefmt='%Y-%m-%d %H:%M:%S') + handler = logging.StreamHandler() + handler.setFormatter(formatter) + logger.addHandler(handler) + logger.setLevel(level) + _thread_local.trial_logger = CustomLoggerAdapter(logger, {'trial': trial}) + return _thread_local.trial_logger From b3da4f62465d2992f70cca0b1df001b28487b8e4 Mon Sep 17 00:00:00 2001 From: maoyixie Date: Fri, 22 Nov 2024 15:02:04 +0800 Subject: [PATCH 2/3] refix and lint --- agent/base_agent.py | 10 ++--- agent/prototyper.py | 44 ++++++++++----------- logger.py | 89 ++++++++++++++++++++++--------------------- pipeline.py | 11 +++--- run_one_experiment.py | 2 +- stage/base_stage.py | 4 +- 6 files changed, 82 insertions(+), 78 deletions(-) diff --git a/agent/base_agent.py b/agent/base_agent.py index 070ece578d..e0fe83ea6b 100644 --- a/agent/base_agent.py +++ b/agent/base_agent.py @@ -42,13 +42,13 @@ def get_tool(self, tool_name: str) -> Optional[BaseTool]: return tool return None - def chat_llm(self, cur_round: int, client: Any, prompt: Prompt) -> str: + def chat_llm(self, cur_round: int, client: Any, prompt: Prompt, trial: int) -> str: """Chat with LLM.""" logger.info('%s', - cur_round, prompt.get(), cur_round) + cur_round, prompt.get(), cur_round, trial=trial) response = self.llm.chat_llm(client=client, prompt=prompt) logger.info('%s', - cur_round, response, cur_round) + cur_round, response, cur_round, trial=trial) return response def _parse_tag(self, response: str, tag: str) -> str: @@ -89,11 +89,11 @@ def _container_handle_invalid_tool_usage(self, tool: BaseTool) -> Prompt: f'interaction protocols:\n{tool.tutorial()}') return DefaultTemplateBuilder(self.llm, None, initial=prompt_text).build([]) - def _sleep_random_duration(self, min_sec: int = 1, max_sec: int = 60) -> None: + def _sleep_random_duration(self, trial: int, min_sec: int = 1, max_sec: int = 60, ) -> None: """Sleeps for a random duration between min_sec and max_sec. Agents uses this to avoid exceeding quota limit (e.g., LLM query frequency).""" duration = random.randint(min_sec, max_sec) - logger.debug('Sleeping for %d before the next query', duration) + logger.debug('Sleeping for %d before the next query', duration, trial=trial) time.sleep(duration) @classmethod diff --git a/agent/prototyper.py b/agent/prototyper.py index eb852f6050..56f32a2d5a 100644 --- a/agent/prototyper.py +++ b/agent/prototyper.py @@ -49,10 +49,10 @@ def _update_fuzz_target_and_build_script(self, cur_round: int, response: str, build_result.fuzz_target_source = fuzz_target_source if fuzz_target_source: logger.debug('ROUND %02d Parsed fuzz target from LLM: %s', cur_round, - fuzz_target_source) + fuzz_target_source, trial=build_result.trial) else: logger.error('ROUND %02d No fuzz target source code in conclusion: %s', - cur_round, response) + cur_round, response, trial=build_result.trial) build_script_source = self._filter_code( self._parse_tag(response, 'build script')) @@ -61,10 +61,10 @@ def _update_fuzz_target_and_build_script(self, cur_round: int, response: str, 'source /src/chronos.sh', '') if build_script_source: logger.debug('ROUND %02d Parsed build script from LLM: %s', cur_round, - build_script_source) + build_script_source, trial=build_result.trial) else: logger.debug('ROUND %02d No build script in conclusion: %s', cur_round, - response) + response, trial=build_result.trial) def _update_build_result(self, build_result: BuildResult, compile_process: sp.CompletedProcess, status: bool, @@ -84,20 +84,20 @@ def _validate_fuzz_target_and_build_script(self, cur_round: int, # 2. Recompile with the modified build script, if any. build_script_source = build_result.build_script_source - logger.info('First compile fuzz target without modifying build script.') + logger.info('First compile fuzz target without modifying build script.', trial=build_result.trial) build_result.build_script_source = '' self._validate_fuzz_target_and_build_script_via_compile( cur_round, build_result) if not build_result.success and build_script_source: - logger.info('Then compile fuzz target with modified build script.') + logger.info('Then compile fuzz target with modified build script.', trial=build_result.trial) build_result.build_script_source = build_script_source self._validate_fuzz_target_and_build_script_via_compile( cur_round, build_result) def _validate_fuzz_target_references_function( self, compilation_tool: ProjectContainerTool, benchmark: Benchmark, - cur_round: int) -> bool: + cur_round: int, trial: int) -> bool: """Validates if the LLM generated fuzz target assembly code references function-under-test.""" disassemble_result = compilation_tool.execute( @@ -106,10 +106,10 @@ def _validate_fuzz_target_references_function( function_referenced = (disassemble_result.returncode == 0 and benchmark.function_name in disassemble_result.stdout) logger.debug('ROUND %02d Final fuzz target function referenced: %s', - cur_round, function_referenced) + cur_round, function_referenced, trial=trial) if not function_referenced: logger.debug('ROUND %02d Final fuzz target function not referenced', - cur_round) + cur_round, trial=trial) return function_referenced def _validate_fuzz_target_and_build_script_via_compile( @@ -133,25 +133,25 @@ def _validate_fuzz_target_and_build_script_via_compile( file_content=build_result.build_script_source)) # Recompile. - logger.info('===== ROUND %02d Recompile =====', cur_round) + logger.info('===== ROUND %02d Recompile =====', cur_round, trial=build_result.trial) start_time = time.time() compile_process = compilation_tool.compile() end_time = time.time() logger.debug('ROUND %02d compilation time: %s', cur_round, - timedelta(seconds=end_time - start_time)) + timedelta(seconds=end_time - start_time), trial=build_result.trial) compile_succeed = compile_process.returncode == 0 logger.debug('ROUND %02d Fuzz target compiles: %s', cur_round, - compile_succeed) + compile_succeed, trial=build_result.trial) # Double-check binary. ls_result = compilation_tool.execute(f'ls /out/{benchmark.target_name}') binary_exists = ls_result.returncode == 0 logger.debug('ROUND %02d Final fuzz target binary exists: %s', cur_round, - binary_exists) + binary_exists, trial=build_result.trial) # Validate if function-under-test is referenced by the fuzz target. function_referenced = self._validate_fuzz_target_references_function( - compilation_tool, benchmark, cur_round) + compilation_tool, benchmark, cur_round, build_result.trial) compilation_tool.terminate() self._update_build_result(build_result, @@ -164,18 +164,18 @@ def _container_handle_conclusion( build_result: BuildResult) -> Optional[Prompt]: """Runs a compilation tool to validate the new fuzz target and build script from LLM.""" - logger.info('----- ROUND %02d Received conclusion -----', cur_round) + logger.info('----- ROUND %02d Received conclusion -----', cur_round, trial=build_result.trial) self._update_fuzz_target_and_build_script(cur_round, response, build_result) self._validate_fuzz_target_and_build_script(cur_round, build_result) if build_result.success: - logger.info('***** Prototyper succeded in %02d rounds *****', cur_round) + logger.info('***** Prototyper succeded in %02d rounds *****', cur_round, trial=build_result.trial) return None if not build_result.compiles: compile_log = self.llm.truncate_prompt(build_result.compile_log) - logger.info('***** Failed to recompile in %02d rounds *****', cur_round) + logger.info('***** Failed to recompile in %02d rounds *****', cur_round, trial=build_result.trial) prompt_text = ( 'Failed to build fuzz target. Here is the fuzz target, build script, ' 'compliation command, and other compilation runtime output. Analyze ' @@ -205,7 +205,7 @@ def _container_handle_conclusion( elif not build_result.is_function_referenced: logger.info( '***** Fuzz target does not reference function-under-test in %02d ' - 'rounds *****', cur_round) + 'rounds *****', cur_round, trial=build_result.trial) prompt_text = ( 'The fuzz target builds successfully, but the target function ' f'`{build_result.benchmark.function_signature}` was not used by ' @@ -230,13 +230,13 @@ def _container_tool_reaction(self, cur_round: int, response: str, build_result) # Other responses are invalid. logger.warning('ROUND %02d Invalid response from LLM: %s', cur_round, - response) + response, trial=build_result.trial) return self._container_handle_invalid_tool_usage(self.inspect_tool) def execute(self, result_history: list[Result]) -> BuildResult: """Executes the agent based on previous result.""" - logger.info('Executing Prototyper') last_result = result_history[-1] + logger.info('Executing Prototyper', trial=last_result.trial) benchmark = last_result.benchmark self.inspect_tool = ProjectContainerTool(benchmark, name='inspect') self.inspect_tool.compile(extra_commands=' && rm -rf /out/* > /dev/null') @@ -250,13 +250,13 @@ def execute(self, result_history: list[Result]) -> BuildResult: try: client = self.llm.get_chat_client(model=self.llm.get_model()) while prompt and cur_round < MAX_ROUND: - response = self.chat_llm(cur_round, client=client, prompt=prompt) + response = self.chat_llm(cur_round, client=client, prompt=prompt, trial=last_result.trial) prompt = self._container_tool_reaction(cur_round, response, build_result) cur_round += 1 finally: # Cleanup: stop and remove the container logger.debug('Stopping and removing the inspect container %s', - self.inspect_tool.container_id) + self.inspect_tool.container_id, trial=last_result.trial) self.inspect_tool.terminate() return build_result diff --git a/logger.py b/logger.py index 076d5ffd6e..a660ec7a05 100644 --- a/logger.py +++ b/logger.py @@ -4,15 +4,12 @@ import json import logging import os -import threading from typing import Mapping from results import Result FINAL_RESULT_JSON = 'result.json' -_thread_local = threading.local() - class CustomLoggerAdapter(logging.LoggerAdapter): """A note-taker to log and record experiment status, key info, and final @@ -62,82 +59,86 @@ def write_chat_history(self, result: Result) -> None: def debug(msg: object, *args: object, + trial: int, exc_info=None, stack_info: bool = False, stacklevel: int = 1, extra: Mapping[str, object] | None = None, **kwargs: object) -> None: - return get_trial_logger().debug(msg, - *args, - exc_info=exc_info, - stack_info=stack_info, - stacklevel=stacklevel, - extra=extra, - **kwargs) + return get_trial_logger(trial=trial).debug(msg, + *args, + exc_info=exc_info, + stack_info=stack_info, + stacklevel=stacklevel, + extra=extra, + **kwargs) def info(msg: object, *args: object, + trial: int, exc_info=None, stack_info: bool = False, stacklevel: int = 1, extra: Mapping[str, object] | None = None, **kwargs: object) -> None: - return get_trial_logger().info(msg, - *args, - exc_info=exc_info, - stack_info=stack_info, - stacklevel=stacklevel, - extra=extra, - **kwargs) + return get_trial_logger(trial=trial).info(msg, + *args, + exc_info=exc_info, + stack_info=stack_info, + stacklevel=stacklevel, + extra=extra, + **kwargs) def warning(msg: object, *args: object, + trial: int, exc_info=None, stack_info: bool = False, stacklevel: int = 1, extra: Mapping[str, object] | None = None, **kwargs: object) -> None: - return get_trial_logger().warning(msg, - *args, - exc_info=exc_info, - stack_info=stack_info, - stacklevel=stacklevel, - extra=extra, - **kwargs) + return get_trial_logger(trial=trial).warning(msg, + *args, + exc_info=exc_info, + stack_info=stack_info, + stacklevel=stacklevel, + extra=extra, + **kwargs) def error(msg: object, *args: object, + trial: int, exc_info=None, stack_info: bool = False, stacklevel: int = 1, extra: Mapping[str, object] | None = None, **kwargs: object) -> None: - return get_trial_logger().error(msg, - *args, - exc_info=exc_info, - stack_info=stack_info, - stacklevel=stacklevel, - extra=extra, - **kwargs) + return get_trial_logger(trial=trial).error(msg, + *args, + exc_info=exc_info, + stack_info=stack_info, + stacklevel=stacklevel, + extra=extra, + **kwargs) def get_trial_logger(name: str = __name__, trial: int = 0, level=logging.DEBUG) -> CustomLoggerAdapter: """Sets up or retrieves a thread-local CustomLoggerAdapter for each thread.""" - if not hasattr(_thread_local, 'trial_logger'): - logger = logging.getLogger(f'{name}_trial_{trial}') - if not logger.handlers: - formatter = logging.Formatter( - fmt=('%(asctime)s [Trial ID: %(trial)02d] %(levelname)s ' - '[%(module)s.%(funcName)s]: %(message)s'), - datefmt='%Y-%m-%d %H:%M:%S') - handler = logging.StreamHandler() - handler.setFormatter(formatter) - logger.addHandler(handler) - logger.setLevel(level) - _thread_local.trial_logger = CustomLoggerAdapter(logger, {'trial': trial}) - return _thread_local.trial_logger + logger = logging.getLogger(name) + if not logger.handlers: + formatter = logging.Formatter( + fmt=('%(asctime)s [Trial ID: %(trial)02d] %(levelname)s ' + '[%(module)s.%(funcName)s]: %(message)s'), + datefmt='%Y-%m-%d %H:%M:%S') + handler = logging.StreamHandler() + handler.setFormatter(formatter) + logger.addHandler(handler) + logger.setLevel(level) + logger.propagate = False + + return CustomLoggerAdapter(logger, {'trial': trial}) diff --git a/pipeline.py b/pipeline.py index 10b083baab..2b441c5c67 100644 --- a/pipeline.py +++ b/pipeline.py @@ -24,17 +24,18 @@ class Pipeline(): def __init__(self, args: argparse.Namespace, + trial: int, writing_stage_agents: Optional[list[BaseAgent]] = None, evaluation_stage_agents: Optional[list[BaseAgent]] = None, analysis_stage_agents: Optional[list[BaseAgent]] = None): self.args = args - self.logger = logger.get_trial_logger() + self.trial = trial + self.logger = logger.get_trial_logger(trial=trial) self.logger.debug('Pipeline Initialized') - self.writing_stage: WritingStage = WritingStage(args, writing_stage_agents) + self.writing_stage: WritingStage = WritingStage(args, trial, writing_stage_agents) self.execution_stage: ExecutionStage = ExecutionStage( - args, evaluation_stage_agents) - self.analysis_stage: AnalysisStage = AnalysisStage(args, - analysis_stage_agents) + args, trial, evaluation_stage_agents) + self.analysis_stage: AnalysisStage = AnalysisStage(args, trial, analysis_stage_agents) def _terminate(self, result_history: list[Result]) -> bool: """Validates if the termination conditions have been satisfied.""" diff --git a/run_one_experiment.py b/run_one_experiment.py index 0f91daf746..47326b964e 100644 --- a/run_one_experiment.py +++ b/run_one_experiment.py @@ -325,7 +325,7 @@ def _fuzzing_pipeline(benchmark: Benchmark, model: models.LLM, trial_logger = logger.get_trial_logger(trial=trial, level=logging.DEBUG) trial_logger.info('Trial Starts') p = pipeline.Pipeline( - args=args, writing_stage_agents=[Prototyper(trial=trial, llm=model)]) + args=args, trial=trial, writing_stage_agents=[Prototyper(trial=trial, llm=model)]) results = p.execute(result_history=[ Result(benchmark=benchmark, trial=trial, work_dirs=work_dirs) ]) diff --git a/stage/base_stage.py b/stage/base_stage.py index 1f8f4dfdbd..dd8421ebe4 100644 --- a/stage/base_stage.py +++ b/stage/base_stage.py @@ -14,10 +14,12 @@ class BaseStage(ABC): def __init__(self, args: argparse.Namespace, + trail: int, agents: Optional[list[BaseAgent]] = None) -> None: self.args = args + self.trial = trail self.agents: list[BaseAgent] = agents or [] - self.logger = logger.get_trial_logger() + self.logger = logger.get_trial_logger(trial=trail) def __repr__(self) -> str: return self.__class__.__name__ From 97f0c052a5cddf6220f6c5bc4d7088b978da6829 Mon Sep 17 00:00:00 2001 From: maoyixie Date: Fri, 22 Nov 2024 15:08:09 +0800 Subject: [PATCH 3/3] lint --- agent/base_agent.py | 20 ++++++++-- agent/prototyper.py | 88 ++++++++++++++++++++++++++++++------------- pipeline.py | 6 ++- run_one_experiment.py | 4 +- 4 files changed, 84 insertions(+), 34 deletions(-) diff --git a/agent/base_agent.py b/agent/base_agent.py index e0fe83ea6b..2867c099f4 100644 --- a/agent/base_agent.py +++ b/agent/base_agent.py @@ -42,13 +42,20 @@ def get_tool(self, tool_name: str) -> Optional[BaseTool]: return tool return None - def chat_llm(self, cur_round: int, client: Any, prompt: Prompt, trial: int) -> str: + def chat_llm(self, cur_round: int, client: Any, prompt: Prompt, + trial: int) -> str: """Chat with LLM.""" logger.info('%s', - cur_round, prompt.get(), cur_round, trial=trial) + cur_round, + prompt.get(), + cur_round, + trial=trial) response = self.llm.chat_llm(client=client, prompt=prompt) logger.info('%s', - cur_round, response, cur_round, trial=trial) + cur_round, + response, + cur_round, + trial=trial) return response def _parse_tag(self, response: str, tag: str) -> str: @@ -89,7 +96,12 @@ def _container_handle_invalid_tool_usage(self, tool: BaseTool) -> Prompt: f'interaction protocols:\n{tool.tutorial()}') return DefaultTemplateBuilder(self.llm, None, initial=prompt_text).build([]) - def _sleep_random_duration(self, trial: int, min_sec: int = 1, max_sec: int = 60, ) -> None: + def _sleep_random_duration( + self, + trial: int, + min_sec: int = 1, + max_sec: int = 60, + ) -> None: """Sleeps for a random duration between min_sec and max_sec. Agents uses this to avoid exceeding quota limit (e.g., LLM query frequency).""" duration = random.randint(min_sec, max_sec) diff --git a/agent/prototyper.py b/agent/prototyper.py index 56f32a2d5a..cb0c34b446 100644 --- a/agent/prototyper.py +++ b/agent/prototyper.py @@ -10,7 +10,6 @@ from agent.base_agent import BaseAgent from data_prep.project_context.context_introspector import ContextRetriever from experiment.benchmark import Benchmark -from llm_toolkit.prompt_builder import EXAMPLES as EXAMPLE_FUZZ_TARGETS from llm_toolkit.prompt_builder import (DefaultTemplateBuilder, PrototyperTemplateBuilder) from llm_toolkit.prompts import Prompt @@ -48,11 +47,15 @@ def _update_fuzz_target_and_build_script(self, cur_round: int, response: str, self._parse_tag(response, 'fuzz target')) build_result.fuzz_target_source = fuzz_target_source if fuzz_target_source: - logger.debug('ROUND %02d Parsed fuzz target from LLM: %s', cur_round, - fuzz_target_source, trial=build_result.trial) + logger.debug('ROUND %02d Parsed fuzz target from LLM: %s', + cur_round, + fuzz_target_source, + trial=build_result.trial) else: logger.error('ROUND %02d No fuzz target source code in conclusion: %s', - cur_round, response, trial=build_result.trial) + cur_round, + response, + trial=build_result.trial) build_script_source = self._filter_code( self._parse_tag(response, 'build script')) @@ -60,11 +63,15 @@ def _update_fuzz_target_and_build_script(self, cur_round: int, response: str, build_result.build_script_source = build_script_source.replace( 'source /src/chronos.sh', '') if build_script_source: - logger.debug('ROUND %02d Parsed build script from LLM: %s', cur_round, - build_script_source, trial=build_result.trial) + logger.debug('ROUND %02d Parsed build script from LLM: %s', + cur_round, + build_script_source, + trial=build_result.trial) else: - logger.debug('ROUND %02d No build script in conclusion: %s', cur_round, - response, trial=build_result.trial) + logger.debug('ROUND %02d No build script in conclusion: %s', + cur_round, + response, + trial=build_result.trial) def _update_build_result(self, build_result: BuildResult, compile_process: sp.CompletedProcess, status: bool, @@ -84,13 +91,15 @@ def _validate_fuzz_target_and_build_script(self, cur_round: int, # 2. Recompile with the modified build script, if any. build_script_source = build_result.build_script_source - logger.info('First compile fuzz target without modifying build script.', trial=build_result.trial) + logger.info('First compile fuzz target without modifying build script.', + trial=build_result.trial) build_result.build_script_source = '' self._validate_fuzz_target_and_build_script_via_compile( cur_round, build_result) if not build_result.success and build_script_source: - logger.info('Then compile fuzz target with modified build script.', trial=build_result.trial) + logger.info('Then compile fuzz target with modified build script.', + trial=build_result.trial) build_result.build_script_source = build_script_source self._validate_fuzz_target_and_build_script_via_compile( cur_round, build_result) @@ -106,10 +115,13 @@ def _validate_fuzz_target_references_function( function_referenced = (disassemble_result.returncode == 0 and benchmark.function_name in disassemble_result.stdout) logger.debug('ROUND %02d Final fuzz target function referenced: %s', - cur_round, function_referenced, trial=trial) + cur_round, + function_referenced, + trial=trial) if not function_referenced: logger.debug('ROUND %02d Final fuzz target function not referenced', - cur_round, trial=trial) + cur_round, + trial=trial) return function_referenced def _validate_fuzz_target_and_build_script_via_compile( @@ -133,21 +145,29 @@ def _validate_fuzz_target_and_build_script_via_compile( file_content=build_result.build_script_source)) # Recompile. - logger.info('===== ROUND %02d Recompile =====', cur_round, trial=build_result.trial) + logger.info('===== ROUND %02d Recompile =====', + cur_round, + trial=build_result.trial) start_time = time.time() compile_process = compilation_tool.compile() end_time = time.time() - logger.debug('ROUND %02d compilation time: %s', cur_round, - timedelta(seconds=end_time - start_time), trial=build_result.trial) + logger.debug('ROUND %02d compilation time: %s', + cur_round, + timedelta(seconds=end_time - start_time), + trial=build_result.trial) compile_succeed = compile_process.returncode == 0 - logger.debug('ROUND %02d Fuzz target compiles: %s', cur_round, - compile_succeed, trial=build_result.trial) + logger.debug('ROUND %02d Fuzz target compiles: %s', + cur_round, + compile_succeed, + trial=build_result.trial) # Double-check binary. ls_result = compilation_tool.execute(f'ls /out/{benchmark.target_name}') binary_exists = ls_result.returncode == 0 - logger.debug('ROUND %02d Final fuzz target binary exists: %s', cur_round, - binary_exists, trial=build_result.trial) + logger.debug('ROUND %02d Final fuzz target binary exists: %s', + cur_round, + binary_exists, + trial=build_result.trial) # Validate if function-under-test is referenced by the fuzz target. function_referenced = self._validate_fuzz_target_references_function( @@ -164,18 +184,24 @@ def _container_handle_conclusion( build_result: BuildResult) -> Optional[Prompt]: """Runs a compilation tool to validate the new fuzz target and build script from LLM.""" - logger.info('----- ROUND %02d Received conclusion -----', cur_round, trial=build_result.trial) + logger.info('----- ROUND %02d Received conclusion -----', + cur_round, + trial=build_result.trial) self._update_fuzz_target_and_build_script(cur_round, response, build_result) self._validate_fuzz_target_and_build_script(cur_round, build_result) if build_result.success: - logger.info('***** Prototyper succeded in %02d rounds *****', cur_round, trial=build_result.trial) + logger.info('***** Prototyper succeded in %02d rounds *****', + cur_round, + trial=build_result.trial) return None if not build_result.compiles: compile_log = self.llm.truncate_prompt(build_result.compile_log) - logger.info('***** Failed to recompile in %02d rounds *****', cur_round, trial=build_result.trial) + logger.info('***** Failed to recompile in %02d rounds *****', + cur_round, + trial=build_result.trial) prompt_text = ( 'Failed to build fuzz target. Here is the fuzz target, build script, ' 'compliation command, and other compilation runtime output. Analyze ' @@ -205,7 +231,9 @@ def _container_handle_conclusion( elif not build_result.is_function_referenced: logger.info( '***** Fuzz target does not reference function-under-test in %02d ' - 'rounds *****', cur_round, trial=build_result.trial) + 'rounds *****', + cur_round, + trial=build_result.trial) prompt_text = ( 'The fuzz target builds successfully, but the target function ' f'`{build_result.benchmark.function_signature}` was not used by ' @@ -229,8 +257,10 @@ def _container_tool_reaction(self, cur_round: int, response: str, return self._container_handle_conclusion(cur_round, response, build_result) # Other responses are invalid. - logger.warning('ROUND %02d Invalid response from LLM: %s', cur_round, - response, trial=build_result.trial) + logger.warning('ROUND %02d Invalid response from LLM: %s', + cur_round, + response, + trial=build_result.trial) return self._container_handle_invalid_tool_usage(self.inspect_tool) def execute(self, result_history: list[Result]) -> BuildResult: @@ -250,13 +280,17 @@ def execute(self, result_history: list[Result]) -> BuildResult: try: client = self.llm.get_chat_client(model=self.llm.get_model()) while prompt and cur_round < MAX_ROUND: - response = self.chat_llm(cur_round, client=client, prompt=prompt, trial=last_result.trial) + response = self.chat_llm(cur_round, + client=client, + prompt=prompt, + trial=last_result.trial) prompt = self._container_tool_reaction(cur_round, response, build_result) cur_round += 1 finally: # Cleanup: stop and remove the container logger.debug('Stopping and removing the inspect container %s', - self.inspect_tool.container_id, trial=last_result.trial) + self.inspect_tool.container_id, + trial=last_result.trial) self.inspect_tool.terminate() return build_result diff --git a/pipeline.py b/pipeline.py index 2b441c5c67..3c534efb35 100644 --- a/pipeline.py +++ b/pipeline.py @@ -32,10 +32,12 @@ def __init__(self, self.trial = trial self.logger = logger.get_trial_logger(trial=trial) self.logger.debug('Pipeline Initialized') - self.writing_stage: WritingStage = WritingStage(args, trial, writing_stage_agents) + self.writing_stage: WritingStage = WritingStage(args, trial, + writing_stage_agents) self.execution_stage: ExecutionStage = ExecutionStage( args, trial, evaluation_stage_agents) - self.analysis_stage: AnalysisStage = AnalysisStage(args, trial, analysis_stage_agents) + self.analysis_stage: AnalysisStage = AnalysisStage(args, trial, + analysis_stage_agents) def _terminate(self, result_history: list[Result]) -> bool: """Validates if the termination conditions have been satisfied.""" diff --git a/run_one_experiment.py b/run_one_experiment.py index 47326b964e..10cbc59eaa 100644 --- a/run_one_experiment.py +++ b/run_one_experiment.py @@ -325,7 +325,9 @@ def _fuzzing_pipeline(benchmark: Benchmark, model: models.LLM, trial_logger = logger.get_trial_logger(trial=trial, level=logging.DEBUG) trial_logger.info('Trial Starts') p = pipeline.Pipeline( - args=args, trial=trial, writing_stage_agents=[Prototyper(trial=trial, llm=model)]) + args=args, + trial=trial, + writing_stage_agents=[Prototyper(trial=trial, llm=model)]) results = p.execute(result_history=[ Result(benchmark=benchmark, trial=trial, work_dirs=work_dirs) ])