diff --git a/getstream/video/rtc/__init__.py b/getstream/video/rtc/__init__.py index e0fe36b4..42d1746a 100644 --- a/getstream/video/rtc/__init__.py +++ b/getstream/video/rtc/__init__.py @@ -47,7 +47,7 @@ async def discover_location(): async def join( - call: Call, user_id: Optional[str] = None, create=True, **kwargs + call: Call, user_id: Optional[str] = None, create=True, fast_join=False, **kwargs ) -> ConnectionManager: """ Join a call. This method will: @@ -60,6 +60,7 @@ async def join( call: The call to join user_id: The user id to join with create: Whether to create the call if it doesn't exist + fast_join: Whether to use fast join (default: False) **kwargs: Additional arguments to pass to the join call request Returns: @@ -67,7 +68,7 @@ async def join( """ # Return ConnectionManager instance that handles everything internally # when used as an async context manager and async iterator - return ConnectionManager(call=call, user_id=user_id, create=create, **kwargs) + return ConnectionManager(call=call, user_id=user_id, create=create, fast_join=fast_join, **kwargs) __all__ = [ diff --git a/getstream/video/rtc/connection_manager.py b/getstream/video/rtc/connection_manager.py index 26e6d0de..477d7a0f 100644 --- a/getstream/video/rtc/connection_manager.py +++ b/getstream/video/rtc/connection_manager.py @@ -8,6 +8,7 @@ import aiortc from getstream.common import telemetry +from getstream.stream_response import StreamResponse from getstream.utils import StreamAsyncIOEventEmitter from getstream.video.rtc.coordinator.ws import StreamAPIWS from getstream.video.rtc.pb.stream.video.sfu.event import events_pb2 @@ -22,6 +23,7 @@ ConnectionOptions, connect_websocket, join_call, + fast_join_call, watch_call, ) from getstream.video.rtc.track_util import ( @@ -53,6 +55,7 @@ def __init__( user_id: Optional[str] = None, create: bool = True, subscription_config: Optional[SubscriptionConfig] = None, + fast_join: bool = False, **kwargs: Any, ): super().__init__() @@ -61,6 +64,7 @@ def __init__( self.call: Call = call self.user_id: Optional[str] = user_id self.create: bool = create + self.fast_join: bool = fast_join self.kwargs: Dict[str, Any] = kwargs self.running: bool = False self.session_id: str = str(uuid.uuid4()) @@ -269,21 +273,38 @@ async def _connect_internal( "coordinator-join-call", ) as span: if not (ws_url or token): - join_response = await join_call( - self.call, - self.user_id, - "auto", - self.create, - self.local_sfu, - **self.kwargs, - ) - ws_url = join_response.data.credentials.server.ws_endpoint - token = join_response.data.credentials.token - self.join_response = join_response - logger.debug(f"coordinator join response: {join_response.data}") - span.set_attribute( - "credentials", join_response.data.credentials.to_json() - ) + if self.fast_join: + # Use fast join to get multiple edge credentials + fast_join_response = await fast_join_call( + self.call, + self.user_id, + "auto", + self.create, + self.local_sfu, + **self.kwargs, + ) + logger.debug( + f"Received {len(fast_join_response.data.credentials)} edge credentials for fast join" + ) + + self._fast_join_response = fast_join_response + else: + # Use regular join + join_response = await join_call( + self.call, + self.user_id, + "auto", + self.create, + self.local_sfu, + **self.kwargs, + ) + ws_url = join_response.data.credentials.server.ws_endpoint + token = join_response.data.credentials.token + self.join_response = join_response + logger.debug(f"coordinator join response: {join_response.data}") + span.set_attribute( + "credentials", join_response.data.credentials.to_json() + ) # Use provided session_id or current one current_session_id = session_id or self.session_id @@ -295,12 +316,38 @@ async def _connect_internal( with telemetry.start_as_current_span( "sfu-signaling-ws-connect", ) as span: - self._ws_client, sfu_event = await connect_websocket( - token=token, - ws_url=ws_url, - session_id=current_session_id, - options=self._connection_options, - ) + # Handle fast join or regular join + if self.fast_join and hasattr(self, "_fast_join_response"): + # Fast join - race multiple edges + self._ws_client, sfu_event, selected_cred = await self._race_edges( + self._fast_join_response.data.credentials, current_session_id + ) + + # Use the selected credentials + ws_url = selected_cred.server.ws_endpoint + token = selected_cred.token + + #map it to standard join call object so that retry/migration can happen + self.join_response = StreamResponse( + response=self._fast_join_response._StreamResponse__response, + data=JoinCallResponse( + call=self._fast_join_response.data.call, + members=self._fast_join_response.data.members, + credentials=selected_cred, + stats_options=self._fast_join_response.data.stats_options, + duration=self._fast_join_response.data.duration, + ) + ) + + span.set_attribute("credentials", selected_cred.to_json()) + else: + # Regular join - connect to single edge + self._ws_client, sfu_event = await connect_websocket( + token=token, + ws_url=ws_url, + session_id=current_session_id, + options=self._connection_options, + ) self._ws_client.on_wildcard("*", _log_event) self._ws_client.on_event("ice_trickle", self._on_ice_trickle) @@ -530,3 +577,55 @@ async def _restore_published_tracks(self): await self._peer_manager.restore_published_tracks() except Exception as e: logger.error("Failed to restore published tracks", exc_info=e) + + async def _race_edges(self, credentials_list, session_id): + """Try multiple edge WebSocket connections sequentially and return the first successful one. + + This method iterates through edge URLs one by one, attempting to connect to each. + The first edge that successfully connects is used, and the iteration stops. + + Args: + credentials_list: List of Credentials to try + session_id: Session ID for the connection + + Returns: + Tuple of (WebSocket client, SFU event, selected Credentials) + + Raises: + SfuConnectionError: If all edge connections fail + """ + if not credentials_list: + raise SfuConnectionError("No edge credentials provided for racing") + + logger.info(f"Trying {len(credentials_list)} edge connections sequentially") + + errors = [] + + # Try each edge sequentially + for cred in credentials_list: + logger.debug(f"Trying edge {cred.server.edge_name} at {cred.server.ws_endpoint}") + + try: + # Attempt to connect to this edge + ws_client, sfu_event = await connect_websocket( + token=cred.token, + ws_url=cred.server.ws_endpoint, + session_id=session_id, + options=self._connection_options, + ) + + # Success! Return the connection and credentials + logger.info( + f"Edge {cred.server.edge_name} connected successfully" + ) + return ws_client, sfu_event, cred + + except Exception as e: + errors.append((cred.server.edge_name, str(e))) + # Continue to next edge + + # All connections failed + error_msg = "All edge connections failed:\n" + "\n".join( + f" - {edge}: {error}" for edge, error in errors + ) + raise SfuConnectionError(error_msg) diff --git a/getstream/video/rtc/connection_utils.py b/getstream/video/rtc/connection_utils.py index 7d333240..35cc2860 100644 --- a/getstream/video/rtc/connection_utils.py +++ b/getstream/video/rtc/connection_utils.py @@ -21,7 +21,7 @@ from getstream.models import CallRequest from getstream.utils import build_body_dict, build_query_param from getstream.video.async_call import Call -from getstream.video.rtc.models import JoinCallResponse +from getstream.video.rtc.models import JoinCallResponse, FastJoinCallResponse from getstream.video.rtc.pb.stream.video.sfu.event import events_pb2 from getstream.video.rtc.pb.stream.video.sfu.models.models_pb2 import ( TRACK_TYPE_AUDIO, @@ -41,6 +41,7 @@ "SfuConnectionError", "ConnectionOptions", "join_call", + "fast_join_call", "join_call_coordinator_request", "create_join_request", "prepare_video_track_info", @@ -157,6 +158,63 @@ async def join_call( raise SfuConnectionError(f"Failed to join call: {e}") +async def fast_join_call( + call: Call, + user_id: str, + location: str, + create: bool, + local_sfu: bool, + **kwargs, +) -> StreamResponse[FastJoinCallResponse]: + """Join call via coordinator API using fast join to get multiple edge credentials. + + This function requests multiple edge URLs from the coordinator. The caller + is responsible for racing these edges to find the fastest connection. + + Args: + call: The call to join + user_id: The user ID to join the call with + location: The preferred location + create: Whether to create the call if it doesn't exist + local_sfu: Whether to use local SFU for development + **kwargs: Additional arguments to pass to the join call request + + Returns: + A StreamResponse containing FastJoinCallResponse with multiple edge credentials + + Raises: + SfuConnectionError: If the coordinator request fails + """ + try: + # Import here to avoid circular dependency + from getstream.video.rtc.coordinator_api import fast_join_call_coordinator_request + + # Get multiple edge credentials from coordinator + fast_join_response = await fast_join_call_coordinator_request( + call, + user_id, + location=location, + create=create, + **kwargs, + ) + + if local_sfu: + # Override all credentials with local SFU for development + for cred in fast_join_response.data.credentials: + cred.server.url = "http://127.0.0.1:3031/twirp" + cred.server.ws_endpoint = "ws://127.0.0.1:3031/ws" + + logger.debug( + f"Received {len(fast_join_response.data.credentials)} edge credentials for fast join" + ) + + return fast_join_response + + except Exception as e: + logger.error(f"Failed to fast join call via coordinator: {e}") + raise SfuConnectionError(f"Failed to fast join call: {e}") + + async def join_call_coordinator_request( call: Call, user_id: str, diff --git a/getstream/video/rtc/coordinator_api.py b/getstream/video/rtc/coordinator_api.py index 7647d92d..f40515a2 100644 --- a/getstream/video/rtc/coordinator_api.py +++ b/getstream/video/rtc/coordinator_api.py @@ -11,7 +11,7 @@ from getstream.utils import build_body_dict # Import the types we need from __init__ without creating circular imports -from getstream.video.rtc.models import JoinCallResponse +from getstream.video.rtc.models import JoinCallResponse, FastJoinCallResponse logger = logging.getLogger("getstream.video.rtc.coordinator") @@ -79,3 +79,68 @@ async def join_call_coordinator_request( path_params=path_params, json=json_body, ) + + +async def fast_join_call_coordinator_request( + call: Call, + user_id: str, + create: bool = False, + data: Optional[CallRequest] = None, + ring: Optional[bool] = None, + notify: Optional[bool] = None, + video: Optional[bool] = None, + location: Optional[str] = None, +) -> StreamResponse[FastJoinCallResponse]: + """Make a fast join request to get multiple edge credentials from the coordinator. + + Args: + call: The call to join + user_id: The user ID to join the call with + create: Whether to create the call if it doesn't exist + data: Additional call data if creating + ring: Whether to ring other users + notify: Whether to notify other users + video: Whether to enable video + location: The preferred location + + Returns: + A response containing the call information and an array of credentials for multiple edges + """ + # Create a token for this user + token = call.client.stream.create_token(user_id=user_id) + + # Create a new client with this token + client = call.client.stream.__class__( + api_key=call.client.stream.api_key, + api_secret=call.client.stream.api_secret, + base_url=call.client.stream.base_url, + ) + + # Set up authentication + client.token = token + client.headers["Authorization"] = token + client.client.headers["Authorization"] = token + + # Prepare path parameters for the request + path_params = { + "type": call.call_type, + "id": call.id, + } + + # Build the request body + json_body = build_body_dict( + location=location or "FRA", # Default to Frankfurt if not specified + create=create, + notify=notify, + ring=ring, + video=video, + data=data, + ) + + # Make the POST request to fast join the call + return await client.post( + "/api/v2/video/call/{type}/{id}/fast_join", + FastJoinCallResponse, + path_params=path_params, + json=json_body, + ) diff --git a/getstream/video/rtc/models.py b/getstream/video/rtc/models.py index a5d9c7b1..c5290565 100644 --- a/getstream/video/rtc/models.py +++ b/getstream/video/rtc/models.py @@ -50,3 +50,12 @@ class JoinCallResponse(DataClassJsonMixin): credentials: Credentials = dc_field(metadata=dc_config(field_name="credentials")) stats_options: dict = dc_field(metadata=dc_config(field_name="stats_options")) duration: str = dc_field(metadata=dc_config(field_name="duration")) + + +@dataclass +class FastJoinCallResponse(DataClassJsonMixin): + call: CallResponse = dc_field(metadata=dc_config(field_name="call")) + members: List[MemberResponse] = dc_field(metadata=dc_config(field_name="members")) + credentials: List[Credentials] = dc_field(metadata=dc_config(field_name="credentials")) + stats_options: dict = dc_field(metadata=dc_config(field_name="stats_options")) + duration: str = dc_field(metadata=dc_config(field_name="duration"))