Skip to content

Commit

Permalink
fix(streaming): avoid invalid deser type error
Browse files Browse the repository at this point in the history
  • Loading branch information
RobertCraigie committed Jan 27, 2025
1 parent d0f98a4 commit 6ae0e9c
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 8 deletions.
16 changes: 12 additions & 4 deletions src/anthropic/lib/streaming/_beta_messages.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from __future__ import annotations

from types import TracebackType
from typing import TYPE_CHECKING, Any, Callable, cast
from typing import TYPE_CHECKING, Any, Type, Callable, cast
from typing_extensions import Self, Iterator, Awaitable, AsyncIterator, assert_never

import httpx
from pydantic import BaseModel

from ..._utils import consume_sync_iterator, consume_async_iterator
from ..._models import build, construct_type
from ..._models import build, construct_type, construct_type_unchecked
from ._beta_types import (
BetaTextEvent,
BetaCitationEvent,
Expand Down Expand Up @@ -372,8 +372,16 @@ def accumulate_event(
event: BetaRawMessageStreamEvent,
current_snapshot: BetaMessage | None,
) -> BetaMessage:
if not isinstance(event, BaseModel): # pyright: ignore[reportUnnecessaryIsInstance]
raise TypeError(f"Unexpected event runtime type - {event}")
if not isinstance(cast(Any, event), BaseModel):
event = cast( # pyright: ignore[reportUnnecessaryCast]
BetaRawMessageStreamEvent,
construct_type_unchecked(
type_=cast(Type[BetaRawMessageStreamEvent], BetaRawMessageStreamEvent),
value=event,
),
)
if not isinstance(cast(Any, event), BaseModel):
raise TypeError(f"Unexpected event runtime type, after deserialising twice - {event} - {type(event)}")

if current_snapshot is None:
if event.type == "message_start":
Expand Down
16 changes: 12 additions & 4 deletions src/anthropic/lib/streaming/_messages.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from types import TracebackType
from typing import TYPE_CHECKING, Any, Callable, cast
from typing import TYPE_CHECKING, Any, Type, Callable, cast
from typing_extensions import Self, Iterator, Awaitable, AsyncIterator, assert_never

import httpx
Expand All @@ -17,7 +17,7 @@
)
from ...types import Message, ContentBlock, RawMessageStreamEvent
from ..._utils import consume_sync_iterator, consume_async_iterator
from ..._models import build, construct_type
from ..._models import build, construct_type, construct_type_unchecked
from ..._streaming import Stream, AsyncStream


Expand Down Expand Up @@ -372,8 +372,16 @@ def accumulate_event(
event: RawMessageStreamEvent,
current_snapshot: Message | None,
) -> Message:
if not isinstance(event, BaseModel): # pyright: ignore[reportUnnecessaryIsInstance]
raise TypeError(f"Unexpected event runtime type - {event}")
if not isinstance(cast(Any, event), BaseModel):
event = cast( # pyright: ignore[reportUnnecessaryCast]
RawMessageStreamEvent,
construct_type_unchecked(
type_=cast(Type[RawMessageStreamEvent], RawMessageStreamEvent),
value=event,
),
)
if not isinstance(cast(Any, event), BaseModel):
raise TypeError(f"Unexpected event runtime type, after deserialising twice - {event} - {type(event)}")

if current_snapshot is None:
if event.type == "message_start":
Expand Down

0 comments on commit 6ae0e9c

Please sign in to comment.