22import random
33from collections .abc import Iterable , Iterator
44from pathlib import Path
5- from typing import Any , Literal , Optional , Union
5+ from typing import Any , Optional , TypedDict , Union
66
77import yaml
88from datasets import (
@@ -63,6 +63,26 @@ class SyntheticDatasetConfig(BaseModel):
6363 gt = 0 ,
6464 default = None ,
6565 )
66+ turns : int = Field (
67+ description = "The number of turns in the conversation." ,
68+ gt = 0 ,
69+ default = 1 ,
70+ )
71+ turns_stdev : Optional [int ] = Field (
72+ description = "The standard deviation of the number of turns." ,
73+ gt = 0 ,
74+ default = None ,
75+ )
76+ turns_min : Optional [int ] = Field (
77+ description = "The minimum number of turns in the conversation." ,
78+ gt = 0 ,
79+ default = None ,
80+ )
81+ turns_max : Optional [int ] = Field (
82+ description = "The maximum number of turns in the conversation." ,
83+ gt = 0 ,
84+ default = None ,
85+ )
6686 samples : int = Field (
6787 description = "The number of samples to generate for the dataset." ,
6888 gt = 0 ,
@@ -118,14 +138,13 @@ def parse_config_file(data: Union[str, Path]) -> "SyntheticDatasetConfig":
118138 return SyntheticDatasetConfig (** config_dict )
119139
120140
121- class SyntheticTextItemsGenerator (
122- Iterable [
123- dict [
124- Literal ["prompt" , "prompt_tokens_count" , "output_tokens_count" ],
125- Union [str , int ],
126- ]
127- ]
128- ):
141+ class SyntheticDatasetRow (TypedDict ):
142+ prompt : list [str ]
143+ prompt_tokens_count : list [int ]
144+ output_tokens_count : list [int ]
145+
146+
147+ class SyntheticTextItemsGenerator (Iterable [SyntheticDatasetRow ]):
129148 def __init__ (
130149 self ,
131150 config : SyntheticDatasetConfig ,
@@ -141,12 +160,7 @@ def __init__(
141160
142161 def __iter__ (
143162 self ,
144- ) -> Iterator [
145- dict [
146- Literal ["prompt" , "prompt_tokens_count" , "output_tokens_count" ],
147- Union [str , int ],
148- ]
149- ]:
163+ ) -> Iterator [SyntheticDatasetRow ]:
150164 prompt_tokens_sampler = IntegerRangeSampler (
151165 average = self .config .prompt_tokens ,
152166 variance = self .config .prompt_tokens_stdev ,
@@ -161,20 +175,33 @@ def __iter__(
161175 max_value = self .config .output_tokens_max ,
162176 random_seed = self .random_seed + 1 , # ensure diff dist from prompts
163177 )
178+ turns_sampler = IntegerRangeSampler (
179+ average = self .config .turns ,
180+ variance = self .config .turns_stdev ,
181+ min_value = self .config .turns_min ,
182+ max_value = self .config .turns_max ,
183+ random_seed = self .random_seed + 7 , # ensure diff dist
184+ )
164185 # ensure diff distribution from output tokens
165186 rand = random .Random (self .random_seed + 2 ) # noqa: S311
166187
167- for _ , prompt_tokens , output_tokens in zip (
168- range (self .config .samples ),
169- prompt_tokens_sampler ,
170- output_tokens_sampler ,
171- ):
172- start_index = rand .randint (0 , len (self .text_creator .words ))
173- yield {
174- "prompt" : self ._create_prompt (prompt_tokens , start_index ),
175- "prompt_tokens_count" : prompt_tokens ,
176- "output_tokens_count" : output_tokens ,
188+ for _ , turns in zip (range (self .config .samples ), turns_sampler ):
189+ row : SyntheticDatasetRow = {
190+ "prompt" : [],
191+ "prompt_tokens_count" : [],
192+ "output_tokens_count" : [],
177193 }
194+ for _ , prompt_tokens , output_tokens in zip (
195+ range (turns ),
196+ prompt_tokens_sampler ,
197+ output_tokens_sampler ,
198+ ):
199+ start_index = rand .randint (0 , len (self .text_creator .words ))
200+ row ["prompt" ].append (self ._create_prompt (prompt_tokens , start_index ))
201+ row ["prompt_tokens_count" ].append (prompt_tokens )
202+ row ["output_tokens_count" ].append (output_tokens )
203+
204+ yield row
178205
179206 def _create_prompt (self , prompt_tokens : int , start_index : int ) -> str :
180207 if prompt_tokens <= 0 :
0 commit comments