33from typing import Any , cast
44from unittest .mock import AsyncMock , patch
55
6+ import pytest
67from psycopg import AsyncConnection
78from psycopg .rows import TupleRow
89from psycopg_pool import AsyncConnectionPool
@@ -23,7 +24,7 @@ def __init__(self, return_value):
2324 async def __aenter__ (self ):
2425 return self .return_value
2526
26- async def __aexit__ (self , exc_type , exc_val , exc_tb ):
27+ async def __aexit__ (self , _exc_type , _exc_val , _exc_tb ):
2728 return None
2829
2930
@@ -54,7 +55,7 @@ def mock_connection() -> AsyncContextManagerMock:
5455
5556 self .mock_pool .connection = mock_connection
5657
57- def mock_cursor_method (* args : Any , ** kwargs : Any ) -> AsyncContextManagerMock :
58+ def mock_cursor_method (* _args : Any , ** _kwargs : Any ) -> AsyncContextManagerMock :
5859 return AsyncContextManagerMock (mock_cursor )
5960
6061 mock_conn .cursor = mock_cursor_method
@@ -353,3 +354,206 @@ async def test_close(self):
353354
354355 self .mock_pool .close .assert_called_once ()
355356 self .assertFalse (self .session ._initialized )
357+
358+ @patch ("agents.extensions.memory.postgres_session.AsyncConnectionPool" )
359+ async def test_from_connection_string_success (self , mock_pool_class ):
360+ """Test creating a session from connection string."""
361+ mock_pool = AsyncMock ()
362+ mock_pool_class .return_value = mock_pool
363+
364+ connection_string = "postgresql://user:pass@host/db"
365+ session_id = "test_session_123"
366+
367+ session = await PostgreSQLSession .from_connection_string (session_id , connection_string )
368+
369+ # Verify pool was created with the connection string
370+ mock_pool_class .assert_called_once_with (connection_string )
371+ mock_pool .open .assert_called_once ()
372+
373+ # Verify session was created with correct parameters
374+ self .assertEqual (session .session_id , session_id )
375+ self .assertEqual (session .pool , mock_pool )
376+ self .assertEqual (session .sessions_table , "agent_sessions" )
377+ self .assertEqual (session .messages_table , "agent_messages" )
378+
379+ @patch ("agents.extensions.memory.postgres_session.AsyncConnectionPool" )
380+ async def test_from_connection_string_custom_tables (self , mock_pool_class ):
381+ """Test creating a session from connection string with custom table names."""
382+ mock_pool = AsyncMock ()
383+ mock_pool_class .return_value = mock_pool
384+
385+ connection_string = "postgresql://user:pass@host/db"
386+ session_id = "test_session_123"
387+ custom_sessions_table = "custom_sessions"
388+ custom_messages_table = "custom_messages"
389+
390+ session = await PostgreSQLSession .from_connection_string (
391+ session_id ,
392+ connection_string ,
393+ sessions_table = custom_sessions_table ,
394+ messages_table = custom_messages_table ,
395+ )
396+
397+ # Verify pool was created with the connection string
398+ mock_pool_class .assert_called_once_with (connection_string )
399+ mock_pool .open .assert_called_once ()
400+
401+ # Verify session was created with correct parameters
402+ self .assertEqual (session .session_id , session_id )
403+ self .assertEqual (session .pool , mock_pool )
404+ self .assertEqual (session .sessions_table , custom_sessions_table )
405+ self .assertEqual (session .messages_table , custom_messages_table )
406+
407+
408+ @pytest .mark .skip (reason = "Integration tests require a running PostgreSQL instance" )
409+ class TestPostgreSQLSessionIntegration (unittest .IsolatedAsyncioTestCase ):
410+ """Integration tests for PostgreSQL session that require a running database."""
411+
412+ # Test connection string - modify as needed for your test database
413+ TEST_CONNECTION_STRING = "postgresql://postgres:password@localhost:5432/test_db"
414+
415+ async def asyncSetUp (self ):
416+ """Set up test session."""
417+ self .session_id = "test_integration_session"
418+ self .session = await PostgreSQLSession .from_connection_string (
419+ self .session_id ,
420+ self .TEST_CONNECTION_STRING ,
421+ sessions_table = "test_sessions" ,
422+ messages_table = "test_messages" ,
423+ )
424+
425+ # Clean up any existing test data
426+ await self .session .clear_session ()
427+
428+ async def asyncTearDown (self ):
429+ """Clean up after tests."""
430+ if hasattr (self , "session" ):
431+ await self .session .clear_session ()
432+ await self .session .close ()
433+
434+ async def test_integration_full_workflow (self ):
435+ """Test complete workflow: add items, get items, pop item, clear session."""
436+ # Initially empty
437+ items = await self .session .get_items ()
438+ self .assertEqual (len (items ), 0 )
439+
440+ # Add some test items
441+ test_items = cast (
442+ list [TResponseInputItem ],
443+ [
444+ {"role" : "user" , "content" : "Hello" , "type" : "message" },
445+ {"role" : "assistant" , "content" : "Hi there!" , "type" : "message" },
446+ {"role" : "user" , "content" : "How are you?" , "type" : "message" },
447+ {"role" : "assistant" , "content" : "I'm doing well, thank you!" , "type" : "message" },
448+ ],
449+ )
450+
451+ for item in test_items :
452+ await self .session .add_items ([item ])
453+
454+ # Verify items were added
455+ stored_items = await self .session .get_items ()
456+ self .assertEqual (len (stored_items ), 4 )
457+ self .assertEqual (stored_items [0 ], test_items [0 ])
458+ self .assertEqual (stored_items [- 1 ], test_items [- 1 ])
459+
460+ # Test with limit
461+ limited_items = await self .session .get_items (limit = 2 )
462+ self .assertEqual (len (limited_items ), 2 )
463+ # Should get the last 2 items in chronological order
464+ self .assertEqual (limited_items [0 ], test_items [2 ])
465+ self .assertEqual (limited_items [1 ], test_items [3 ])
466+
467+ # Test pop_item
468+ popped_item = await self .session .pop_item ()
469+ self .assertEqual (popped_item , test_items [3 ]) # Last item
470+
471+ # Verify item was removed
472+ remaining_items = await self .session .get_items ()
473+ self .assertEqual (len (remaining_items ), 3 )
474+ self .assertEqual (remaining_items [- 1 ], test_items [2 ])
475+
476+ # Test clear_session
477+ await self .session .clear_session ()
478+ final_items = await self .session .get_items ()
479+ self .assertEqual (len (final_items ), 0 )
480+
481+ async def test_integration_multiple_sessions (self ):
482+ """Test that different sessions maintain separate data."""
483+ # Create a second session
484+ session2 = await PostgreSQLSession .from_connection_string (
485+ "test_integration_session_2" ,
486+ self .TEST_CONNECTION_STRING ,
487+ sessions_table = "test_sessions" ,
488+ messages_table = "test_messages" ,
489+ )
490+
491+ try :
492+ # Add different items to each session
493+ items1 = cast (
494+ list [TResponseInputItem ],
495+ [{"role" : "user" , "content" : "Session 1 message" , "type" : "message" }],
496+ )
497+ items2 = cast (
498+ list [TResponseInputItem ],
499+ [{"role" : "user" , "content" : "Session 2 message" , "type" : "message" }],
500+ )
501+
502+ await self .session .add_items (items1 )
503+ await session2 .add_items (items2 )
504+
505+ # Verify sessions have different data
506+ session1_items = await self .session .get_items ()
507+ session2_items = await session2 .get_items ()
508+
509+ self .assertEqual (len (session1_items ), 1 )
510+ self .assertEqual (len (session2_items ), 1 )
511+ self .assertEqual (session1_items [0 ]["content" ], "Session 1 message" ) # type: ignore
512+ self .assertEqual (session2_items [0 ]["content" ], "Session 2 message" ) # type: ignore
513+
514+ finally :
515+ await session2 .clear_session ()
516+ await session2 .close ()
517+
518+ async def test_integration_empty_session_operations (self ):
519+ """Test operations on empty session."""
520+ # Pop from empty session
521+ popped = await self .session .pop_item ()
522+ self .assertIsNone (popped )
523+
524+ # Get items from empty session
525+ items = await self .session .get_items ()
526+ self .assertEqual (len (items ), 0 )
527+
528+ # Get items with limit from empty session
529+ limited_items = await self .session .get_items (limit = 5 )
530+ self .assertEqual (len (limited_items ), 0 )
531+
532+ # Clear empty session (should not error)
533+ await self .session .clear_session ()
534+
535+ async def test_integration_connection_string_with_custom_tables (self ):
536+ """Test creating session with custom table names."""
537+ custom_session = await PostgreSQLSession .from_connection_string (
538+ "custom_table_test" ,
539+ self .TEST_CONNECTION_STRING ,
540+ sessions_table = "custom_sessions_table" ,
541+ messages_table = "custom_messages_table" ,
542+ )
543+
544+ try :
545+ # Test basic functionality with custom tables
546+ test_items = cast (
547+ list [TResponseInputItem ],
548+ [{"role" : "user" , "content" : "Custom table test" , "type" : "message" }],
549+ )
550+
551+ await custom_session .add_items (test_items )
552+ stored_items = await custom_session .get_items ()
553+
554+ self .assertEqual (len (stored_items ), 1 )
555+ self .assertEqual (stored_items [0 ]["content" ], "Custom table test" ) # type: ignore
556+
557+ finally :
558+ await custom_session .clear_session ()
559+ await custom_session .close ()
0 commit comments