Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initialize components with interfaces and stub implementations #5

Merged
merged 74 commits into from
Apr 17, 2023
Merged
Show file tree
Hide file tree
Changes from 36 commits
Commits
Show all changes
74 commits
Select commit Hold shift + click to select a range
db1eadc
save [wip]
viswavi Apr 4, 2023
a96d203
Add integration test
viswavi Apr 5, 2023
e394de4
Update skeletons, now passing integration test
viswavi Apr 5, 2023
50783c8
Fix integration test
viswavi Apr 5, 2023
db8d4d8
Merge branch 'main' into vijay_init_stubs
viswavi Apr 5, 2023
a912646
Run isort on whole package
viswavi Apr 5, 2023
4b2f563
Remove uses of deprecated Typing objects
viswavi Apr 5, 2023
39e4589
Add docstrings to all functions and classes
viswavi Apr 5, 2023
bbd8d1c
Refactor files into modules
viswavi Apr 6, 2023
49d7e88
Separate interfaces from implementations
viswavi Apr 7, 2023
57fdb40
Run black and isort
viswavi Apr 7, 2023
eb61517
Fix type error
viswavi Apr 7, 2023
156587b
Prevent import unused error in init files
viswavi Apr 7, 2023
bfe8dbf
Use absolute imports in init.py files
viswavi Apr 7, 2023
06d0176
Fix flake8 errors
viswavi Apr 7, 2023
4a10bea
Run isort and black
viswavi Apr 7, 2023
05934b5
Use ABC instead of Protocol
viswavi Apr 7, 2023
e92cf6c
Supress import position errors in test
viswavi Apr 7, 2023
ca5cad9
Disable unnecessary pylint warnings
viswavi Apr 7, 2023
fa6b733
Run isort
viswavi Apr 7, 2023
71e4258
Use type instead of Type
viswavi Apr 7, 2023
f79f1c5
Fix all flake8 docstring errors
viswavi Apr 7, 2023
01f4f40
Add tests
viswavi Apr 12, 2023
8e86387
Remove init file from tests file
viswavi Apr 12, 2023
2747172
Add pytest init file to simplify test running
viswavi Apr 12, 2023
fdac050
Move the generate_datasets implementation into the base class
viswavi Apr 12, 2023
d3ccc2b
Use optional annotations instead of Optional type hints
viswavi Apr 12, 2023
8019b98
Remove redundant type hints from docstrings
viswavi Apr 12, 2023
9b205ae
Remove redundant type info from docstrings
viswavi Apr 12, 2023
70c239c
Rename BaseGenerator -> EmptyDatasetGenerator
viswavi Apr 13, 2023
2a02641
Remove redundant I/O types from the evaluator docstring
viswavi Apr 13, 2023
de9d8ac
Add script to run (skeleton) pipeline locally
viswavi Apr 13, 2023
d7efe5a
Add random seed generator class
viswavi Apr 13, 2023
1e5cef4
Add setup.py to contain dependencies
viswavi Apr 13, 2023
a735cbd
Fix setup.py
viswavi Apr 13, 2023
aae496a
Move dependencies from setup.py to pyproject.toml
viswavi Apr 13, 2023
e9b2ca3
Avoid linter errors in __init__.py files
viswavi Apr 14, 2023
d76bf48
Set global seed generator
viswavi Apr 14, 2023
c48772d
Fix __all__ imports in init files
viswavi Apr 14, 2023
657fb8d
Simplify the generate_datasets implementation
viswavi Apr 14, 2023
cff1368
Require the num_examples argument in generate_examples
viswavi Apr 14, 2023
2cdb0aa
Remove redundant IO-type comment
viswavi Apr 14, 2023
7c335fa
Move Trainer arguments to the train_model function call
viswavi Apr 14, 2023
fe6ca2f
Update prompt2model/dataset_generator/base.py
viswavi Apr 14, 2023
8a15487
Move empty dataset generator to empty.py
viswavi Apr 14, 2023
bbbefeb
Move EmptyDatasetGenerator to separate file
viswavi Apr 14, 2023
5019fac
Remove unnecessary todo from DatasetRetriever
viswavi Apr 14, 2023
7c14c22
Move MockRetriever to separate file
viswavi Apr 14, 2023
9692ba5
Move MockEvaluator to separate file
viswavi Apr 14, 2023
3a4929b
Move MockModelSelector to separate file
viswavi Apr 14, 2023
cfbb116
Avoid storing unnecessary state in ModelSelector
viswavi Apr 14, 2023
12c3534
Avoid storing unnecessary state in Evaluator
viswavi Apr 14, 2023
1bf8589
Make default PromptSpec more general
viswavi Apr 14, 2023
99b92fe
Move MockTrainer to separate file
viswavi Apr 14, 2023
079eb4c
Use absolute module path for test script import
viswavi Apr 14, 2023
e42a9c6
Use absolute imports for prompt2model
viswavi Apr 14, 2023
cba644a
Move Trainer into model selector's state
viswavi Apr 14, 2023
53d26e4
Separate the model selector's functionality into two functions
viswavi Apr 14, 2023
96ed204
Create new architecture component for the model executor
viswavi Apr 17, 2023
ead11b7
Add mock class for DatasetRetriever
viswavi Apr 17, 2023
1eba1fa
Add future annotations import for evaluator/mock.py
viswavi Apr 17, 2023
6c2a4ca
Provide a dataset column to the model executor
viswavi Apr 17, 2023
f7a5021
Provide a dataset column to the evaluator class
viswavi Apr 17, 2023
2c498b8
Pass in an output column to the evaluator
viswavi Apr 17, 2023
0df6536
Update prompt2model/dataset_generator/base.py
viswavi Apr 17, 2023
e0bd1ce
Update prompt2model/dataset_generator/base.py
viswavi Apr 17, 2023
814e786
Update prompt2model/dataset_generator/base.py
viswavi Apr 17, 2023
b15976f
Update prompt2model/evaluator/base.py
viswavi Apr 17, 2023
e174a9f
Update prompt2model/model_executor/base.py
viswavi Apr 17, 2023
5a942d6
Rename EmptyDatasetGenerator -> Mock...
viswavi Apr 17, 2023
d146d66
Remove unused import from dataset_generator/base.py
viswavi Apr 17, 2023
f7a720f
Move empty write_metrics function into base class
viswavi Apr 17, 2023
a0e0b7e
Add base implementation of write_metrics into base class
viswavi Apr 17, 2023
93d3e46
Refactor evaluator and mock executor to match new ModelOutput schema
viswavi Apr 17, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@ build
prompt2model.egg-info
.vscode
.mypy_cache
*.pyc

8 changes: 0 additions & 8 deletions prompt2model/dataset_generator.py

This file was deleted.

3 changes: 3 additions & 0 deletions prompt2model/dataset_generator/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"""Import DatasetGenerator classes."""
from dataset_generator.base import EmptyDatasetGenerator # noqa: F401
from dataset_generator.base import DatasetGenerator, DatasetSplit # noqa: F401
viswavi marked this conversation as resolved.
Show resolved Hide resolved
125 changes: 125 additions & 0 deletions prompt2model/dataset_generator/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
"""An interface for dataset generation."""

from __future__ import annotations # noqa FI58

from abc import ABC, abstractmethod
from enum import Enum

import datasets
import pandas as pd
from prompt_parser import PromptSpec
from utils.rng import ConstantSeedGenerator
viswavi marked this conversation as resolved.
Show resolved Hide resolved


class DatasetSplit(Enum):
"""The split of a dataset."""

TRAIN = "train"
VAL = "val"
TEST = "test"


class DatasetGenerator(ABC):
"""A class for generating datasets from a prompt specification."""

def __init__(
self,
model_config: dict | None = None,
output_dir: str | None = None,
):
"""Construct a dataset generator."""
self.model_config = model_config
self.output_dir = output_dir
self.seed_generator = ConstantSeedGenerator()
viswavi marked this conversation as resolved.
Show resolved Hide resolved

@abstractmethod
def generate_examples(
self,
prompt_spec: PromptSpec,
num_examples: int | None,
viswavi marked this conversation as resolved.
Show resolved Hide resolved
split: DatasetSplit,
) -> datasets.Dataset:
"""Generate data for a single named split of data.

Args:
prompt_spec: A prompt spec (containing a system description).
num_examples: Number of examples in split.
split: Name of dataset split to generate.)
viswavi marked this conversation as resolved.
Show resolved Hide resolved

Returns:
A single dataset split.

"""

def generate_datasets(
self,
prompt_spec: PromptSpec,
num_examples: dict[DatasetSplit, int],
viswavi marked this conversation as resolved.
Show resolved Hide resolved
) -> datasets.DatasetDict:
"""Generate training/validation/testing datasets from a prompt.

Args:
prompt_spec: A prompt specification.
num_examples: Number of examples per split (train/val/test/etc).

Returns:
A DatasetDict containing train, validation, and test splits.
"""
assert num_examples.keys() == {
DatasetSplit.TRAIN,
DatasetSplit.VAL,
DatasetSplit.TEST,
}

train_examples = self.generate_examples(
prompt_spec, num_examples[DatasetSplit.TRAIN], split=DatasetSplit.TRAIN
)
val_examples = self.generate_examples(
prompt_spec, num_examples[DatasetSplit.VAL], split=DatasetSplit.VAL
)
test_examples = self.generate_examples(
prompt_spec, num_examples[DatasetSplit.TEST], split=DatasetSplit.TEST
)

dataset_dict = datasets.DatasetDict(
{
DatasetSplit.TRAIN: train_examples,
DatasetSplit.VAL: val_examples,
DatasetSplit.TEST: test_examples,
}
)
viswavi marked this conversation as resolved.
Show resolved Hide resolved

if self.output_dir:
viswavi marked this conversation as resolved.
Show resolved Hide resolved
dataset_dict.save_to_disk(self.output_dir)

return dataset_dict


class EmptyDatasetGenerator(DatasetGenerator):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could consider moving this to prompt2model/dataset_generator/empty.py

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @viswavi Vijay, I believe this comment has already been addressed, but it seems like the comments have not yet been resolved. Is there anything left that I can do to offer my assistance?

"""A class for generating empty datasets (for testing purposes)."""

def generate_examples(
self,
prompt_spec: PromptSpec,
num_examples: int | None,
viswavi marked this conversation as resolved.
Show resolved Hide resolved
split: DatasetSplit,
) -> datasets.Dataset:
"""Create empty versions of the datasets, for testing.

Args:
prompt_spec: A prompt specification.
num_examples: Number of examples in split.
split: Name of dataset split to generate.)
viswavi marked this conversation as resolved.
Show resolved Hide resolved

Returns:
A single dataset split.

"""
_ = prompt_spec, split # suppress unused variable warnings
if num_examples is None:
raise NotImplementedError
else:
col_values = ["" for i in range(num_examples)]
viswavi marked this conversation as resolved.
Show resolved Hide resolved
# Construct empty-valued dataframe with length matching num_examples.
df = pd.DataFrame.from_dict({"test_col": col_values})
return datasets.Dataset.from_pandas(df)
5 changes: 0 additions & 5 deletions prompt2model/dataset_retriever.py

This file was deleted.

3 changes: 3 additions & 0 deletions prompt2model/dataset_retriever/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"""Import DatasetRetriever classes."""
from dataset_retriever.base import BaseRetriever # noqa: F401
from dataset_retriever.base import DatasetRetriever # noqa: F401
viswavi marked this conversation as resolved.
Show resolved Hide resolved
41 changes: 41 additions & 0 deletions prompt2model/dataset_retriever/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""An interface for dataset retrieval."""

from abc import ABC, abstractmethod

import datasets
import pandas as pd
from prompt_parser import PromptSpec


# pylint: disable=too-few-public-methods
class DatasetRetriever(ABC):
"""A class for retrieving datasets.

TO IMPLEMENT IN SUBCLASSES:
def __init__(self):
'''Construct a search index from HuggingFace Datasets.'''
viswavi marked this conversation as resolved.
Show resolved Hide resolved
"""

@abstractmethod
def retrieve_datasets(self, prompt_spec: PromptSpec) -> list[datasets.Dataset]:
"""Retrieve datasets from a prompt specification.

Args:
prompt_spec: A prompt spec (containing a system description).

Returns:
A list of retrieved datasets.

"""


class BaseRetriever(DatasetRetriever):
viswavi marked this conversation as resolved.
Show resolved Hide resolved
"""A class for retrieving datasets."""

def __init__(self):
"""Construct a mock dataset retriever."""

def retrieve_datasets(self, prompt_spec: PromptSpec) -> list[datasets.Dataset]:
"""Return a single empty dataset for testing purposes."""
_ = prompt_spec # suppress unused variable warning
return [datasets.Dataset.from_pandas(pd.DataFrame({}))]
8 changes: 0 additions & 8 deletions prompt2model/demo_creator.py

This file was deleted.

23 changes: 23 additions & 0 deletions prompt2model/demo_creator/gradio_creator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""An interface for creating Gradio demos automatically."""

import gradio as gr
import transformers
from prompt_parser.base import PromptSpec


def create_gradio(
model: transformers.PreTrainedModel, prompt_spec: PromptSpec
) -> gr.Interface:
"""Create a Gradio interface automatically.

Args:
model: A trained model to expose via a Gradio interface.
prompt_spec: A PromptSpec to help choose the visual interface.

Returns:
A Gradio interface for interacting with the model.

"""
_ = model, prompt_spec # suppress unused variable warnings
dummy_interface = gr.Interface(lambda input: None, "textbox", "label")
return dummy_interface
9 changes: 0 additions & 9 deletions prompt2model/evaluator.py

This file was deleted.

2 changes: 2 additions & 0 deletions prompt2model/evaluator/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
"""Import evaluator classes."""
from evaluator.base import BaseEvaluator, Evaluator # noqa: F401
viswavi marked this conversation as resolved.
Show resolved Hide resolved
69 changes: 69 additions & 0 deletions prompt2model/evaluator/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
"""An interface for automatic model evaluation."""

from __future__ import annotations # noqa FI58

from abc import ABC, abstractmethod
from typing import Any

import datasets
import transformers
from prompt_parser.base import PromptSpec


class Evaluator(ABC):
"""An interface for automatic model evaluation."""

@abstractmethod
def evaluate_model(self, model: transformers.PreTrainedModel) -> dict[str, Any]:
"""Evaluate a model on a test set.

Args:
model: The model to evaluate.

Returns:
A dictionary of metric values to return.

"""
viswavi marked this conversation as resolved.
Show resolved Hide resolved

@abstractmethod
def write_metrics(self, metrics_dict: dict[str, Any], metrics_path: str) -> None:
"""Write or display metrics to a file.

Args:
metrics_dict: A dictionary of metrics to write.
metrics_path: The file path to write metrics to.

"""
viswavi marked this conversation as resolved.
Show resolved Hide resolved


class BaseEvaluator(Evaluator):
viswavi marked this conversation as resolved.
Show resolved Hide resolved
"""A dummy evaluator that always returns the same metric value."""

def __init__(
self,
dataset: datasets.Dataset,
metrics: list[datasets.Metric] | None = None,
prompt_spec: PromptSpec | None = None,
) -> None:
"""Initialize the evaluation setting.

Args:
dataset: The dataset to evaluate metrics on.
metrics: (Optional) The metrics to use.
prompt_spec: (Optional) A PromptSpec to infer the metrics from.

"""
self.test_data = dataset
self.metrics = metrics
self.prompt_spec = prompt_spec
viswavi marked this conversation as resolved.
Show resolved Hide resolved

def evaluate_model(
self,
model: transformers.PreTrainedModel,
) -> dict[str, Any]:
"""Return empty metrics dictionary."""
return {}

def write_metrics(self, metrics_dict: dict[str, Any], metrics_path: str) -> None:
"""Do nothing."""
_ = metrics_dict, metrics_path # suppress unused variable warnings
12 changes: 0 additions & 12 deletions prompt2model/model_selector.py

This file was deleted.

3 changes: 3 additions & 0 deletions prompt2model/model_selector/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"""Import model selector classes."""
from model_selector.base import DefaultParameterSelector # noqa: F401
from model_selector.base import ModelSelector # noqa: F401
viswavi marked this conversation as resolved.
Show resolved Hide resolved
Loading