Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 6 additions & 22 deletions packages/firebase_ai/firebase_ai/lib/src/api.dart
Original file line number Diff line number Diff line change
Expand Up @@ -209,25 +209,6 @@ final class UsageMetadata {
final List<ModalityTokenCount>? candidatesTokensDetails;
}

/// Constructe a UsageMetadata with all it's fields.
///
/// Expose access to the private constructor for use within the package..
UsageMetadata createUsageMetadata({
required int? promptTokenCount,
required int? candidatesTokenCount,
required int? totalTokenCount,
required int? thoughtsTokenCount,
required List<ModalityTokenCount>? promptTokensDetails,
required List<ModalityTokenCount>? candidatesTokensDetails,
}) =>
UsageMetadata._(
promptTokenCount: promptTokenCount,
candidatesTokenCount: candidatesTokenCount,
totalTokenCount: totalTokenCount,
thoughtsTokenCount: thoughtsTokenCount,
promptTokensDetails: promptTokensDetails,
candidatesTokensDetails: candidatesTokensDetails);

/// Response candidate generated from a [GenerativeModel].
final class Candidate {
// TODO: token count?
Expand Down Expand Up @@ -1194,7 +1175,7 @@ final class VertexSerialization implements SerializationStrategy {
};
final usageMedata = switch (jsonObject) {
{'usageMetadata': final usageMetadata?} =>
_parseUsageMetadata(usageMetadata),
parseUsageMetadata(usageMetadata),
{'totalTokens': final int totalTokens} =>
UsageMetadata._(totalTokenCount: totalTokens),
_ => null,
Expand Down Expand Up @@ -1324,7 +1305,10 @@ PromptFeedback _parsePromptFeedback(Object jsonObject) {
};
}

UsageMetadata _parseUsageMetadata(Object jsonObject) {
/// Parses a UsageMetadata from a JSON object.
///
/// Expose access to the private helper for use within the package.
UsageMetadata parseUsageMetadata(Object jsonObject) {
if (jsonObject is! Map<String, Object?>) {
throw unhandledFormat('UsageMetadata', jsonObject);
}
Expand Down Expand Up @@ -1355,7 +1339,7 @@ UsageMetadata _parseUsageMetadata(Object jsonObject) {
candidatesTokensDetails.map(_parseModalityTokenCount).toList(),
_ => null,
};
return createUsageMetadata(
return UsageMetadata._(
promptTokenCount: promptTokenCount,
candidatesTokenCount: candidatesTokenCount,
totalTokenCount: totalTokenCount,
Expand Down
40 changes: 4 additions & 36 deletions packages/firebase_ai/firebase_ai/lib/src/developer/api.dart
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,8 @@ import '../api.dart'
SearchEntryPoint,
Segment,
SerializationStrategy,
UsageMetadata,
WebGroundingChunk,
createUsageMetadata;
parseUsageMetadata;
import '../content.dart' show Content, parseContent;
import '../error.dart';
import '../tool.dart' show Tool, ToolConfig;
Expand Down Expand Up @@ -117,13 +116,13 @@ final class DeveloperSerialization implements SerializationStrategy {
_parsePromptFeedback(promptFeedback),
_ => null,
};
final usageMedata = switch (jsonObject) {
final usageMetadata = switch (jsonObject) {
{'usageMetadata': final usageMetadata?} =>
_parseUsageMetadata(usageMetadata),
parseUsageMetadata(usageMetadata),
_ => null,
};
return GenerateContentResponse(candidates, promptFeedback,
usageMetadata: usageMedata);
usageMetadata: usageMetadata);
}

@override
Expand Down Expand Up @@ -238,37 +237,6 @@ PromptFeedback _parsePromptFeedback(Object jsonObject) {
};
}

UsageMetadata _parseUsageMetadata(Object jsonObject) {
if (jsonObject is! Map<String, Object?>) {
throw unhandledFormat('UsageMetadata', jsonObject);
}
final promptTokenCount = switch (jsonObject) {
{'promptTokenCount': final int promptTokenCount} => promptTokenCount,
_ => null,
};
final candidatesTokenCount = switch (jsonObject) {
{'candidatesTokenCount': final int candidatesTokenCount} =>
candidatesTokenCount,
_ => null,
};
final totalTokenCount = switch (jsonObject) {
{'totalTokenCount': final int totalTokenCount} => totalTokenCount,
_ => null,
};
final thoughtsTokenCount = switch (jsonObject) {
{'thoughtsTokenCount': final int thoughtsTokenCount} => thoughtsTokenCount,
_ => null,
};
return createUsageMetadata(
promptTokenCount: promptTokenCount,
candidatesTokenCount: candidatesTokenCount,
totalTokenCount: totalTokenCount,
thoughtsTokenCount: thoughtsTokenCount,
promptTokensDetails: null,
candidatesTokensDetails: null,
);
}

SafetyRating _parseSafetyRating(Object? jsonObject) {
return switch (jsonObject) {
{
Expand Down
41 changes: 41 additions & 0 deletions packages/firebase_ai/firebase_ai/test/developer_api_test.dart
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ void main() {
'candidatesTokenCount': 5,
'totalTokenCount': 15,
'thoughtsTokenCount': 3,
'promptTokensDetails': [
{'modality': 'TEXT', 'tokenCount': 10}
],
'candidatesTokensDetails': [
{'modality': 'TEXT', 'tokenCount': 25}
],
}
};
final response =
Expand All @@ -49,6 +55,15 @@ void main() {
expect(response.usageMetadata!.candidatesTokenCount, 5);
expect(response.usageMetadata!.totalTokenCount, 15);
expect(response.usageMetadata!.thoughtsTokenCount, 3);
expect(response.usageMetadata!.promptTokensDetails, isNotNull);
expect(response.usageMetadata!.promptTokensDetails, hasLength(1));
expect(
response.usageMetadata!.promptTokensDetails!.first.tokenCount, 10);
expect(response.usageMetadata!.candidatesTokensDetails, isNotNull);
expect(response.usageMetadata!.candidatesTokensDetails, hasLength(1));
expect(
response.usageMetadata!.candidatesTokensDetails!.first.tokenCount,
25);
});

test('parses usageMetadata when thoughtsTokenCount is missing', () {
Expand All @@ -69,6 +84,12 @@ void main() {
'candidatesTokenCount': 5,
'totalTokenCount': 15,
// thoughtsTokenCount is missing
'promptTokensDetails': [
{'modality': 'TEXT', 'tokenCount': 10}
],
'candidatesTokensDetails': [
{'modality': 'TEXT', 'tokenCount': 25}
],
}
};
final response =
Expand Down Expand Up @@ -303,6 +324,26 @@ void main() {
});
});

test('parses usageMetadata when token details are missing', () {
final jsonResponse = {
'usageMetadata': {
'promptTokenCount': 10,
'candidatesTokenCount': 25,
'totalTokenCount': 35,
}
};

final response =
DeveloperSerialization().parseGenerateContentResponse(jsonResponse);

expect(response.usageMetadata, isNotNull);
expect(response.usageMetadata!.promptTokenCount, 10);
expect(response.usageMetadata!.candidatesTokenCount, 25);
expect(response.usageMetadata!.totalTokenCount, 35);
expect(response.usageMetadata!.promptTokensDetails, isNull);
expect(response.usageMetadata!.candidatesTokensDetails, isNull);
});

test('parses inlineData part correctly', () {
final inlineData = Uint8List.fromList([1, 2, 3, 4]);
final jsonResponse = {
Expand Down
Loading