Skip to content

Commit 478e8a2

Browse files
authored
fix: fix openai prompt behavior on RAI errors (#2279)
1 parent 392f601 commit 478e8a2

File tree

4 files changed

+52
-12
lines changed

4 files changed

+52
-12
lines changed

cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ import org.apache.http.entity.AbstractHttpEntity
1313
import org.apache.spark.ml.param.{BooleanParam, Param, ParamMap, ParamValidators}
1414
import org.apache.spark.ml.util.Identifiable
1515
import org.apache.spark.ml.{ComplexParamsReadable, ComplexParamsWritable, Transformer}
16+
import org.apache.spark.sql.Row.unapplySeq
17+
import org.apache.spark.sql.catalyst.encoders.RowEncoder
1618
import org.apache.spark.sql.functions.udf
1719
import org.apache.spark.sql.types.{DataType, StructType}
1820
import org.apache.spark.sql.{Column, DataFrame, Dataset, Row, functions => F, types => T}
@@ -78,7 +80,7 @@ class OpenAIPrompt(override val uid: String) extends Transformer
7880

7981
def setSystemPrompt(value: String): this.type = set(systemPrompt, value)
8082

81-
private val defaultSystemPrompt = "You are an AI chatbot who wants to answer user's questions and complete tasks. " +
83+
private val defaultSystemPrompt = "You are an AI chatbot who wants to answer user's questions and complete tasks. " +
8284
"Follow their instructions carefully and be brief if they don't say otherwise."
8385

8486
setDefault(
@@ -100,6 +102,27 @@ class OpenAIPrompt(override val uid: String) extends Transformer
100102
"promptTemplate", "outputCol", "postProcessing", "postProcessingOptions", "dropPrompt", "dropMessages",
101103
"systemPrompt")
102104

105+
private def addRAIErrors(df: DataFrame, errorCol: String, outputCol: String): DataFrame = {
106+
val openAIResultFromRow = ChatCompletionResponse.makeFromRowConverter
107+
df.map({ row =>
108+
val originalOutput = Option(row.getAs[Row](outputCol))
109+
.map({ row => openAIResultFromRow(row).choices.head })
110+
val isFiltered = originalOutput
111+
.map(output => Option(output.message.content).isEmpty)
112+
.getOrElse(false)
113+
114+
if (isFiltered) {
115+
val updatedRowSeq = row.toSeq.updated(
116+
row.fieldIndex(errorCol),
117+
Row(originalOutput.get.finish_reason, null) //scalastyle:ignore null
118+
)
119+
Row.fromSeq(updatedRowSeq)
120+
} else {
121+
row
122+
}
123+
})(RowEncoder(df.schema))
124+
}
125+
103126
override def transform(dataset: Dataset[_]): DataFrame = {
104127
import com.microsoft.azure.synapse.ml.core.schema.DatasetExtensions._
105128

@@ -120,8 +143,10 @@ class OpenAIPrompt(override val uid: String) extends Transformer
120143
val dfTemplated = df.withColumn(messageColName, createMessagesUDF(promptCol))
121144
val completionNamed = chatCompletion.setMessagesCol(messageColName)
122145

123-
val results = completionNamed
124-
.transform(dfTemplated)
146+
val transformed = addRAIErrors(
147+
completionNamed.transform(dfTemplated), chatCompletion.getErrorCol, chatCompletion.getOutputCol)
148+
149+
val results = transformed
125150
.withColumn(getOutputCol,
126151
getParser.parse(F.element_at(F.col(completionNamed.getOutputCol).getField("choices"), 1)
127152
.getField("message").getField("content")))
@@ -155,19 +180,19 @@ class OpenAIPrompt(override val uid: String) extends Transformer
155180
}, dataset.columns.length)
156181
}
157182

158-
private val legacyModels = Set("ada","babbage", "curie", "davinci",
183+
private val legacyModels = Set("ada", "babbage", "curie", "davinci",
159184
"text-ada-001", "text-babbage-001", "text-curie-001", "text-davinci-002", "text-davinci-003",
160185
"code-cushman-001", "code-davinci-002")
161186

162187
private def openAICompletion: OpenAIServicesBase = {
163188

164189
val completion: OpenAIServicesBase =
165-
if (legacyModels.contains(getDeploymentName)) {
166-
new OpenAICompletion()
167-
}
168-
else {
169-
new OpenAIChatCompletion()
170-
}
190+
if (legacyModels.contains(getDeploymentName)) {
191+
new OpenAICompletion()
192+
}
193+
else {
194+
new OpenAIChatCompletion()
195+
}
171196
// apply all parameters
172197
extractParamMap().toSeq
173198
.filter(p => !localParamNames.contains(p.param.name))

cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAISchemas.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,15 @@ case class OpenAIChatChoice(message: OpenAIMessage,
4141
index: Long,
4242
finish_reason: String)
4343

44+
case class OpenAIUsage(completion_tokens: Long, prompt_tokens: Long, total_tokens: Long)
45+
4446
case class ChatCompletionResponse(id: String,
4547
`object`: String,
4648
created: String,
4749
model: String,
48-
choices: Seq[OpenAIChatChoice])
50+
choices: Seq[OpenAIChatChoice],
51+
system_fingerprint: Option[String],
52+
usage: Option[OpenAIUsage])
4953

5054
object ChatCompletionResponse extends SparkBindings[ChatCompletionResponse]
5155

cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPromptSuite.scala

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ import com.microsoft.azure.synapse.ml.Secrets.getAccessToken
77
import com.microsoft.azure.synapse.ml.core.test.base.Flaky
88
import com.microsoft.azure.synapse.ml.core.test.fuzzing.{TestObject, TransformerFuzzing}
99
import org.apache.spark.ml.util.MLReadable
10-
import org.apache.spark.sql.DataFrame
10+
import org.apache.spark.sql.{DataFrame, Row}
1111
import org.apache.spark.sql.functions.col
1212
import org.scalactic.Equality
1313

@@ -35,6 +35,16 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK
3535
(null, "none") //scalastyle:ignore null
3636
).toDF("text", "category")
3737

38+
test("RAI Usage") {
39+
val result = prompt
40+
.setDeploymentName(deploymentNameGpt4)
41+
.setPromptTemplate("Tell me about a graphically disgusting movie in detail")
42+
.transform(df)
43+
.select(prompt.getErrorCol)
44+
.collect().head.getAs[Row](0)
45+
assert(Option(result).nonEmpty)
46+
}
47+
3848
test("Basic Usage") {
3949
val nonNullCount = prompt
4050
.setPromptTemplate("here is a comma separated list of 5 {category}: {text}, ")

core/src/main/scala/com/microsoft/azure/synapse/ml/io/http/SimpleHTTPTransformer.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ trait HasErrorCol extends Params {
3030
}
3131

3232
object ErrorUtils extends Serializable {
33+
3334
val ErrorSchema: StructType = new StructType()
3435
.add("response", StringType, nullable = true)
3536
.add("status", StatusLineData.schema, nullable = true)

0 commit comments

Comments
 (0)