Skip to content

Commit d8b1fbb

Browse files
authored
Add responseSchema to GenerationConfig (#176)
1 parent edc9de3 commit d8b1fbb

File tree

2 files changed

+49
-4
lines changed

2 files changed

+49
-4
lines changed

Sources/GoogleAI/GenerationConfig.swift

+10-1
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,12 @@ public struct GenerationConfig {
7070
/// - `application/json`: JSON response in the candidates.
7171
public let responseMIMEType: String?
7272

73+
/// Output response schema of the generated candidate text.
74+
///
75+
/// - Note: This only applies when the specified ``responseMIMEType`` supports a schema; currently
76+
/// this is limited to `application/json`.
77+
public let responseSchema: Schema?
78+
7379
/// Creates a new `GenerationConfig` value.
7480
///
7581
/// - Parameters:
@@ -80,9 +86,11 @@ public struct GenerationConfig {
8086
/// - maxOutputTokens: See ``maxOutputTokens``.
8187
/// - stopSequences: See ``stopSequences``.
8288
/// - responseMIMEType: See ``responseMIMEType``.
89+
/// - responseSchema: See ``responseSchema``.
8390
public init(temperature: Float? = nil, topP: Float? = nil, topK: Int? = nil,
8491
candidateCount: Int? = nil, maxOutputTokens: Int? = nil,
85-
stopSequences: [String]? = nil, responseMIMEType: String? = nil) {
92+
stopSequences: [String]? = nil, responseMIMEType: String? = nil,
93+
responseSchema: Schema? = nil) {
8694
// Explicit init because otherwise if we re-arrange the above variables it changes the API
8795
// surface.
8896
self.temperature = temperature
@@ -92,6 +100,7 @@ public struct GenerationConfig {
92100
self.maxOutputTokens = maxOutputTokens
93101
self.stopSequences = stopSequences
94102
self.responseMIMEType = responseMIMEType
103+
self.responseSchema = responseSchema
95104
}
96105
}
97106

Tests/GoogleAITests/GenerationConfigTests.swift

+39-3
Original file line numberDiff line numberDiff line change
@@ -48,15 +48,20 @@ final class GenerationConfigTests: XCTestCase {
4848
let candidateCount = 2
4949
let maxOutputTokens = 256
5050
let stopSequences = ["END", "DONE"]
51-
let responseMIMEType = "text/plain"
51+
let responseMIMEType = "application/json"
52+
let schemaType = DataType.object
53+
let fieldName = "test-field"
54+
let fieldType = DataType.string
55+
let responseSchema = Schema(type: schemaType, properties: [fieldName: Schema(type: fieldType)])
5256
let generationConfig = GenerationConfig(
5357
temperature: temperature,
5458
topP: topP,
5559
topK: topK,
5660
candidateCount: candidateCount,
5761
maxOutputTokens: maxOutputTokens,
5862
stopSequences: stopSequences,
59-
responseMIMEType: responseMIMEType
63+
responseMIMEType: responseMIMEType,
64+
responseSchema: responseSchema
6065
)
6166

6267
let jsonData = try encoder.encode(generationConfig)
@@ -67,6 +72,14 @@ final class GenerationConfigTests: XCTestCase {
6772
"candidateCount" : \(candidateCount),
6873
"maxOutputTokens" : \(maxOutputTokens),
6974
"responseMIMEType" : "\(responseMIMEType)",
75+
"responseSchema" : {
76+
"properties" : {
77+
"\(fieldName)" : {
78+
"type" : "\(fieldType.rawValue)"
79+
}
80+
},
81+
"type" : "\(schemaType.rawValue)"
82+
},
7083
"stopSequences" : [
7184
"END",
7285
"DONE"
@@ -79,7 +92,7 @@ final class GenerationConfigTests: XCTestCase {
7992
}
8093

8194
func testEncodeGenerationConfig_responseMIMEType() throws {
82-
let mimeType = "image/jpeg"
95+
let mimeType = "text/plain"
8396
let generationConfig = GenerationConfig(responseMIMEType: mimeType)
8497

8598
let jsonData = try encoder.encode(generationConfig)
@@ -91,4 +104,27 @@ final class GenerationConfigTests: XCTestCase {
91104
}
92105
""")
93106
}
107+
108+
func testEncodeGenerationConfig_responseMIMETypeWithSchema() throws {
109+
let mimeType = "application/json"
110+
let schemaType = DataType.array
111+
let arrayItemType = DataType.integer
112+
let schema = Schema(type: schemaType, items: Schema(type: arrayItemType))
113+
let generationConfig = GenerationConfig(responseMIMEType: mimeType, responseSchema: schema)
114+
115+
let jsonData = try encoder.encode(generationConfig)
116+
117+
let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8))
118+
XCTAssertEqual(json, """
119+
{
120+
"responseMIMEType" : "\(mimeType)",
121+
"responseSchema" : {
122+
"items" : {
123+
"type" : "\(arrayItemType.rawValue)"
124+
},
125+
"type" : "\(schemaType.rawValue)"
126+
}
127+
}
128+
""")
129+
}
94130
}

0 commit comments

Comments
 (0)