Skip to content

Commit 8437057

Browse files
committed
Allow specifying question and context columns in LLM dataset configs
1 parent a87a1d1 commit 8437057

File tree

3 files changed

+52
-1
lines changed

3 files changed

+52
-1
lines changed

openlayer/schemas.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,12 @@ class LLMInputSchema(BaseDatasetSchema):
101101
inputVariableNames = ma.fields.List(
102102
ma.fields.Str(validate=COLUMN_NAME_VALIDATION_LIST), required=True
103103
)
104+
contextColumnName = ma.fields.Str(
105+
validate=COLUMN_NAME_VALIDATION_LIST, allow_none=True, load_default=None
106+
)
107+
questionColumnName = ma.fields.Str(
108+
validate=COLUMN_NAME_VALIDATION_LIST, allow_none=True, load_default=None
109+
)
104110

105111

106112
class TabularInputSchema(BaseDatasetSchema):

openlayer/validators/dataset_validators.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,14 +199,22 @@ class LLInputValidator(BaseDatasetValidator):
199199
"""
200200

201201
input_variable_names: Optional[List[str]] = None
202+
context_column_name: Optional[str] = None
203+
question_column_name: Optional[str] = None
202204

203205
def _validate_inputs(self):
204206
"""Validates LLM inputs."""
205207
# Setting the attributes needed for the validations
206208
self.input_variable_names = self.dataset_config.get("inputVariableNames")
209+
self.context_column_name = self.dataset_config.get("contextColumnName")
210+
self.question_column_name = self.dataset_config.get("questionColumnName")
207211

208212
if self.input_variable_names:
209213
self._validate_input_variables()
214+
if self.context_column_name:
215+
self._validate_context()
216+
if self.question_column_name:
217+
self._validate_question()
210218

211219
def _validate_input_variables(self):
212220
"""Validates the data in the input variables columns."""
@@ -234,6 +242,44 @@ def _validate_input_variables(self):
234242
"`inputVariableNames` do not exceed the maximum character limit."
235243
)
236244

245+
def _validate_context(self):
246+
"""Validations on the ground truth column."""
247+
if self.context_column_name not in self.dataset_df.columns:
248+
self.failed_validations.append(
249+
f"The context column `{self.context_column_name}` specified as"
250+
" `contextColumnName` is not in the dataset."
251+
)
252+
elif not hasattr(self.dataset_df[self.context_column_name], "str"):
253+
self.failed_validations.append(
254+
f"The context column `{self.context_column_name}` specified as"
255+
" `contextColumnName` is not a string column."
256+
)
257+
elif exceeds_character_limit(self.dataset_df, self.context_column_name):
258+
self.failed_validations.append(
259+
f"The ground truth column `{self.context_column_name}` specified as"
260+
" `contextColumnName` contains strings that exceed the "
261+
f" {constants.MAXIMUM_CHARACTER_LIMIT} character limit."
262+
)
263+
264+
def _validate_question(self):
265+
"""Validations on the ground truth column."""
266+
if self.question_column_name not in self.dataset_df.columns:
267+
self.failed_validations.append(
268+
f"The question column `{self.question_column_name}` specified as"
269+
" `questionColumnName` is not in the dataset."
270+
)
271+
elif not hasattr(self.dataset_df[self.question_column_name], "str"):
272+
self.failed_validations.append(
273+
f"The question column `{self.question_column_name}` specified as"
274+
" `questionColumnName` is not a string column."
275+
)
276+
elif exceeds_character_limit(self.dataset_df, self.question_column_name):
277+
self.failed_validations.append(
278+
f"The ground truth column `{self.question_column_name}` specified as"
279+
" `questionColumnName` contains strings that exceed the "
280+
f" {constants.MAXIMUM_CHARACTER_LIMIT} character limit."
281+
)
282+
237283
@staticmethod
238284
def _input_variables_not_castable_to_str(
239285
dataset_df: pd.DataFrame,

setup.cfg

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ install_requires =
4747
openai
4848
pandas
4949
pybars3
50-
requests
5150
requests_toolbelt
5251
requests>=2.28.2
5352
tabulate

0 commit comments

Comments
 (0)