diff --git a/samples/python/src/roles/shopping_agent/subagents/payment_method_collector/tools.py b/samples/python/src/roles/shopping_agent/subagents/payment_method_collector/tools.py index f24fe572..a876fa53 100644 --- a/samples/python/src/roles/shopping_agent/subagents/payment_method_collector/tools.py +++ b/samples/python/src/roles/shopping_agent/subagents/payment_method_collector/tools.py @@ -20,6 +20,7 @@ from google.adk.tools.tool_context import ToolContext +from ap2.types.mandate import CartMandate from ap2.types.payment_request import PAYMENT_METHOD_DATA_DATA_KEY from common.a2a_message_builder import A2aMessageBuilder from common import artifact_utils @@ -41,7 +42,7 @@ async def get_payment_methods( Returns: A dictionary of the user's applicable payment methods. """ - cart_mandate = tool_context.state["cart_mandate"] + cart_mandate = CartMandate.model_validate(tool_context.state["cart_mandate"]) message_builder = ( A2aMessageBuilder() .set_context_id(tool_context.state["shopping_context_id"]) diff --git a/samples/python/src/roles/shopping_agent/subagents/shopper/tools.py b/samples/python/src/roles/shopping_agent/subagents/shopper/tools.py index cee16eff..5c23ba72 100644 --- a/samples/python/src/roles/shopping_agent/subagents/shopper/tools.py +++ b/samples/python/src/roles/shopping_agent/subagents/shopper/tools.py @@ -65,7 +65,7 @@ def create_intent_mandate( datetime.now(timezone.utc) + timedelta(days=1) ).isoformat(), ) - tool_context.state["intent_mandate"] = intent_mandate + tool_context.state["intent_mandate"] = intent_mandate.model_dump() return intent_mandate @@ -84,9 +84,10 @@ async def find_products( Raises: RuntimeError: If the merchant agent fails to provide products. """ - intent_mandate = tool_context.state["intent_mandate"] - if not intent_mandate: + intent_mandate_data = tool_context.state["intent_mandate"] + if not intent_mandate_data: raise RuntimeError("No IntentMandate found in tool context state.") + intent_mandate = IntentMandate.model_validate(intent_mandate_data) risk_data = _collect_risk_data(tool_context) if not risk_data: raise RuntimeError("No risk data found in tool context state.") @@ -106,7 +107,7 @@ async def find_products( tool_context.state["shopping_context_id"] = task.context_id cart_mandates = _parse_cart_mandates(task.artifacts) - tool_context.state["cart_mandates"] = cart_mandates + tool_context.state["cart_mandates"] = [cm.model_dump() for cm in cart_mandates] return cart_mandates @@ -117,7 +118,10 @@ def update_chosen_cart_mandate(cart_id: str, tool_context: ToolContext) -> str: cart_id: The ID of the chosen cart. tool_context: The ADK supplied tool context. """ - cart_mandates: list[CartMandate] = tool_context.state.get("cart_mandates", []) + cart_mandates = [ + CartMandate.model_validate(cm) + for cm in tool_context.state.get("cart_mandates", []) + ] for cart in cart_mandates: print( f"Checking cart with ID: {cart.contents.id} with chosen ID: {cart_id}" diff --git a/samples/python/src/roles/shopping_agent/tools.py b/samples/python/src/roles/shopping_agent/tools.py index 1d8b7bf1..280bf1e2 100644 --- a/samples/python/src/roles/shopping_agent/tools.py +++ b/samples/python/src/roles/shopping_agent/tools.py @@ -76,8 +76,8 @@ async def update_cart( _parse_cart_mandates(task.artifacts) ) - tool_context.state["cart_mandate"] = updated_cart_mandate - tool_context.state["shipping_address"] = shipping_address + tool_context.state["cart_mandate"] = updated_cart_mandate.model_dump() + tool_context.state["shipping_address"] = shipping_address.model_dump() return updated_cart_mandate @@ -163,7 +163,7 @@ def store_receipt_if_present(task, tool_context: ToolContext) -> None: ) if payment_receipts: payment_receipt = artifact_utils.only(payment_receipts) - tool_context.state["payment_receipt"] = payment_receipt + tool_context.state["payment_receipt"] = payment_receipt.model_dump() def create_payment_mandate( @@ -181,10 +181,12 @@ def create_payment_mandate( Returns: The payment mandate. """ - cart_mandate = tool_context.state["cart_mandate"] + cart_mandate = CartMandate.model_validate(tool_context.state["cart_mandate"]) payment_request = cart_mandate.contents.payment_request - shipping_address = tool_context.state["shipping_address"] + shipping_address = ContactAddress.model_validate( + tool_context.state["shipping_address"] + ) payment_method = os.environ.get("PAYMENT_METHOD", "CARD") if payment_method == "x402": @@ -215,7 +217,7 @@ def create_payment_mandate( ), ) - tool_context.state["payment_mandate"] = payment_mandate + tool_context.state["payment_mandate"] = payment_mandate.model_dump() return payment_mandate @@ -238,8 +240,12 @@ def sign_mandates_on_user_device(tool_context: ToolContext) -> str: Returns: A string representing the simulated user authorization signature (JWT). """ - payment_mandate: PaymentMandate = tool_context.state["payment_mandate"] - cart_mandate: CartMandate = tool_context.state["cart_mandate"] + payment_mandate = PaymentMandate.model_validate( + tool_context.state["payment_mandate"] + ) + cart_mandate = CartMandate.model_validate( + tool_context.state["cart_mandate"] + ) cart_mandate_hash = _generate_cart_mandate_hash(cart_mandate) payment_mandate_hash = _generate_payment_mandate_hash( payment_mandate.payment_mandate_contents @@ -250,7 +256,7 @@ def sign_mandates_on_user_device(tool_context: ToolContext) -> str: payment_mandate.user_authorization = ( cart_mandate_hash + "_" + payment_mandate_hash ) - tool_context.state["signed_payment_mandate"] = payment_mandate + tool_context.state["signed_payment_mandate"] = payment_mandate.model_dump() return payment_mandate.user_authorization