@@ -13,6 +13,8 @@ import org.apache.http.entity.AbstractHttpEntity
1313import org .apache .spark .ml .param .{BooleanParam , Param , ParamMap , ParamValidators }
1414import org .apache .spark .ml .util .Identifiable
1515import org .apache .spark .ml .{ComplexParamsReadable , ComplexParamsWritable , Transformer }
16+ import org .apache .spark .sql .Row .unapplySeq
17+ import org .apache .spark .sql .catalyst .encoders .RowEncoder
1618import org .apache .spark .sql .functions .udf
1719import org .apache .spark .sql .types .{DataType , StructType }
1820import org .apache .spark .sql .{Column , DataFrame , Dataset , Row , functions => F , types => T }
@@ -78,7 +80,7 @@ class OpenAIPrompt(override val uid: String) extends Transformer
7880
7981 def setSystemPrompt (value : String ): this .type = set(systemPrompt, value)
8082
81- private val defaultSystemPrompt = " You are an AI chatbot who wants to answer user's questions and complete tasks. " +
83+ private val defaultSystemPrompt = " You are an AI chatbot who wants to answer user's questions and complete tasks. " +
8284 " Follow their instructions carefully and be brief if they don't say otherwise."
8385
8486 setDefault(
@@ -100,6 +102,27 @@ class OpenAIPrompt(override val uid: String) extends Transformer
100102 " promptTemplate" , " outputCol" , " postProcessing" , " postProcessingOptions" , " dropPrompt" , " dropMessages" ,
101103 " systemPrompt" )
102104
105+ private def addRAIErrors (df : DataFrame , errorCol : String , outputCol : String ): DataFrame = {
106+ val openAIResultFromRow = ChatCompletionResponse .makeFromRowConverter
107+ df.map({ row =>
108+ val originalOutput = Option (row.getAs[Row ](outputCol))
109+ .map({ row => openAIResultFromRow(row).choices.head })
110+ val isFiltered = originalOutput
111+ .map(output => Option (output.message.content).isEmpty)
112+ .getOrElse(false )
113+
114+ if (isFiltered) {
115+ val updatedRowSeq = row.toSeq.updated(
116+ row.fieldIndex(errorCol),
117+ Row (originalOutput.get.finish_reason, null ) // scalastyle:ignore null
118+ )
119+ Row .fromSeq(updatedRowSeq)
120+ } else {
121+ row
122+ }
123+ })(RowEncoder (df.schema))
124+ }
125+
103126 override def transform (dataset : Dataset [_]): DataFrame = {
104127 import com .microsoft .azure .synapse .ml .core .schema .DatasetExtensions ._
105128
@@ -120,8 +143,10 @@ class OpenAIPrompt(override val uid: String) extends Transformer
120143 val dfTemplated = df.withColumn(messageColName, createMessagesUDF(promptCol))
121144 val completionNamed = chatCompletion.setMessagesCol(messageColName)
122145
123- val results = completionNamed
124- .transform(dfTemplated)
146+ val transformed = addRAIErrors(
147+ completionNamed.transform(dfTemplated), chatCompletion.getErrorCol, chatCompletion.getOutputCol)
148+
149+ val results = transformed
125150 .withColumn(getOutputCol,
126151 getParser.parse(F .element_at(F .col(completionNamed.getOutputCol).getField(" choices" ), 1 )
127152 .getField(" message" ).getField(" content" )))
@@ -155,19 +180,19 @@ class OpenAIPrompt(override val uid: String) extends Transformer
155180 }, dataset.columns.length)
156181 }
157182
158- private val legacyModels = Set (" ada" ," babbage" , " curie" , " davinci" ,
183+ private val legacyModels = Set (" ada" , " babbage" , " curie" , " davinci" ,
159184 " text-ada-001" , " text-babbage-001" , " text-curie-001" , " text-davinci-002" , " text-davinci-003" ,
160185 " code-cushman-001" , " code-davinci-002" )
161186
162187 private def openAICompletion : OpenAIServicesBase = {
163188
164189 val completion : OpenAIServicesBase =
165- if (legacyModels.contains(getDeploymentName)) {
166- new OpenAICompletion ()
167- }
168- else {
169- new OpenAIChatCompletion ()
170- }
190+ if (legacyModels.contains(getDeploymentName)) {
191+ new OpenAICompletion ()
192+ }
193+ else {
194+ new OpenAIChatCompletion ()
195+ }
171196 // apply all parameters
172197 extractParamMap().toSeq
173198 .filter(p => ! localParamNames.contains(p.param.name))
0 commit comments