Skip to content

Fix multiple threads share one logger #720

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Nov 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions agent/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,20 @@ def get_tool(self, tool_name: str) -> Optional[BaseTool]:
return tool
return None

def chat_llm(self, cur_round: int, client: Any, prompt: Prompt) -> str:
def chat_llm(self, cur_round: int, client: Any, prompt: Prompt,
trial: int) -> str:
"""Chat with LLM."""
logger.info('<CHAT PROMPT:ROUND %02d>%s</CHAT PROMPT:ROUND %02d>',
cur_round, prompt.get(), cur_round)
cur_round,
prompt.get(),
cur_round,
trial=trial)
response = self.llm.chat_llm(client=client, prompt=prompt)
logger.info('<CHAT RESPONSE:ROUND %02d>%s</CHAT RESPONSE:ROUND %02d>',
cur_round, response, cur_round)
cur_round,
response,
cur_round,
trial=trial)
return response

def _parse_tag(self, response: str, tag: str) -> str:
Expand Down Expand Up @@ -89,11 +96,16 @@ def _container_handle_invalid_tool_usage(self, tool: BaseTool) -> Prompt:
f'interaction protocols:\n{tool.tutorial()}')
return DefaultTemplateBuilder(self.llm, None, initial=prompt_text).build([])

def _sleep_random_duration(self, min_sec: int = 1, max_sec: int = 60) -> None:
def _sleep_random_duration(
self,
trial: int,
min_sec: int = 1,
max_sec: int = 60,
) -> None:
"""Sleeps for a random duration between min_sec and max_sec. Agents uses
this to avoid exceeding quota limit (e.g., LLM query frequency)."""
duration = random.randint(min_sec, max_sec)
logger.debug('Sleeping for %d before the next query', duration)
logger.debug('Sleeping for %d before the next query', duration, trial=trial)
time.sleep(duration)

@classmethod
Expand Down
94 changes: 64 additions & 30 deletions agent/prototyper.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from agent.base_agent import BaseAgent
from data_prep.project_context.context_introspector import ContextRetriever
from experiment.benchmark import Benchmark
from llm_toolkit.prompt_builder import EXAMPLES as EXAMPLE_FUZZ_TARGETS
from llm_toolkit.prompt_builder import (DefaultTemplateBuilder,
PrototyperTemplateBuilder)
from llm_toolkit.prompts import Prompt
Expand Down Expand Up @@ -48,23 +47,31 @@ def _update_fuzz_target_and_build_script(self, cur_round: int, response: str,
self._parse_tag(response, 'fuzz target'))
build_result.fuzz_target_source = fuzz_target_source
if fuzz_target_source:
logger.debug('ROUND %02d Parsed fuzz target from LLM: %s', cur_round,
fuzz_target_source)
logger.debug('ROUND %02d Parsed fuzz target from LLM: %s',
cur_round,
fuzz_target_source,
trial=build_result.trial)
else:
logger.error('ROUND %02d No fuzz target source code in conclusion: %s',
cur_round, response)
cur_round,
response,
trial=build_result.trial)

build_script_source = self._filter_code(
self._parse_tag(response, 'build script'))
# Sometimes LLM adds chronos, which makes no sense for new build scripts.
build_result.build_script_source = build_script_source.replace(
'source /src/chronos.sh', '')
if build_script_source:
logger.debug('ROUND %02d Parsed build script from LLM: %s', cur_round,
build_script_source)
logger.debug('ROUND %02d Parsed build script from LLM: %s',
cur_round,
build_script_source,
trial=build_result.trial)
else:
logger.debug('ROUND %02d No build script in conclusion: %s', cur_round,
response)
logger.debug('ROUND %02d No build script in conclusion: %s',
cur_round,
response,
trial=build_result.trial)

def _update_build_result(self, build_result: BuildResult,
compile_process: sp.CompletedProcess, status: bool,
Expand All @@ -84,20 +91,22 @@ def _validate_fuzz_target_and_build_script(self, cur_round: int,
# 2. Recompile with the modified build script, if any.
build_script_source = build_result.build_script_source

logger.info('First compile fuzz target without modifying build script.')
logger.info('First compile fuzz target without modifying build script.',
trial=build_result.trial)
build_result.build_script_source = ''
self._validate_fuzz_target_and_build_script_via_compile(
cur_round, build_result)

if not build_result.success and build_script_source:
logger.info('Then compile fuzz target with modified build script.')
logger.info('Then compile fuzz target with modified build script.',
trial=build_result.trial)
build_result.build_script_source = build_script_source
self._validate_fuzz_target_and_build_script_via_compile(
cur_round, build_result)

def _validate_fuzz_target_references_function(
self, compilation_tool: ProjectContainerTool, benchmark: Benchmark,
cur_round: int) -> bool:
cur_round: int, trial: int) -> bool:
"""Validates if the LLM generated fuzz target assembly code references
function-under-test."""
disassemble_result = compilation_tool.execute(
Expand All @@ -106,10 +115,13 @@ def _validate_fuzz_target_references_function(
function_referenced = (disassemble_result.returncode == 0 and
benchmark.function_name in disassemble_result.stdout)
logger.debug('ROUND %02d Final fuzz target function referenced: %s',
cur_round, function_referenced)
cur_round,
function_referenced,
trial=trial)
if not function_referenced:
logger.debug('ROUND %02d Final fuzz target function not referenced',
cur_round)
cur_round,
trial=trial)
return function_referenced

def _validate_fuzz_target_and_build_script_via_compile(
Expand All @@ -133,25 +145,33 @@ def _validate_fuzz_target_and_build_script_via_compile(
file_content=build_result.build_script_source))

# Recompile.
logger.info('===== ROUND %02d Recompile =====', cur_round)
logger.info('===== ROUND %02d Recompile =====',
cur_round,
trial=build_result.trial)
start_time = time.time()
compile_process = compilation_tool.compile()
end_time = time.time()
logger.debug('ROUND %02d compilation time: %s', cur_round,
timedelta(seconds=end_time - start_time))
logger.debug('ROUND %02d compilation time: %s',
cur_round,
timedelta(seconds=end_time - start_time),
trial=build_result.trial)
compile_succeed = compile_process.returncode == 0
logger.debug('ROUND %02d Fuzz target compiles: %s', cur_round,
compile_succeed)
logger.debug('ROUND %02d Fuzz target compiles: %s',
cur_round,
compile_succeed,
trial=build_result.trial)

# Double-check binary.
ls_result = compilation_tool.execute(f'ls /out/{benchmark.target_name}')
binary_exists = ls_result.returncode == 0
logger.debug('ROUND %02d Final fuzz target binary exists: %s', cur_round,
binary_exists)
logger.debug('ROUND %02d Final fuzz target binary exists: %s',
cur_round,
binary_exists,
trial=build_result.trial)

# Validate if function-under-test is referenced by the fuzz target.
function_referenced = self._validate_fuzz_target_references_function(
compilation_tool, benchmark, cur_round)
compilation_tool, benchmark, cur_round, build_result.trial)

compilation_tool.terminate()
self._update_build_result(build_result,
Expand All @@ -164,18 +184,24 @@ def _container_handle_conclusion(
build_result: BuildResult) -> Optional[Prompt]:
"""Runs a compilation tool to validate the new fuzz target and build script
from LLM."""
logger.info('----- ROUND %02d Received conclusion -----', cur_round)
logger.info('----- ROUND %02d Received conclusion -----',
cur_round,
trial=build_result.trial)

self._update_fuzz_target_and_build_script(cur_round, response, build_result)

self._validate_fuzz_target_and_build_script(cur_round, build_result)
if build_result.success:
logger.info('***** Prototyper succeded in %02d rounds *****', cur_round)
logger.info('***** Prototyper succeded in %02d rounds *****',
cur_round,
trial=build_result.trial)
return None

if not build_result.compiles:
compile_log = self.llm.truncate_prompt(build_result.compile_log)
logger.info('***** Failed to recompile in %02d rounds *****', cur_round)
logger.info('***** Failed to recompile in %02d rounds *****',
cur_round,
trial=build_result.trial)
prompt_text = (
'Failed to build fuzz target. Here is the fuzz target, build script, '
'compliation command, and other compilation runtime output. Analyze '
Expand Down Expand Up @@ -205,7 +231,9 @@ def _container_handle_conclusion(
elif not build_result.is_function_referenced:
logger.info(
'***** Fuzz target does not reference function-under-test in %02d '
'rounds *****', cur_round)
'rounds *****',
cur_round,
trial=build_result.trial)
prompt_text = (
'The fuzz target builds successfully, but the target function '
f'`{build_result.benchmark.function_signature}` was not used by '
Expand All @@ -229,14 +257,16 @@ def _container_tool_reaction(self, cur_round: int, response: str,
return self._container_handle_conclusion(cur_round, response,
build_result)
# Other responses are invalid.
logger.warning('ROUND %02d Invalid response from LLM: %s', cur_round,
response)
logger.warning('ROUND %02d Invalid response from LLM: %s',
cur_round,
response,
trial=build_result.trial)
return self._container_handle_invalid_tool_usage(self.inspect_tool)

def execute(self, result_history: list[Result]) -> BuildResult:
"""Executes the agent based on previous result."""
logger.info('Executing Prototyper')
last_result = result_history[-1]
logger.info('Executing Prototyper', trial=last_result.trial)
benchmark = last_result.benchmark
self.inspect_tool = ProjectContainerTool(benchmark, name='inspect')
self.inspect_tool.compile(extra_commands=' && rm -rf /out/* > /dev/null')
Expand All @@ -250,13 +280,17 @@ def execute(self, result_history: list[Result]) -> BuildResult:
try:
client = self.llm.get_chat_client(model=self.llm.get_model())
while prompt and cur_round < MAX_ROUND:
response = self.chat_llm(cur_round, client=client, prompt=prompt)
response = self.chat_llm(cur_round,
client=client,
prompt=prompt,
trial=last_result.trial)
prompt = self._container_tool_reaction(cur_round, response,
build_result)
cur_round += 1
finally:
# Cleanup: stop and remove the container
logger.debug('Stopping and removing the inspect container %s',
self.inspect_tool.container_id)
self.inspect_tool.container_id,
trial=last_result.trial)
self.inspect_tool.terminate()
return build_result
71 changes: 34 additions & 37 deletions logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@

FINAL_RESULT_JSON = 'result.json'

_trial_logger = None


class CustomLoggerAdapter(logging.LoggerAdapter):
"""A note-taker to log and record experiment status, key info, and final
Expand Down Expand Up @@ -61,76 +59,76 @@ def write_chat_history(self, result: Result) -> None:

def debug(msg: object,
*args: object,
trial: int,
exc_info=None,
stack_info: bool = False,
stacklevel: int = 1,
extra: Mapping[str, object] | None = None,
**kwargs: object) -> None:
return get_trial_logger().debug(msg,
*args,
exc_info=exc_info,
stack_info=stack_info,
stacklevel=stacklevel,
extra=extra,
**kwargs)
return get_trial_logger(trial=trial).debug(msg,
*args,
exc_info=exc_info,
stack_info=stack_info,
stacklevel=stacklevel,
extra=extra,
**kwargs)


def info(msg: object,
*args: object,
trial: int,
exc_info=None,
stack_info: bool = False,
stacklevel: int = 1,
extra: Mapping[str, object] | None = None,
**kwargs: object) -> None:
return get_trial_logger().info(msg,
*args,
exc_info=exc_info,
stack_info=stack_info,
stacklevel=stacklevel,
extra=extra,
**kwargs)
return get_trial_logger(trial=trial).info(msg,
*args,
exc_info=exc_info,
stack_info=stack_info,
stacklevel=stacklevel,
extra=extra,
**kwargs)


def warning(msg: object,
*args: object,
trial: int,
exc_info=None,
stack_info: bool = False,
stacklevel: int = 1,
extra: Mapping[str, object] | None = None,
**kwargs: object) -> None:
return get_trial_logger().warning(msg,
*args,
exc_info=exc_info,
stack_info=stack_info,
stacklevel=stacklevel,
extra=extra,
**kwargs)
return get_trial_logger(trial=trial).warning(msg,
*args,
exc_info=exc_info,
stack_info=stack_info,
stacklevel=stacklevel,
extra=extra,
**kwargs)


def error(msg: object,
*args: object,
trial: int,
exc_info=None,
stack_info: bool = False,
stacklevel: int = 1,
extra: Mapping[str, object] | None = None,
**kwargs: object) -> None:
return get_trial_logger().error(msg,
*args,
exc_info=exc_info,
stack_info=stack_info,
stacklevel=stacklevel,
extra=extra,
**kwargs)
return get_trial_logger(trial=trial).error(msg,
*args,
exc_info=exc_info,
stack_info=stack_info,
stacklevel=stacklevel,
extra=extra,
**kwargs)


def get_trial_logger(name: str = __name__,
trial: int = 0,
level=logging.DEBUG) -> CustomLoggerAdapter:
"""Sets up or retrieves the singleton instance of CustomLoggerAdapter."""
global _trial_logger
if _trial_logger:
return _trial_logger

"""Sets up or retrieves a thread-local CustomLoggerAdapter for each thread."""
logger = logging.getLogger(name)
if not logger.handlers:
formatter = logging.Formatter(
Expand All @@ -143,5 +141,4 @@ def get_trial_logger(name: str = __name__,
logger.setLevel(level)
logger.propagate = False

_trial_logger = CustomLoggerAdapter(logger, {'trial': trial})
return _trial_logger
return CustomLoggerAdapter(logger, {'trial': trial})
Loading