@@ -199,14 +199,22 @@ class LLInputValidator(BaseDatasetValidator):
199
199
"""
200
200
201
201
input_variable_names : Optional [List [str ]] = None
202
+ context_column_name : Optional [str ] = None
203
+ question_column_name : Optional [str ] = None
202
204
203
205
def _validate_inputs (self ):
204
206
"""Validates LLM inputs."""
205
207
# Setting the attributes needed for the validations
206
208
self .input_variable_names = self .dataset_config .get ("inputVariableNames" )
209
+ self .context_column_name = self .dataset_config .get ("contextColumnName" )
210
+ self .question_column_name = self .dataset_config .get ("questionColumnName" )
207
211
208
212
if self .input_variable_names :
209
213
self ._validate_input_variables ()
214
+ if self .context_column_name :
215
+ self ._validate_context ()
216
+ if self .question_column_name :
217
+ self ._validate_question ()
210
218
211
219
def _validate_input_variables (self ):
212
220
"""Validates the data in the input variables columns."""
@@ -234,6 +242,44 @@ def _validate_input_variables(self):
234
242
"`inputVariableNames` do not exceed the maximum character limit."
235
243
)
236
244
245
+ def _validate_context (self ):
246
+ """Validations on the ground truth column."""
247
+ if self .context_column_name not in self .dataset_df .columns :
248
+ self .failed_validations .append (
249
+ f"The context column `{ self .context_column_name } ` specified as"
250
+ " `contextColumnName` is not in the dataset."
251
+ )
252
+ elif not hasattr (self .dataset_df [self .context_column_name ], "str" ):
253
+ self .failed_validations .append (
254
+ f"The context column `{ self .context_column_name } ` specified as"
255
+ " `contextColumnName` is not a string column."
256
+ )
257
+ elif exceeds_character_limit (self .dataset_df , self .context_column_name ):
258
+ self .failed_validations .append (
259
+ f"The ground truth column `{ self .context_column_name } ` specified as"
260
+ " `contextColumnName` contains strings that exceed the "
261
+ f" { constants .MAXIMUM_CHARACTER_LIMIT } character limit."
262
+ )
263
+
264
+ def _validate_question (self ):
265
+ """Validations on the ground truth column."""
266
+ if self .question_column_name not in self .dataset_df .columns :
267
+ self .failed_validations .append (
268
+ f"The question column `{ self .question_column_name } ` specified as"
269
+ " `questionColumnName` is not in the dataset."
270
+ )
271
+ elif not hasattr (self .dataset_df [self .question_column_name ], "str" ):
272
+ self .failed_validations .append (
273
+ f"The question column `{ self .question_column_name } ` specified as"
274
+ " `questionColumnName` is not a string column."
275
+ )
276
+ elif exceeds_character_limit (self .dataset_df , self .question_column_name ):
277
+ self .failed_validations .append (
278
+ f"The ground truth column `{ self .question_column_name } ` specified as"
279
+ " `questionColumnName` contains strings that exceed the "
280
+ f" { constants .MAXIMUM_CHARACTER_LIMIT } character limit."
281
+ )
282
+
237
283
@staticmethod
238
284
def _input_variables_not_castable_to_str (
239
285
dataset_df : pd .DataFrame ,
0 commit comments