Skip to content

Commit 97a81a2

Browse files
authored
Send GenerateContentRequest in CountTokensRequest (#175)
1 parent d8b1fbb commit 97a81a2

File tree

4 files changed

+158
-4
lines changed

4 files changed

+158
-4
lines changed

Sources/GoogleAI/CountTokensRequest.swift

+2-2
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ import Foundation
1717
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *)
1818
struct CountTokensRequest {
1919
let model: String
20-
let contents: [ModelContent]
20+
let generateContentRequest: GenerateContentRequest
2121
let options: RequestOptions
2222
}
2323

@@ -42,7 +42,7 @@ public struct CountTokensResponse {
4242
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *)
4343
extension CountTokensRequest: Encodable {
4444
enum CodingKeys: CodingKey {
45-
case contents
45+
case generateContentRequest
4646
}
4747
}
4848

Sources/GoogleAI/GenerateContentRequest.swift

+1
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ struct GenerateContentRequest {
3131
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *)
3232
extension GenerateContentRequest: Encodable {
3333
enum CodingKeys: String, CodingKey {
34+
case model
3435
case contents
3536
case generationConfig
3637
case safetySettings

Sources/GoogleAI/GenerativeModel.swift

+11-2
Original file line numberDiff line numberDiff line change
@@ -325,9 +325,18 @@ public final class GenerativeModel {
325325
public func countTokens(_ content: @autoclosure () throws -> [ModelContent]) async throws
326326
-> CountTokensResponse {
327327
do {
328-
let countTokensRequest = try CountTokensRequest(
328+
let generateContentRequest = try GenerateContentRequest(model: modelResourceName,
329+
contents: content(),
330+
generationConfig: generationConfig,
331+
safetySettings: safetySettings,
332+
tools: tools,
333+
toolConfig: toolConfig,
334+
systemInstruction: systemInstruction,
335+
isStreaming: false,
336+
options: requestOptions)
337+
let countTokensRequest = CountTokensRequest(
329338
model: modelResourceName,
330-
contents: content(),
339+
generateContentRequest: generateContentRequest,
331340
options: requestOptions
332341
)
333342
return try await generativeAIService.loadRequest(request: countTokensRequest)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
// Copyright 2024 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import Foundation
16+
import XCTest
17+
18+
@testable import GoogleGenerativeAI
19+
20+
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *)
21+
final class GenerateContentRequestTests: XCTestCase {
22+
let encoder = JSONEncoder()
23+
let role = "test-role"
24+
let prompt = "test-prompt"
25+
let modelName = "test-model-name"
26+
27+
override func setUp() {
28+
encoder.outputFormatting = .init(
29+
arrayLiteral: .prettyPrinted, .sortedKeys, .withoutEscapingSlashes
30+
)
31+
}
32+
33+
// MARK: GenerateContentRequest Encoding
34+
35+
func testEncodeRequest_allFieldsIncluded() throws {
36+
let content = [ModelContent(role: role, parts: prompt)]
37+
let request = GenerateContentRequest(
38+
model: modelName,
39+
contents: content,
40+
generationConfig: GenerationConfig(temperature: 0.5),
41+
safetySettings: [SafetySetting(
42+
harmCategory: .dangerousContent,
43+
threshold: .blockLowAndAbove
44+
)],
45+
tools: [Tool(functionDeclarations: [FunctionDeclaration(
46+
name: "test-function-name",
47+
description: "test-function-description",
48+
parameters: nil
49+
)])],
50+
toolConfig: ToolConfig(functionCallingConfig: FunctionCallingConfig(mode: .auto)),
51+
systemInstruction: ModelContent(role: "system", parts: "test-system-instruction"),
52+
isStreaming: false,
53+
options: RequestOptions()
54+
)
55+
56+
let jsonData = try encoder.encode(request)
57+
58+
let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8))
59+
XCTAssertEqual(json, """
60+
{
61+
"contents" : [
62+
{
63+
"parts" : [
64+
{
65+
"text" : "\(prompt)"
66+
}
67+
],
68+
"role" : "\(role)"
69+
}
70+
],
71+
"generationConfig" : {
72+
"temperature" : 0.5
73+
},
74+
"model" : "\(modelName)",
75+
"safetySettings" : [
76+
{
77+
"category" : "HARM_CATEGORY_DANGEROUS_CONTENT",
78+
"threshold" : "BLOCK_LOW_AND_ABOVE"
79+
}
80+
],
81+
"systemInstruction" : {
82+
"parts" : [
83+
{
84+
"text" : "test-system-instruction"
85+
}
86+
],
87+
"role" : "system"
88+
},
89+
"toolConfig" : {
90+
"functionCallingConfig" : {
91+
"mode" : "AUTO"
92+
}
93+
},
94+
"tools" : [
95+
{
96+
"functionDeclarations" : [
97+
{
98+
"description" : "test-function-description",
99+
"name" : "test-function-name",
100+
"parameters" : {
101+
"type" : "OBJECT"
102+
}
103+
}
104+
]
105+
}
106+
]
107+
}
108+
""")
109+
}
110+
111+
func testEncodeRequest_optionalFieldsOmitted() throws {
112+
let content = [ModelContent(role: role, parts: prompt)]
113+
let request = GenerateContentRequest(
114+
model: modelName,
115+
contents: content,
116+
generationConfig: nil,
117+
safetySettings: nil,
118+
tools: nil,
119+
toolConfig: nil,
120+
systemInstruction: nil,
121+
isStreaming: false,
122+
options: RequestOptions()
123+
)
124+
125+
let jsonData = try encoder.encode(request)
126+
127+
let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8))
128+
XCTAssertEqual(json, """
129+
{
130+
"contents" : [
131+
{
132+
"parts" : [
133+
{
134+
"text" : "\(prompt)"
135+
}
136+
],
137+
"role" : "\(role)"
138+
}
139+
],
140+
"model" : "\(modelName)"
141+
}
142+
""")
143+
}
144+
}

0 commit comments

Comments
 (0)