Skip to content

Commit af8b45c

Browse files
author
Alex Wang
committed
feat: Add replay mode detection for default logger
- Add visited_operation to logger - Replace execution_arn with execution_state - Add is_replay method so that we will not log when it is in replay mode
1 parent 636d757 commit af8b45c

File tree

6 files changed

+148
-52
lines changed

6 files changed

+148
-52
lines changed

src/aws_durable_execution_sdk_python/concurrency/executor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,7 @@ def _execute_item_in_child_context(
381381
executor_context._parent_id, # noqa: SLF001
382382
name,
383383
)
384+
child_context.logger.visit_operation(operation_id=operation_id)
384385

385386
def run_in_child_handler():
386387
return self.execute_item(child_context, executable)

src/aws_durable_execution_sdk_python/context.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,8 @@ def __init__(
176176
self._step_counter: OrderedCounter = OrderedCounter()
177177

178178
log_info = LogInfo(
179-
execution_arn=state.durable_execution_arn, parent_id=parent_id
179+
execution_state=state,
180+
parent_id=parent_id,
180181
)
181182
self._log_info = log_info
182183
self.logger: Logger = logger or Logger.from_log_info(
@@ -205,7 +206,8 @@ def create_child_context(self, parent_id: str) -> DurableContext:
205206
parent_id=parent_id,
206207
logger=self.logger.with_log_info(
207208
LogInfo(
208-
execution_arn=self.state.durable_execution_arn, parent_id=parent_id
209+
execution_state=self.state,
210+
parent_id=parent_id,
209211
)
210212
),
211213
)
@@ -269,6 +271,7 @@ def create_callback(
269271
if not config:
270272
config = CallbackConfig()
271273
operation_id: str = self._create_step_id()
274+
self.logger.visit_operation(operation_id=operation_id)
272275
callback_id: str = create_callback_handler(
273276
state=self.state,
274277
operation_identifier=OperationIdentifier(
@@ -302,12 +305,14 @@ def invoke(
302305
Returns:
303306
The result of the invoked function
304307
"""
308+
operation_id = self._create_step_id()
309+
self.logger.visit_operation(operation_id=operation_id)
305310
return invoke_handler(
306311
function_name=function_name,
307312
payload=payload,
308313
state=self.state,
309314
operation_identifier=OperationIdentifier(
310-
operation_id=self._create_step_id(),
315+
operation_id=operation_id,
311316
parent_id=self._parent_id,
312317
name=name,
313318
),
@@ -325,6 +330,7 @@ def map(
325330
map_name: str | None = self._resolve_step_name(name, func)
326331

327332
operation_id = self._create_step_id()
333+
self.logger.visit_operation(operation_id=operation_id)
328334
operation_identifier = OperationIdentifier(
329335
operation_id=operation_id, parent_id=self._parent_id, name=map_name
330336
)
@@ -367,6 +373,7 @@ def parallel(
367373
"""Execute multiple callables in parallel."""
368374
# _create_step_id() is thread-safe. rest of method is safe, since using local copy of parent id
369375
operation_id = self._create_step_id()
376+
self.logger.visit_operation(operation_id=operation_id)
370377
parallel_context = self.create_child_context(parent_id=operation_id)
371378
operation_identifier = OperationIdentifier(
372379
operation_id=operation_id, parent_id=self._parent_id, name=name
@@ -420,6 +427,7 @@ def run_in_child_context(
420427
step_name: str | None = self._resolve_step_name(name, func)
421428
# _create_step_id() is thread-safe. rest of method is safe, since using local copy of parent id
422429
operation_id = self._create_step_id()
430+
self.logger.visit_operation(operation_id=operation_id)
423431

424432
def callable_with_child_context():
425433
return func(self.create_child_context(parent_id=operation_id))
@@ -441,13 +449,15 @@ def step(
441449
) -> T:
442450
step_name = self._resolve_step_name(name, func)
443451
logger.debug("Step name: %s", step_name)
452+
operation_id = self._create_step_id()
453+
self.logger.visit_operation(operation_id=operation_id)
444454

445455
return step_handler(
446456
func=func,
447457
config=config,
448458
state=self.state,
449459
operation_identifier=OperationIdentifier(
450-
operation_id=self._create_step_id(),
460+
operation_id=operation_id,
451461
parent_id=self._parent_id,
452462
name=step_name,
453463
),
@@ -465,11 +475,14 @@ def wait(self, duration: Duration, name: str | None = None) -> None:
465475
if seconds < 1:
466476
msg = "duration must be at least 1 second"
467477
raise ValidationError(msg)
478+
operation_id = self._create_step_id()
479+
self.logger.visit_operation(operation_id=operation_id)
480+
468481
wait_handler(
469482
seconds=seconds,
470483
state=self.state,
471484
operation_identifier=OperationIdentifier(
472-
operation_id=self._create_step_id(),
485+
operation_id=operation_id,
473486
parent_id=self._parent_id,
474487
name=name,
475488
),
@@ -515,12 +528,15 @@ def wait_for_condition(
515528
msg = "`config` is required for wait_for_condition"
516529
raise ValidationError(msg)
517530

531+
operation_id = self._create_step_id()
532+
self.logger.visit_operation(operation_id=operation_id)
533+
518534
return wait_for_condition_handler(
519535
check=check,
520536
config=config,
521537
state=self.state,
522538
operation_identifier=OperationIdentifier(
523-
operation_id=self._create_step_id(),
539+
operation_id=operation_id,
524540
parent_id=self._parent_id,
525541
name=name,
526542
),

src/aws_durable_execution_sdk_python/logger.py

Lines changed: 55 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,29 +5,34 @@
55
from dataclasses import dataclass
66
from typing import TYPE_CHECKING
77

8+
from aws_durable_execution_sdk_python.lambda_service import OperationType
89
from aws_durable_execution_sdk_python.types import LoggerInterface
910

1011
if TYPE_CHECKING:
11-
from collections.abc import Mapping, MutableMapping
12+
from collections.abc import Callable, Mapping, MutableMapping
1213

14+
from aws_durable_execution_sdk_python.context import ExecutionState
1315
from aws_durable_execution_sdk_python.identifier import OperationIdentifier
1416

1517

1618
@dataclass(frozen=True)
1719
class LogInfo:
18-
execution_arn: str
20+
execution_state: ExecutionState
1921
parent_id: str | None = None
2022
operation_id: str | None = None
2123
name: str | None = None
2224
attempt: int | None = None
2325

2426
@classmethod
2527
def from_operation_identifier(
26-
cls, execution_arn: str, op_id: OperationIdentifier, attempt: int | None = None
28+
cls,
29+
execution_state: ExecutionState,
30+
op_id: OperationIdentifier,
31+
attempt: int | None = None,
2732
) -> LogInfo:
2833
"""Create new log info from an execution arn, OperationIdentifier and attempt."""
2934
return cls(
30-
execution_arn=execution_arn,
35+
execution_state=execution_state,
3136
parent_id=op_id.parent_id,
3237
operation_id=op_id.operation_id,
3338
name=op_id.name,
@@ -37,7 +42,7 @@ def from_operation_identifier(
3742
def with_parent_id(self, parent_id: str) -> LogInfo:
3843
"""Clone the log info with a new parent id."""
3944
return LogInfo(
40-
execution_arn=self.execution_arn,
45+
execution_state=self.execution_state,
4146
parent_id=parent_id,
4247
operation_id=self.operation_id,
4348
name=self.name,
@@ -47,15 +52,23 @@ def with_parent_id(self, parent_id: str) -> LogInfo:
4752

4853
class Logger(LoggerInterface):
4954
def __init__(
50-
self, logger: LoggerInterface, default_extra: Mapping[str, object]
55+
self,
56+
logger: LoggerInterface,
57+
default_extra: Mapping[str, object],
58+
execution_state: ExecutionState,
59+
visited_operations: set[str] | None = None,
5160
) -> None:
5261
self._logger = logger
5362
self._default_extra = default_extra
63+
self._execution_state = execution_state
64+
self._visited_operations = visited_operations or set()
5465

5566
@classmethod
5667
def from_log_info(cls, logger: LoggerInterface, info: LogInfo) -> Logger:
5768
"""Create a new logger with the given LogInfo."""
58-
extra: MutableMapping[str, object] = {"execution_arn": info.execution_arn}
69+
extra: MutableMapping[str, object] = {
70+
"execution_arn": info.execution_state.durable_execution_arn
71+
}
5972
if info.parent_id:
6073
extra["parent_id"] = info.parent_id
6174
if info.name:
@@ -65,7 +78,9 @@ def from_log_info(cls, logger: LoggerInterface, info: LogInfo) -> Logger:
6578
extra["attempt"] = info.attempt + 1
6679
if info.operation_id:
6780
extra["operation_id"] = info.operation_id
68-
return cls(logger, extra)
81+
return cls(
82+
logger=logger, default_extra=extra, execution_state=info.execution_state
83+
)
6984

7085
def with_log_info(self, info: LogInfo) -> Logger:
7186
"""Clone the existing logger with new LogInfo."""
@@ -81,29 +96,52 @@ def get_logger(self) -> LoggerInterface:
8196
def debug(
8297
self, msg: object, *args: object, extra: Mapping[str, object] | None = None
8398
) -> None:
84-
merged_extra = {**self._default_extra, **(extra or {})}
85-
self._logger.debug(msg, *args, extra=merged_extra)
99+
self._log(self._logger.debug, msg, *args, extra=extra)
86100

87101
def info(
88102
self, msg: object, *args: object, extra: Mapping[str, object] | None = None
89103
) -> None:
90-
merged_extra = {**self._default_extra, **(extra or {})}
91-
self._logger.info(msg, *args, extra=merged_extra)
104+
self._log(self._logger.info, msg, *args, extra=extra)
92105

93106
def warning(
94107
self, msg: object, *args: object, extra: Mapping[str, object] | None = None
95108
) -> None:
96-
merged_extra = {**self._default_extra, **(extra or {})}
97-
self._logger.warning(msg, *args, extra=merged_extra)
109+
self._log(self._logger.warning, msg, *args, extra=extra)
98110

99111
def error(
100112
self, msg: object, *args: object, extra: Mapping[str, object] | None = None
101113
) -> None:
102-
merged_extra = {**self._default_extra, **(extra or {})}
103-
self._logger.error(msg, *args, extra=merged_extra)
114+
self._log(self._logger.error, msg, *args, extra=extra)
104115

105116
def exception(
106117
self, msg: object, *args: object, extra: Mapping[str, object] | None = None
107118
) -> None:
119+
self._log(self._logger.exception, msg, *args, extra=extra)
120+
121+
def _log(
122+
self,
123+
log_func: Callable,
124+
msg: object,
125+
*args: object,
126+
extra: Mapping[str, object] | None = None,
127+
):
128+
if not self._should_log():
129+
return
108130
merged_extra = {**self._default_extra, **(extra or {})}
109-
self._logger.exception(msg, *args, extra=merged_extra)
131+
log_func(msg, *args, extra=merged_extra)
132+
133+
def visit_operation(self, operation_id: str):
134+
self._visited_operations.add(operation_id)
135+
136+
def _is_replay(self) -> bool:
137+
if not self._execution_state or not self._execution_state.operations:
138+
return False
139+
140+
return any(
141+
operation_id in self._visited_operations
142+
for operation_id, operation in self._execution_state.operations.items()
143+
if operation.operation_type != OperationType.EXECUTION
144+
)
145+
146+
def _should_log(self) -> bool:
147+
return not self._is_replay()

src/aws_durable_execution_sdk_python/operation/step.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def step_handler(
128128
step_context = StepContext(
129129
logger=context_logger.with_log_info(
130130
LogInfo.from_operation_identifier(
131-
execution_arn=state.durable_execution_arn,
131+
execution_state=state,
132132
op_id=operation_identifier,
133133
attempt=attempt,
134134
)

src/aws_durable_execution_sdk_python/operation/wait_for_condition.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def wait_for_condition_handler(
133133
check_context = WaitForConditionCheckContext(
134134
logger=context_logger.with_log_info(
135135
LogInfo.from_operation_identifier(
136-
execution_arn=state.durable_execution_arn,
136+
execution_state=state,
137137
op_id=operation_identifier,
138138
attempt=attempt,
139139
)

0 commit comments

Comments
 (0)