Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace custom SSE reader with source for System.Net.ServerSentEvents #33

Merged
merged 5 commits into from
Jun 14, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -3,6 +3,8 @@
using System.ClientModel.Primitives;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Net.ServerSentEvents;
using System.Threading;
using System.Threading.Tasks;

@@ -31,7 +33,7 @@ public override IAsyncEnumerator<StreamingUpdate> GetAsyncEnumerator(Cancellatio

private sealed class AsyncStreamingUpdateEnumerator : IAsyncEnumerator<StreamingUpdate>
{
private const string _terminalData = "[DONE]";
private static ReadOnlySpan<byte> TerminalData => "[DONE]"u8;

private readonly Func<Task<ClientResult>> _getResultAsync;
private readonly AsyncStreamingUpdateCollection _enumerable;
@@ -44,7 +46,7 @@ private sealed class AsyncStreamingUpdateEnumerator : IAsyncEnumerator<Streaming
// // get _updates from sse event
// foreach (var update in _updates) { ... }
// }
private IAsyncEnumerator<ServerSentEvent>? _events;
private IAsyncEnumerator<SseItem<byte[]>>? _events;
private IEnumerator<StreamingUpdate>? _updates;

private StreamingUpdate? _current;
@@ -84,7 +86,7 @@ async ValueTask<bool> IAsyncEnumerator<StreamingUpdate>.MoveNextAsync()

if (await _events.MoveNextAsync().ConfigureAwait(false))
{
if (_events.Current.Data == _terminalData)
if (_events.Current.Data.AsSpan().SequenceEqual(TerminalData))
{
_current = default;
return false;
@@ -104,7 +106,7 @@ async ValueTask<bool> IAsyncEnumerator<StreamingUpdate>.MoveNextAsync()
return false;
}

private async Task<IAsyncEnumerator<ServerSentEvent>> CreateEventEnumeratorAsync()
private async Task<IAsyncEnumerator<SseItem<byte[]>>> CreateEventEnumeratorAsync()
{
ClientResult result = await _getResultAsync().ConfigureAwait(false);
PipelineResponse response = result.GetRawResponse();
@@ -115,7 +117,7 @@ private async Task<IAsyncEnumerator<ServerSentEvent>> CreateEventEnumeratorAsync
throw new InvalidOperationException("Unable to create result from response with null ContentStream");
}

AsyncServerSentEventEnumerable enumerable = new(response.ContentStream);
IAsyncEnumerable<SseItem<byte[]>> enumerable = SseParser.Create(response.ContentStream, (_, bytes) => bytes.ToArray()).EnumerateAsync();
return enumerable.GetAsyncEnumerator(_cancellationToken);
}

3 changes: 2 additions & 1 deletion src/Custom/Assistants/Streaming/StreamingUpdate.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System.Collections.Generic;
using System.Net.ServerSentEvents;
using System.Text.Json;

namespace OpenAI.Assistants;
@@ -38,7 +39,7 @@ internal StreamingUpdate(StreamingUpdateReason updateKind)
UpdateKind = updateKind;
}

internal static IEnumerable<StreamingUpdate> FromEvent(ServerSentEvent sseItem)
internal static IEnumerable<StreamingUpdate> FromEvent(SseItem<byte[]> sseItem)
{
StreamingUpdateReason updateKind = StreamingUpdateReasonExtensions.FromSseEventLabel(sseItem.EventType);
using JsonDocument dataDocument = JsonDocument.Parse(sseItem.Data);
11 changes: 6 additions & 5 deletions src/Custom/Assistants/Streaming/StreamingUpdateCollection.cs
Original file line number Diff line number Diff line change
@@ -4,6 +4,7 @@
using System.Collections;
using System.Collections.Generic;
using System.Diagnostics;
using System.Net.ServerSentEvents;

#nullable enable

@@ -30,7 +31,7 @@ public override IEnumerator<StreamingUpdate> GetEnumerator()

private sealed class StreamingUpdateEnumerator : IEnumerator<StreamingUpdate>
{
private const string _terminalData = "[DONE]";
private static ReadOnlySpan<byte> TerminalData => "[DONE]"u8;

private readonly Func<ClientResult> _getResult;
private readonly StreamingUpdateCollection _enumerable;
@@ -42,7 +43,7 @@ private sealed class StreamingUpdateEnumerator : IEnumerator<StreamingUpdate>
// // get _updates from sse event
// foreach (var update in _updates) { ... }
// }
private IEnumerator<ServerSentEvent>? _events;
private IEnumerator<SseItem<byte[]>>? _events;
private IEnumerator<StreamingUpdate>? _updates;

private StreamingUpdate? _current;
@@ -81,7 +82,7 @@ public bool MoveNext()

if (_events.MoveNext())
{
if (_events.Current.Data == _terminalData)
if (_events.Current.Data.AsSpan().SequenceEqual(TerminalData))
{
_current = default;
return false;
@@ -101,7 +102,7 @@ public bool MoveNext()
return false;
}

private IEnumerator<ServerSentEvent> CreateEventEnumerator()
private IEnumerator<SseItem<byte[]>> CreateEventEnumerator()
{
ClientResult result = _getResult();
PipelineResponse response = result.GetRawResponse();
@@ -112,7 +113,7 @@ private IEnumerator<ServerSentEvent> CreateEventEnumerator()
throw new InvalidOperationException("Unable to create result from response with null ContentStream");
}

ServerSentEventEnumerable enumerable = new(response.ContentStream);
IEnumerable<SseItem<byte[]>> enumerable = SseParser.Create(response.ContentStream, (_, bytes) => bytes.ToArray()).Enumerate();
return enumerable.GetEnumerator();
}

Original file line number Diff line number Diff line change
@@ -3,6 +3,7 @@
using System.ClientModel.Primitives;
using System.Collections.Generic;
using System.Diagnostics;
using System.Net.ServerSentEvents;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
@@ -32,7 +33,7 @@ public override IAsyncEnumerator<StreamingChatCompletionUpdate> GetAsyncEnumerat

private sealed class AsyncStreamingChatUpdateEnumerator : IAsyncEnumerator<StreamingChatCompletionUpdate>
{
private const string _terminalData = "[DONE]";
private static ReadOnlySpan<byte> TerminalData => "[DONE]"u8;

private readonly Func<Task<ClientResult>> _getResultAsync;
private readonly AsyncStreamingChatCompletionUpdateCollection _enumerable;
@@ -45,7 +46,7 @@ private sealed class AsyncStreamingChatUpdateEnumerator : IAsyncEnumerator<Strea
// // get _updates from sse event
// foreach (var update in _updates) { ... }
// }
private IAsyncEnumerator<ServerSentEvent>? _events;
private IAsyncEnumerator<SseItem<byte[]>>? _events;
private IEnumerator<StreamingChatCompletionUpdate>? _updates;

private StreamingChatCompletionUpdate? _current;
@@ -85,7 +86,7 @@ async ValueTask<bool> IAsyncEnumerator<StreamingChatCompletionUpdate>.MoveNextAs

if (await _events.MoveNextAsync().ConfigureAwait(false))
{
if (_events.Current.Data == _terminalData)
if (_events.Current.Data.AsSpan().SequenceEqual(TerminalData))
{
_current = default;
return false;
@@ -106,7 +107,7 @@ async ValueTask<bool> IAsyncEnumerator<StreamingChatCompletionUpdate>.MoveNextAs
return false;
}

private async Task<IAsyncEnumerator<ServerSentEvent>> CreateEventEnumeratorAsync()
private async Task<IAsyncEnumerator<SseItem<byte[]>>> CreateEventEnumeratorAsync()
{
ClientResult result = await _getResultAsync().ConfigureAwait(false);
PipelineResponse response = result.GetRawResponse();
@@ -117,7 +118,7 @@ private async Task<IAsyncEnumerator<ServerSentEvent>> CreateEventEnumeratorAsync
throw new InvalidOperationException("Unable to create result from response with null ContentStream");
}

AsyncServerSentEventEnumerable enumerable = new(response.ContentStream);
IAsyncEnumerable<SseItem<byte[]>> enumerable = SseParser.Create(response.ContentStream, (_, bytes) => bytes.ToArray()).EnumerateAsync();
return enumerable.GetAsyncEnumerator(_cancellationToken);
}

Original file line number Diff line number Diff line change
@@ -4,6 +4,7 @@
using System.Collections;
using System.Collections.Generic;
using System.Diagnostics;
using System.Net.ServerSentEvents;
using System.Text.Json;

#nullable enable
@@ -31,7 +32,7 @@ public override IEnumerator<StreamingChatCompletionUpdate> GetEnumerator()

private sealed class StreamingChatUpdateEnumerator : IEnumerator<StreamingChatCompletionUpdate>
{
private const string _terminalData = "[DONE]";
private static ReadOnlySpan<byte> TerminalData => "[DONE]"u8;

private readonly Func<ClientResult> _getResult;
private readonly StreamingChatCompletionUpdateCollection _enumerable;
@@ -43,7 +44,7 @@ private sealed class StreamingChatUpdateEnumerator : IEnumerator<StreamingChatCo
// // get _updates from sse event
// foreach (var update in _updates) { ... }
// }
private IEnumerator<ServerSentEvent>? _events;
private IEnumerator<SseItem<byte[]>>? _events;
private IEnumerator<StreamingChatCompletionUpdate>? _updates;

private StreamingChatCompletionUpdate? _current;
@@ -82,7 +83,7 @@ public bool MoveNext()

if (_events.MoveNext())
{
if (_events.Current.Data == _terminalData)
if (_events.Current.Data.AsSpan().SequenceEqual(TerminalData))
{
_current = default;
return false;
@@ -103,7 +104,7 @@ public bool MoveNext()
return false;
}

private IEnumerator<ServerSentEvent> CreateEventEnumerator()
private IEnumerator<SseItem<byte[]>> CreateEventEnumerator()
{
ClientResult result = _getResult();
PipelineResponse response = result.GetRawResponse();
@@ -114,7 +115,7 @@ private IEnumerator<ServerSentEvent> CreateEventEnumerator()
throw new InvalidOperationException("Unable to create result from response with null ContentStream");
}

ServerSentEventEnumerable enumerable = new(response.ContentStream);
IEnumerable<SseItem<byte[]>> enumerable = SseParser.Create(response.ContentStream, (_, bytes) => bytes.ToArray()).Enumerate();
return enumerable.GetEnumerator();
}

7 changes: 7 additions & 0 deletions src/OpenAI.csproj
Original file line number Diff line number Diff line change
@@ -41,6 +41,13 @@
<NoWarn>$(NoWarn),0169</NoWarn>
</PropertyGroup>

<PropertyGroup Condition="'$(TargetFramework)' == 'netstandard2.0'">
<!-- Allow use of unsafe code, for System.Net.ServerSentEvents polyfill on netstandard2.0
TODO https://github.com/openai/openai-dotnet/issues/41: Remove once polyfill for
System.Net.ServerSentEvents is removed in favor of referencing the package -->
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
</PropertyGroup>

<PropertyGroup Condition="'$(GITHUB_ACTIONS)' == 'true'">
<!-- Normalize stored file paths in symbols when in a CI build. -->
<ContinuousIntegrationBuild>true</ContinuousIntegrationBuild>
82 changes: 0 additions & 82 deletions src/Utility/AsyncServerSentEventEnumerable.cs

This file was deleted.

24 changes: 0 additions & 24 deletions src/Utility/ServerSentEvent.cs

This file was deleted.

Loading