Skip to content

Commit a167a4e

Browse files
committed
Add base model into client
1 parent bbf59c6 commit a167a4e

File tree

1 file changed

+143
-0
lines changed

1 file changed

+143
-0
lines changed

openlayer/model_runners/base_model.py

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
"""Base class for an Openlayer model."""
2+
3+
import abc
4+
import argparse
5+
import inspect
6+
import json
7+
import os
8+
import time
9+
from dataclasses import dataclass, field
10+
from typing import Any, Dict, Tuple
11+
12+
import pandas as pd
13+
from ..tracing import tracer
14+
15+
16+
@dataclass
17+
class RunReturn:
18+
output: Any
19+
other_fields: Dict[str, Any] = field(default_factory=dict)
20+
21+
22+
class OpenlayerModel(abc.ABC):
23+
"""Base class for an Openlayer model."""
24+
25+
def run_from_cli(self):
26+
# Create the parser
27+
parser = argparse.ArgumentParser(description="Run data through a model.")
28+
29+
# Add the --dataset-path argument
30+
parser.add_argument(
31+
"--dataset-path", type=str, required=True, help="Path to the dataset"
32+
)
33+
parser.add_argument(
34+
"--dataset-name", type=str, required=True, help="Name of the dataset"
35+
)
36+
parser.add_argument(
37+
"--output-path",
38+
type=str,
39+
required=False,
40+
help="Path to dump the results",
41+
default="output",
42+
)
43+
44+
# Parse the arguments
45+
args = parser.parse_args()
46+
47+
return self.batch(
48+
dataset_path=args.dataset_path,
49+
dataset_name=args.dataset_name,
50+
output_dir=args.output_path,
51+
)
52+
53+
def batch(self, dataset_path: str, dataset_name: str, output_dir: str):
54+
# Load the dataset into a pandas DataFrame
55+
df = pd.read_json(dataset_path, orient="records")
56+
57+
# Call the model's run_batch method, passing in the DataFrame
58+
output_df, config = self.run_batch_from_df(df)
59+
output_dir = os.path.join(output_dir, dataset_name)
60+
self.write_output_to_directory(output_df, config, output_dir)
61+
62+
def run_batch_from_df(self, df: pd.DataFrame) -> Tuple[pd.DataFrame, dict]:
63+
"""Function that runs the model and returns the result."""
64+
# Ensure the 'output' column exists
65+
if "output" not in df.columns:
66+
df["output"] = None
67+
68+
# Get the signature of the 'run' method
69+
run_signature = inspect.signature(self.run)
70+
71+
for index, row in df.iterrows():
72+
# Filter row_dict to only include keys that are valid parameters
73+
# for the 'run' method
74+
row_dict = row.to_dict()
75+
filtered_kwargs = {
76+
k: v for k, v in row_dict.items() if k in run_signature.parameters
77+
}
78+
79+
# Call the run method with filtered kwargs
80+
output = self.run(**filtered_kwargs)
81+
82+
df.at[index, "output"] = output.output
83+
84+
for k, v in output.other_fields.items():
85+
if k not in df.columns:
86+
df[k] = None
87+
df.at[index, k] = v
88+
89+
trace = tracer.get_current_trace()
90+
if trace:
91+
steps = trace.to_dict()
92+
df.at[index, "steps"] = steps
93+
# also need cost, latency, tokens, timestamp
94+
95+
config = {}
96+
config["outputColumnName"] = "output"
97+
config["inputVariableNames"] = list(run_signature.parameters.keys())
98+
config["metadata"] = {
99+
"output_timestamp": time.time(),
100+
}
101+
102+
# pull the config info from trace if it exists, otherwise manually construct it
103+
# with the bare minimum
104+
# costColumnName, latencyColumnName, numOfTokenColumnName, timestampColumnName
105+
106+
return df, config
107+
108+
def write_output_to_directory(self, output_df, config, output_dir, fmt="json"):
109+
"""
110+
Writes the output DataFrame to a file in the specified directory based on the
111+
given format.
112+
113+
:param output_df: DataFrame to write.
114+
:param output_dir: Directory where the output file will be saved.
115+
:param fmt: Format of the output file ('csv' or 'json').
116+
"""
117+
os.makedirs(
118+
output_dir, exist_ok=True
119+
) # Create the directory if it doesn't exist
120+
121+
# Determine the filename based on the dataset name and format
122+
filename = f"dataset.{fmt}"
123+
output_path = os.path.join(output_dir, filename)
124+
125+
# Write the config to a json file
126+
config_path = os.path.join(output_dir, "config.json")
127+
with open(config_path, "w", encoding="utf-8") as f:
128+
json.dump(config, f, indent=4)
129+
130+
# Write the DataFrame to the file based on the specified format
131+
if fmt == "csv":
132+
output_df.to_csv(output_path, index=False)
133+
elif fmt == "json":
134+
output_df.to_json(output_path, orient="records", indent=4)
135+
else:
136+
raise ValueError("Unsupported format. Please choose 'csv' or 'json'.")
137+
138+
print(f"Output written to {output_path}")
139+
140+
@abc.abstractmethod
141+
def run(self, **kwargs) -> RunReturn:
142+
"""Function that runs the model and returns the result."""
143+
pass

0 commit comments

Comments
 (0)