Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
</Choose>

<ItemGroup>
<PackageReference Include="Microsoft.Extensions.AI.Abstractions" Version="9.7.0" />
<PackageReference Include="Microsoft.Extensions.AI.Abstractions" Version="9.8.0" />
</ItemGroup>

<ItemGroup>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
</Choose>

<ItemGroup>
<PackageReference Include="Microsoft.Extensions.AI.Abstractions" Version="9.7.0" />
<PackageReference Include="Microsoft.Extensions.AI.Abstractions" Version="9.8.0" />
</ItemGroup>

<ItemGroup>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,17 @@
<group targetFramework="net472">
<dependency id="AWSSDK.Core" version="4.0.0.4" />
<dependency id="AWSSDK.BedrockRuntime" version="4.0.0.3" />
<dependency id="Microsoft.Extensions.AI.Abstractions" version="9.7.0" />
<dependency id="Microsoft.Extensions.AI.Abstractions" version="9.8.0" />
</group>
<group targetFramework="netstandard2.0">
<dependency id="AWSSDK.Core" version="4.0.0.4" />
<dependency id="AWSSDK.BedrockRuntime" version="4.0.0.3" />
<dependency id="Microsoft.Extensions.AI.Abstractions" version="9.7.0" />
<dependency id="Microsoft.Extensions.AI.Abstractions" version="9.8.0" />
</group>
<group targetFramework="net8.0">
<dependency id="AWSSDK.Core" version="4.0.0.4" />
<dependency id="AWSSDK.BedrockRuntime" version="4.0.0.3" />
<dependency id="Microsoft.Extensions.AI.Abstractions" version="9.7.0" />
<dependency id="Microsoft.Extensions.AI.Abstractions" version="9.8.0" />
</group>
</dependencies>
</metadata>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

using Microsoft.Extensions.AI;
using System;
using System.Diagnostics.CodeAnalysis;

namespace Amazon.BedrockRuntime;

Expand Down Expand Up @@ -53,4 +54,18 @@ public static IEmbeddingGenerator<string, Embedding<float>> AsIEmbeddingGenerato
this IAmazonBedrockRuntime runtime, string? defaultModelId = null, int? defaultModelDimensions = null) =>
runtime is not null ? new BedrockEmbeddingGenerator(runtime, defaultModelId, defaultModelDimensions) :
throw new ArgumentNullException(nameof(runtime));

/// <summary>Gets an <see cref="IImageGenerator"/> for the specified <see cref="IAmazonBedrockRuntime"/> instance.</summary>
/// <param name="runtime">The runtime instance to be represented as an <see cref="IImageGenerator"/>.</param>
/// <param name="defaultModelId">
/// The default model ID to use when no model is specified in a request. If not specified,
/// a model must be provided in the <see cref="ImageGenerationOptions.ModelId"/> passed to <see cref="IImageGenerator.GenerateAsync"/>.
/// </param>
/// <returns>An <see cref="IImageGenerator"/> instance representing the <see cref="IAmazonBedrockRuntime"/> instance.</returns>
/// <exception cref="ArgumentNullException"><paramref name="runtime"/> is <see langword="null"/>.</exception>
[Experimental("MEAI001")]
public static IImageGenerator AsIImageGenerator(
this IAmazonBedrockRuntime runtime, string? defaultModelId = null) =>
runtime is not null ? new BedrockImageGenerator(runtime, defaultModelId) :
throw new ArgumentNullException(nameof(runtime));
}
39 changes: 36 additions & 3 deletions extensions/src/AWSSDK.Extensions.Bedrock.MEAI/BedrockChatClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ public async Task<ChatResponse> GetResponseAsync(

ChatMessage result = new()
{
CreatedAt = DateTimeOffset.UtcNow,
RawRepresentation = response.Output?.Message,
Role = ChatRole.Assistant,
MessageId = Guid.NewGuid().ToString("N"),
Expand All @@ -97,6 +98,21 @@ public async Task<ChatResponse> GetResponseAsync(
result.Contents.Add(new TextContent(text) { RawRepresentation = content });
}

if (content.CitationsContent is { } citations)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

im not a huge fan of the ! null checks everywhere (unless there was a specific reason you needed them?). wondering if we can do something like below instead?

i havent tested it but it gives the general idea

if (content.CitationsContent is { } citations)
{
    for (int i = 0; i < Math.Min(citations.Citations?.Count ?? 0, citations.Content?.Count ?? 0); i++)
    {
        var contentItem = citations.Content?[i];
        var citationItem = citations.Citations?[i];

        if (contentItem != null && citationItem != null)
        {
            TextContent tc = new(contentItem.Text) { RawRepresentation = contentItem };
            tc.Annotations = new[]
            {
                new CitationAnnotation
                {
                    Title = citationItem.Title,
                    Snippet = citationItem.SourceContent?.Select(c => c.Text).FirstOrDefault(),
                }
            };
            result.Contents.Add(tc);
        }
    }
}

same in other places

{
int count = Math.Min(citations.Citations?.Count ?? 0, citations.Content?.Count ?? 0);
for (int i = 0; i < count; i++)
{
TextContent tc = new(citations.Content![i]?.Text) { RawRepresentation = citations.Content![i] };
tc.Annotations = [new CitationAnnotation()
{
Title = citations.Citations![i].Title,
Snippet = citations.Citations![i].SourceContent?.Select(c => c.Text).FirstOrDefault(),
}];
result.Contents.Add(tc);
Copy link
Preview

Copilot AI Aug 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potential null reference exception when citations.Content[i] is null. The null-forgiving operator ! is used but citations.Content[i] could still be null even after the count check.

Suggested change
result.Contents.Add(tc);
if (citations.Content![i] != null)
{
TextContent tc = new(citations.Content[i]?.Text) { RawRepresentation = citations.Content[i] };
tc.Annotations = [new CitationAnnotation()
{
Title = citations.Citations![i].Title,
Snippet = citations.Citations![i].SourceContent?.Select(c => c.Text).FirstOrDefault(),
}];
result.Contents.Add(tc);
}

Copilot uses AI. Check for mistakes.

Copy link
Preview

Copilot AI Aug 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potential null reference exception when citations.Citations[i] is null. The null-forgiving operator ! is used but citations.Citations[i] could still be null even after the count check.

Suggested change
result.Contents.Add(tc);
var contentItem = citations.Content![i];
var citationItem = citations.Citations![i];
if (contentItem != null && citationItem != null)
{
TextContent tc = new(contentItem.Text) { RawRepresentation = contentItem };
tc.Annotations = [new CitationAnnotation()
{
Title = citationItem.Title,
Snippet = citationItem.SourceContent?.Select(c => c.Text).FirstOrDefault(),
}];
result.Contents.Add(tc);
}

Copilot uses AI. Check for mistakes.

Copy link
Preview

Copilot AI Aug 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potential null reference exception when citations.Citations[i] is null. The null-forgiving operator ! is used but citations.Citations[i] could still be null even after the count check.

Suggested change
result.Contents.Add(tc);
var citation = citations.Citations![i];
var contentItem = citations.Content![i];
if (citation != null && contentItem != null)
{
TextContent tc = new(contentItem.Text) { RawRepresentation = contentItem };
tc.Annotations = [new CitationAnnotation()
{
Title = citation.Title,
Snippet = citation.SourceContent?.Select(c => c.Text).FirstOrDefault(),
}];
result.Contents.Add(tc);
}

Copilot uses AI. Check for mistakes.

}
}

if (content.ReasoningContent is { ReasoningText.Text: not null } reasoningContent)
{
TextReasoningContent trc = new(reasoningContent.ReasoningText.Text) { RawRepresentation = content };
Expand Down Expand Up @@ -126,7 +142,11 @@ public async Task<ChatResponse> GetResponseAsync(

if (content.Document is { Source.Bytes: { } documentBytes, Format: { } documentFormat })
{
result.Contents.Add(new DataContent(documentBytes.ToArray(), GetMimeType(documentFormat)) { RawRepresentation = content });
result.Contents.Add(new DataContent(documentBytes.ToArray(), GetMimeType(documentFormat))
{
RawRepresentation = content,
Name = content.Document.Name
});
}

if (content.ToolUse is { } toolUse)
Expand All @@ -143,7 +163,7 @@ public async Task<ChatResponse> GetResponseAsync(

return new(result)
{
CreatedAt = DateTimeOffset.UtcNow,
CreatedAt = result.CreatedAt,
FinishReason = response.StopReason is not null ? GetChatFinishReason(response.StopReason) : null,
RawRepresentation = response,
ResponseId = Guid.NewGuid().ToString("N"),
Expand Down Expand Up @@ -205,14 +225,26 @@ public async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(

if (contentBlockDelta.Delta.Text is string text)
{
yield return new(ChatRole.Assistant, text)
ChatResponseUpdate textUpdate = new(ChatRole.Assistant, text)
{
CreatedAt = DateTimeOffset.UtcNow,
MessageId = messageId,
RawRepresentation = update,
FinishReason = finishReason,
ResponseId = responseId,
};

if (contentBlockDelta.Delta.Citation is { } citation &&
(citation.Title is not null || citation.SourceContent is { Count: > 0 }))
{
textUpdate.Contents[0].Annotations = [new CitationAnnotation()
{
Title = citation.Title,
Snippet = citation.SourceContent?.Select(c => c.Text).FirstOrDefault(),
}];
}

yield return textUpdate;
}

if (contentBlockDelta.Delta.ReasoningContent is { Text: not null } reasoningContent)
Expand Down Expand Up @@ -468,6 +500,7 @@ private static List<ContentBlock> CreateContents(ChatMessage message)
{
Source = new() { Bytes = new(dc.Data.ToArray()) },
Format = docFormat,
Name = dc.Name ?? "file",
}
});
}
Expand Down
197 changes: 197 additions & 0 deletions extensions/src/AWSSDK.Extensions.Bedrock.MEAI/BedrockImageGenerator.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://aws.amazon.com/apache2.0
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/

using Amazon.BedrockRuntime.Model;
using Microsoft.Extensions.AI;
using System;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.IO;
using System.Text.Json;
using System.Text.Json.Nodes;
using System.Threading;
using System.Threading.Tasks;

namespace Amazon.BedrockRuntime;

[Experimental("MEAI001")]
internal sealed partial class BedrockImageGenerator : IImageGenerator
{
/// <summary>The wrapped <see cref="IAmazonBedrockRuntime"/> instance.</summary>
private readonly IAmazonBedrockRuntime _runtime;
/// <summary>Default model ID to use when no model is specified in the request.</summary>
private readonly string? _modelId;
/// <summary>Metadata describing the image generator.</summary>
private readonly ImageGeneratorMetadata _metadata;

/// <summary>
/// Initializes a new instance of the <see cref="BedrockImageGenerator"/> class.
/// </summary>
/// <param name="runtime">The <see cref="IAmazonBedrockRuntime"/> instance to wrap.</param>
/// <param name="defaultModelId">Model ID to use as the default when no model ID is specified in a request.</param>
public BedrockImageGenerator(IAmazonBedrockRuntime runtime, string? defaultModelId)
{
Debug.Assert(runtime is not null);

_runtime = runtime!;
_modelId = defaultModelId;

_metadata = new(AmazonBedrockRuntimeExtensions.ProviderName, defaultModelId: defaultModelId);
}

public void Dispose()
{
// Do not dispose of _runtime, as this instance doesn't own it.
}

/// <inheritdoc />

/// <inheritdoc />
public object? GetService(Type serviceType, object? serviceKey)
{
if (serviceType is null)
{
throw new ArgumentNullException(nameof(serviceType));
}

return
serviceKey is not null ? null :
serviceType == typeof(ImageGeneratorMetadata) ? _metadata :
serviceType.IsInstanceOfType(_runtime) ? _runtime :
serviceType.IsInstanceOfType(this) ? this :
null;
}

public async Task<ImageGenerationResponse> GenerateAsync(
ImageGenerationRequest request, ImageGenerationOptions? options = null, CancellationToken cancellationToken = default)
{
if (request is null)
{
throw new ArgumentNullException(nameof(request));
}

int numImages = options?.Count ?? 1;
if (numImages < 1)
{
throw new ArgumentOutOfRangeException(nameof(options), "The number of images must be at least 1.");
}

InvokeModelRequest invokeRequest = options?.RawRepresentationFactory?.Invoke(this) as InvokeModelRequest ?? new();
invokeRequest.ModelId ??= options?.ModelId ?? _modelId;
invokeRequest.Accept ??= "application/json";
invokeRequest.ContentType ??= "application/json";
if (invokeRequest.Body is null)
{
JsonObject body = new();

// Each model has its own way of specifying the prompt and image generation parameters, unfortunately.
// The following logic handles the most common cases today, but may need to be extended for
// future models.

if (invokeRequest.ModelId?.IndexOf("stability", StringComparison.OrdinalIgnoreCase) >= 0)
{
// Stability AI models
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stability models should also permit a starting image we should pass that in for edit. https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-diffusion-3-text-image.html


if (invokeRequest.ModelId?.IndexOf("stable-diffusion", StringComparison.OrdinalIgnoreCase) >= 0)
{
JsonArray textPrompts = new();
for (int i = 0; i < numImages; i++)
{
textPrompts.Add((JsonNode)new JsonObject { ["text"] = request.Prompt ?? "" });
}
body["text_prompts"] = textPrompts;

if (options?.ImageSize?.Width is int width && options.ImageSize?.Height is int height)
{
body["width"] = width;
body["height"] = height;
}
}
else
{
body["prompt"] = request.Prompt ?? "";
}
}
else
{
// Amazon models (e.g. Titan, Nova Canvas)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


body["taskType"] = "TEXT_IMAGE";
body["textToImageParams"] = new JsonObject { ["text"] = request.Prompt ?? "" };

JsonObject imageGenerationConfig = new()
{
["seed"] =
#if NET
Random.Shared.Next(),
#else
new Random().Next(),
#endif
};

if (options?.ImageSize?.Width is int width && options.ImageSize?.Height is int height)
{
imageGenerationConfig["width"] = width;
imageGenerationConfig["height"] = height;
}

if (numImages > 1)
{
imageGenerationConfig["numberOfImages"] = numImages;
}

body["imageGenerationConfig"] = imageGenerationConfig;
}

invokeRequest.Body = new MemoryStream(JsonSerializer.SerializeToUtf8Bytes(body, BedrockJsonContext.Default.JsonNode));
}

InvokeModelResponse rawResponse = await _runtime.InvokeModelAsync(invokeRequest, cancellationToken).ConfigureAwait(false);

ImageGenerationResponse result = new() { RawRepresentation = rawResponse };

using JsonDocument doc = JsonDocument.Parse(rawResponse.Body);
JsonElement root = doc.RootElement;

if (root.TryGetProperty("artifacts", out JsonElement artifactElement) && artifactElement.ValueKind == JsonValueKind.Array)
{
foreach (var element in artifactElement.EnumerateArray())
{
if (element.TryGetProperty("base64", out JsonElement base64Element) &&
base64Element.ValueKind == JsonValueKind.String)
{
result.Contents.Add(new DataContent(Convert.FromBase64String(base64Element.GetString()!), "image/png"));
Copy link
Preview

Copilot AI Aug 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The MIME type is hardcoded as 'image/png' but different models may return different image formats. Consider determining the actual format from the response or making it configurable.

Suggested change
result.Contents.Add(new DataContent(Convert.FromBase64String(base64Element.GetString()!), "image/png"));
string mimeType = "image/png";
if (element.TryGetProperty("mime_type", out JsonElement mimeTypeElement) && mimeTypeElement.ValueKind == JsonValueKind.String)
{
mimeType = mimeTypeElement.GetString()!;
}
else if (element.TryGetProperty("type", out JsonElement typeElement) && typeElement.ValueKind == JsonValueKind.String)
{
mimeType = typeElement.GetString()!;
}
result.Contents.Add(new DataContent(Convert.FromBase64String(base64Element.GetString()!), mimeType));

Copilot uses AI. Check for mistakes.

}
}
}
else if (root.TryGetProperty("images", out JsonElement imagesElement) && imagesElement.ValueKind == JsonValueKind.Array)
{
foreach (var image in imagesElement.EnumerateArray())
{
if (image.ValueKind == JsonValueKind.String)
{
result.Contents.Add(new DataContent(Convert.FromBase64String(image.GetString()!), "image/png"));
Copy link
Preview

Copilot AI Aug 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The MIME type is hardcoded as 'image/png' but different models may return different image formats. Consider determining the actual format from the response or making it configurable.

Suggested change
result.Contents.Add(new DataContent(Convert.FromBase64String(image.GetString()!), "image/png"));
{
var imageBytes = Convert.FromBase64String(image.GetString()!);
var mimeType = ImageFormatHelper.GetImageMimeType(imageBytes);
result.Contents.Add(new DataContent(imageBytes, mimeType));
}

Copilot uses AI. Check for mistakes.

}
}
}

if (result.Contents is not { Count: > 0 })
{
throw new InvalidOperationException("Image generation did not produce any images.");
}

return result;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://aws.amazon.com/apache2.0
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/

#if !NET
namespace System.Diagnostics.CodeAnalysis;

// Polyfill for [Experimental]

[AttributeUsage(
AttributeTargets.Assembly | AttributeTargets.Module | AttributeTargets.Class | AttributeTargets.Struct |
AttributeTargets.Enum | AttributeTargets.Constructor | AttributeTargets.Method | AttributeTargets.Property |
AttributeTargets.Field | AttributeTargets.Event | AttributeTargets.Interface | AttributeTargets.Delegate,
Inherited = false)]
internal sealed class ExperimentalAttribute : Attribute
{
public ExperimentalAttribute(string diagnosticId) => DiagnosticId = diagnosticId;
public string DiagnosticId { get; }
public string? Message { get; set; }
public string? UrlFormat { get; set; }
}
#endif
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
</PropertyGroup>

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add tests for the citation changes?

<ItemGroup>
<PackageReference Include="Microsoft.Extensions.AI.Abstractions" Version="9.7.0" />
<PackageReference Include="Microsoft.Extensions.AI.Abstractions" Version="9.8.0" />
<PackageReference Include="xunit" Version="2.9.2" />
<PackageReference Include="xunit.runner.visualstudio" Version="2.8.2" />
</ItemGroup>
Expand Down