diff --git a/agent/base_agent.py b/agent/base_agent.py
index 070ece578d..2867c099f4 100644
--- a/agent/base_agent.py
+++ b/agent/base_agent.py
@@ -42,13 +42,20 @@ def get_tool(self, tool_name: str) -> Optional[BaseTool]:
return tool
return None
- def chat_llm(self, cur_round: int, client: Any, prompt: Prompt) -> str:
+ def chat_llm(self, cur_round: int, client: Any, prompt: Prompt,
+ trial: int) -> str:
"""Chat with LLM."""
logger.info('%s',
- cur_round, prompt.get(), cur_round)
+ cur_round,
+ prompt.get(),
+ cur_round,
+ trial=trial)
response = self.llm.chat_llm(client=client, prompt=prompt)
logger.info('%s',
- cur_round, response, cur_round)
+ cur_round,
+ response,
+ cur_round,
+ trial=trial)
return response
def _parse_tag(self, response: str, tag: str) -> str:
@@ -89,11 +96,16 @@ def _container_handle_invalid_tool_usage(self, tool: BaseTool) -> Prompt:
f'interaction protocols:\n{tool.tutorial()}')
return DefaultTemplateBuilder(self.llm, None, initial=prompt_text).build([])
- def _sleep_random_duration(self, min_sec: int = 1, max_sec: int = 60) -> None:
+ def _sleep_random_duration(
+ self,
+ trial: int,
+ min_sec: int = 1,
+ max_sec: int = 60,
+ ) -> None:
"""Sleeps for a random duration between min_sec and max_sec. Agents uses
this to avoid exceeding quota limit (e.g., LLM query frequency)."""
duration = random.randint(min_sec, max_sec)
- logger.debug('Sleeping for %d before the next query', duration)
+ logger.debug('Sleeping for %d before the next query', duration, trial=trial)
time.sleep(duration)
@classmethod
diff --git a/agent/prototyper.py b/agent/prototyper.py
index eb852f6050..cb0c34b446 100644
--- a/agent/prototyper.py
+++ b/agent/prototyper.py
@@ -10,7 +10,6 @@
from agent.base_agent import BaseAgent
from data_prep.project_context.context_introspector import ContextRetriever
from experiment.benchmark import Benchmark
-from llm_toolkit.prompt_builder import EXAMPLES as EXAMPLE_FUZZ_TARGETS
from llm_toolkit.prompt_builder import (DefaultTemplateBuilder,
PrototyperTemplateBuilder)
from llm_toolkit.prompts import Prompt
@@ -48,11 +47,15 @@ def _update_fuzz_target_and_build_script(self, cur_round: int, response: str,
self._parse_tag(response, 'fuzz target'))
build_result.fuzz_target_source = fuzz_target_source
if fuzz_target_source:
- logger.debug('ROUND %02d Parsed fuzz target from LLM: %s', cur_round,
- fuzz_target_source)
+ logger.debug('ROUND %02d Parsed fuzz target from LLM: %s',
+ cur_round,
+ fuzz_target_source,
+ trial=build_result.trial)
else:
logger.error('ROUND %02d No fuzz target source code in conclusion: %s',
- cur_round, response)
+ cur_round,
+ response,
+ trial=build_result.trial)
build_script_source = self._filter_code(
self._parse_tag(response, 'build script'))
@@ -60,11 +63,15 @@ def _update_fuzz_target_and_build_script(self, cur_round: int, response: str,
build_result.build_script_source = build_script_source.replace(
'source /src/chronos.sh', '')
if build_script_source:
- logger.debug('ROUND %02d Parsed build script from LLM: %s', cur_round,
- build_script_source)
+ logger.debug('ROUND %02d Parsed build script from LLM: %s',
+ cur_round,
+ build_script_source,
+ trial=build_result.trial)
else:
- logger.debug('ROUND %02d No build script in conclusion: %s', cur_round,
- response)
+ logger.debug('ROUND %02d No build script in conclusion: %s',
+ cur_round,
+ response,
+ trial=build_result.trial)
def _update_build_result(self, build_result: BuildResult,
compile_process: sp.CompletedProcess, status: bool,
@@ -84,20 +91,22 @@ def _validate_fuzz_target_and_build_script(self, cur_round: int,
# 2. Recompile with the modified build script, if any.
build_script_source = build_result.build_script_source
- logger.info('First compile fuzz target without modifying build script.')
+ logger.info('First compile fuzz target without modifying build script.',
+ trial=build_result.trial)
build_result.build_script_source = ''
self._validate_fuzz_target_and_build_script_via_compile(
cur_round, build_result)
if not build_result.success and build_script_source:
- logger.info('Then compile fuzz target with modified build script.')
+ logger.info('Then compile fuzz target with modified build script.',
+ trial=build_result.trial)
build_result.build_script_source = build_script_source
self._validate_fuzz_target_and_build_script_via_compile(
cur_round, build_result)
def _validate_fuzz_target_references_function(
self, compilation_tool: ProjectContainerTool, benchmark: Benchmark,
- cur_round: int) -> bool:
+ cur_round: int, trial: int) -> bool:
"""Validates if the LLM generated fuzz target assembly code references
function-under-test."""
disassemble_result = compilation_tool.execute(
@@ -106,10 +115,13 @@ def _validate_fuzz_target_references_function(
function_referenced = (disassemble_result.returncode == 0 and
benchmark.function_name in disassemble_result.stdout)
logger.debug('ROUND %02d Final fuzz target function referenced: %s',
- cur_round, function_referenced)
+ cur_round,
+ function_referenced,
+ trial=trial)
if not function_referenced:
logger.debug('ROUND %02d Final fuzz target function not referenced',
- cur_round)
+ cur_round,
+ trial=trial)
return function_referenced
def _validate_fuzz_target_and_build_script_via_compile(
@@ -133,25 +145,33 @@ def _validate_fuzz_target_and_build_script_via_compile(
file_content=build_result.build_script_source))
# Recompile.
- logger.info('===== ROUND %02d Recompile =====', cur_round)
+ logger.info('===== ROUND %02d Recompile =====',
+ cur_round,
+ trial=build_result.trial)
start_time = time.time()
compile_process = compilation_tool.compile()
end_time = time.time()
- logger.debug('ROUND %02d compilation time: %s', cur_round,
- timedelta(seconds=end_time - start_time))
+ logger.debug('ROUND %02d compilation time: %s',
+ cur_round,
+ timedelta(seconds=end_time - start_time),
+ trial=build_result.trial)
compile_succeed = compile_process.returncode == 0
- logger.debug('ROUND %02d Fuzz target compiles: %s', cur_round,
- compile_succeed)
+ logger.debug('ROUND %02d Fuzz target compiles: %s',
+ cur_round,
+ compile_succeed,
+ trial=build_result.trial)
# Double-check binary.
ls_result = compilation_tool.execute(f'ls /out/{benchmark.target_name}')
binary_exists = ls_result.returncode == 0
- logger.debug('ROUND %02d Final fuzz target binary exists: %s', cur_round,
- binary_exists)
+ logger.debug('ROUND %02d Final fuzz target binary exists: %s',
+ cur_round,
+ binary_exists,
+ trial=build_result.trial)
# Validate if function-under-test is referenced by the fuzz target.
function_referenced = self._validate_fuzz_target_references_function(
- compilation_tool, benchmark, cur_round)
+ compilation_tool, benchmark, cur_round, build_result.trial)
compilation_tool.terminate()
self._update_build_result(build_result,
@@ -164,18 +184,24 @@ def _container_handle_conclusion(
build_result: BuildResult) -> Optional[Prompt]:
"""Runs a compilation tool to validate the new fuzz target and build script
from LLM."""
- logger.info('----- ROUND %02d Received conclusion -----', cur_round)
+ logger.info('----- ROUND %02d Received conclusion -----',
+ cur_round,
+ trial=build_result.trial)
self._update_fuzz_target_and_build_script(cur_round, response, build_result)
self._validate_fuzz_target_and_build_script(cur_round, build_result)
if build_result.success:
- logger.info('***** Prototyper succeded in %02d rounds *****', cur_round)
+ logger.info('***** Prototyper succeded in %02d rounds *****',
+ cur_round,
+ trial=build_result.trial)
return None
if not build_result.compiles:
compile_log = self.llm.truncate_prompt(build_result.compile_log)
- logger.info('***** Failed to recompile in %02d rounds *****', cur_round)
+ logger.info('***** Failed to recompile in %02d rounds *****',
+ cur_round,
+ trial=build_result.trial)
prompt_text = (
'Failed to build fuzz target. Here is the fuzz target, build script, '
'compliation command, and other compilation runtime output. Analyze '
@@ -205,7 +231,9 @@ def _container_handle_conclusion(
elif not build_result.is_function_referenced:
logger.info(
'***** Fuzz target does not reference function-under-test in %02d '
- 'rounds *****', cur_round)
+ 'rounds *****',
+ cur_round,
+ trial=build_result.trial)
prompt_text = (
'The fuzz target builds successfully, but the target function '
f'`{build_result.benchmark.function_signature}` was not used by '
@@ -229,14 +257,16 @@ def _container_tool_reaction(self, cur_round: int, response: str,
return self._container_handle_conclusion(cur_round, response,
build_result)
# Other responses are invalid.
- logger.warning('ROUND %02d Invalid response from LLM: %s', cur_round,
- response)
+ logger.warning('ROUND %02d Invalid response from LLM: %s',
+ cur_round,
+ response,
+ trial=build_result.trial)
return self._container_handle_invalid_tool_usage(self.inspect_tool)
def execute(self, result_history: list[Result]) -> BuildResult:
"""Executes the agent based on previous result."""
- logger.info('Executing Prototyper')
last_result = result_history[-1]
+ logger.info('Executing Prototyper', trial=last_result.trial)
benchmark = last_result.benchmark
self.inspect_tool = ProjectContainerTool(benchmark, name='inspect')
self.inspect_tool.compile(extra_commands=' && rm -rf /out/* > /dev/null')
@@ -250,13 +280,17 @@ def execute(self, result_history: list[Result]) -> BuildResult:
try:
client = self.llm.get_chat_client(model=self.llm.get_model())
while prompt and cur_round < MAX_ROUND:
- response = self.chat_llm(cur_round, client=client, prompt=prompt)
+ response = self.chat_llm(cur_round,
+ client=client,
+ prompt=prompt,
+ trial=last_result.trial)
prompt = self._container_tool_reaction(cur_round, response,
build_result)
cur_round += 1
finally:
# Cleanup: stop and remove the container
logger.debug('Stopping and removing the inspect container %s',
- self.inspect_tool.container_id)
+ self.inspect_tool.container_id,
+ trial=last_result.trial)
self.inspect_tool.terminate()
return build_result
diff --git a/logger.py b/logger.py
index ea747da865..a660ec7a05 100644
--- a/logger.py
+++ b/logger.py
@@ -10,8 +10,6 @@
FINAL_RESULT_JSON = 'result.json'
-_trial_logger = None
-
class CustomLoggerAdapter(logging.LoggerAdapter):
"""A note-taker to log and record experiment status, key info, and final
@@ -61,76 +59,76 @@ def write_chat_history(self, result: Result) -> None:
def debug(msg: object,
*args: object,
+ trial: int,
exc_info=None,
stack_info: bool = False,
stacklevel: int = 1,
extra: Mapping[str, object] | None = None,
**kwargs: object) -> None:
- return get_trial_logger().debug(msg,
- *args,
- exc_info=exc_info,
- stack_info=stack_info,
- stacklevel=stacklevel,
- extra=extra,
- **kwargs)
+ return get_trial_logger(trial=trial).debug(msg,
+ *args,
+ exc_info=exc_info,
+ stack_info=stack_info,
+ stacklevel=stacklevel,
+ extra=extra,
+ **kwargs)
def info(msg: object,
*args: object,
+ trial: int,
exc_info=None,
stack_info: bool = False,
stacklevel: int = 1,
extra: Mapping[str, object] | None = None,
**kwargs: object) -> None:
- return get_trial_logger().info(msg,
- *args,
- exc_info=exc_info,
- stack_info=stack_info,
- stacklevel=stacklevel,
- extra=extra,
- **kwargs)
+ return get_trial_logger(trial=trial).info(msg,
+ *args,
+ exc_info=exc_info,
+ stack_info=stack_info,
+ stacklevel=stacklevel,
+ extra=extra,
+ **kwargs)
def warning(msg: object,
*args: object,
+ trial: int,
exc_info=None,
stack_info: bool = False,
stacklevel: int = 1,
extra: Mapping[str, object] | None = None,
**kwargs: object) -> None:
- return get_trial_logger().warning(msg,
- *args,
- exc_info=exc_info,
- stack_info=stack_info,
- stacklevel=stacklevel,
- extra=extra,
- **kwargs)
+ return get_trial_logger(trial=trial).warning(msg,
+ *args,
+ exc_info=exc_info,
+ stack_info=stack_info,
+ stacklevel=stacklevel,
+ extra=extra,
+ **kwargs)
def error(msg: object,
*args: object,
+ trial: int,
exc_info=None,
stack_info: bool = False,
stacklevel: int = 1,
extra: Mapping[str, object] | None = None,
**kwargs: object) -> None:
- return get_trial_logger().error(msg,
- *args,
- exc_info=exc_info,
- stack_info=stack_info,
- stacklevel=stacklevel,
- extra=extra,
- **kwargs)
+ return get_trial_logger(trial=trial).error(msg,
+ *args,
+ exc_info=exc_info,
+ stack_info=stack_info,
+ stacklevel=stacklevel,
+ extra=extra,
+ **kwargs)
def get_trial_logger(name: str = __name__,
trial: int = 0,
level=logging.DEBUG) -> CustomLoggerAdapter:
- """Sets up or retrieves the singleton instance of CustomLoggerAdapter."""
- global _trial_logger
- if _trial_logger:
- return _trial_logger
-
+ """Sets up or retrieves a thread-local CustomLoggerAdapter for each thread."""
logger = logging.getLogger(name)
if not logger.handlers:
formatter = logging.Formatter(
@@ -143,5 +141,4 @@ def get_trial_logger(name: str = __name__,
logger.setLevel(level)
logger.propagate = False
- _trial_logger = CustomLoggerAdapter(logger, {'trial': trial})
- return _trial_logger
+ return CustomLoggerAdapter(logger, {'trial': trial})
diff --git a/pipeline.py b/pipeline.py
index 10b083baab..3c534efb35 100644
--- a/pipeline.py
+++ b/pipeline.py
@@ -24,16 +24,19 @@ class Pipeline():
def __init__(self,
args: argparse.Namespace,
+ trial: int,
writing_stage_agents: Optional[list[BaseAgent]] = None,
evaluation_stage_agents: Optional[list[BaseAgent]] = None,
analysis_stage_agents: Optional[list[BaseAgent]] = None):
self.args = args
- self.logger = logger.get_trial_logger()
+ self.trial = trial
+ self.logger = logger.get_trial_logger(trial=trial)
self.logger.debug('Pipeline Initialized')
- self.writing_stage: WritingStage = WritingStage(args, writing_stage_agents)
+ self.writing_stage: WritingStage = WritingStage(args, trial,
+ writing_stage_agents)
self.execution_stage: ExecutionStage = ExecutionStage(
- args, evaluation_stage_agents)
- self.analysis_stage: AnalysisStage = AnalysisStage(args,
+ args, trial, evaluation_stage_agents)
+ self.analysis_stage: AnalysisStage = AnalysisStage(args, trial,
analysis_stage_agents)
def _terminate(self, result_history: list[Result]) -> bool:
diff --git a/run_one_experiment.py b/run_one_experiment.py
index 0f91daf746..10cbc59eaa 100644
--- a/run_one_experiment.py
+++ b/run_one_experiment.py
@@ -325,7 +325,9 @@ def _fuzzing_pipeline(benchmark: Benchmark, model: models.LLM,
trial_logger = logger.get_trial_logger(trial=trial, level=logging.DEBUG)
trial_logger.info('Trial Starts')
p = pipeline.Pipeline(
- args=args, writing_stage_agents=[Prototyper(trial=trial, llm=model)])
+ args=args,
+ trial=trial,
+ writing_stage_agents=[Prototyper(trial=trial, llm=model)])
results = p.execute(result_history=[
Result(benchmark=benchmark, trial=trial, work_dirs=work_dirs)
])
diff --git a/stage/base_stage.py b/stage/base_stage.py
index 1f8f4dfdbd..dd8421ebe4 100644
--- a/stage/base_stage.py
+++ b/stage/base_stage.py
@@ -14,10 +14,12 @@ class BaseStage(ABC):
def __init__(self,
args: argparse.Namespace,
+ trail: int,
agents: Optional[list[BaseAgent]] = None) -> None:
self.args = args
+ self.trial = trail
self.agents: list[BaseAgent] = agents or []
- self.logger = logger.get_trial_logger()
+ self.logger = logger.get_trial_logger(trial=trail)
def __repr__(self) -> str:
return self.__class__.__name__