diff --git a/.github/workflows/code_quality.yml b/.github/workflows/code_quality.yml new file mode 100644 index 0000000..adfa784 --- /dev/null +++ b/.github/workflows/code_quality.yml @@ -0,0 +1,13 @@ +name: Code Quality + +on: [pull_request] + +jobs: + code-quality: + runs-on: ubuntu-20.04 + steps: + - uses: actions/checkout@v2 + - uses: actions/setup-python@v2 + with: + python-version: 3.8 + - uses: pre-commit/action@v2.0.3 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..739c3ac --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,40 @@ +# See https://pre-commit.com for more information +# See https://pre-commit.com/hooks.html for more hooks +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.4.0 + hooks: + - id: check-case-conflict + - id: check-json + - id: check-symlinks + - id: check-yaml + - id: destroyed-symlinks + - id: end-of-file-fixer + exclude: docs/CNAME + - id: fix-byte-order-marker + - id: fix-encoding-pragma + args: [--remove] + - id: mixed-line-ending + args: [--fix=lf] + - id: requirements-txt-fixer + - id: trailing-whitespace + - repo: https://github.com/psf/black + rev: 23.1.0 + hooks: + - id: black + - repo: https://github.com/pycqa/isort + rev: 5.12.0 + hooks: + - id: isort + name: isort (python) + - repo: https://github.com/pycqa/flake8 + rev: 6.0.0 + hooks: + - id: flake8 + - repo: https://github.com/codespell-project/codespell + rev: v2.2.2 + hooks: + - id: codespell + args: [--ignore-words, dictionary.txt, --skip, customToolformer/merges.txt, --skip, customToolformer/vocab.json] + additional_dependencies: + - tomli diff --git a/configs/ds_configs/ds_config_gpt_j_z3.json b/configs/ds_configs/ds_config_gpt_j_z3.json index 0df50da..84c11a5 100644 --- a/configs/ds_configs/ds_config_gpt_j_z3.json +++ b/configs/ds_configs/ds_config_gpt_j_z3.json @@ -1,44 +1,44 @@ { "train_batch_size": "auto", "fp16": { - "enabled": "auto", - "min_loss_scale": 1, - "loss_scale_window": 1000, - "hysteresis": 2, - "initial_scale_power": 32 + "enabled": "auto", + "min_loss_scale": 1, + "loss_scale_window": 1000, + "hysteresis": 2, + "initial_scale_power": 32 }, "bf16": { "enabled": "auto" }, "zero_optimization": { - "stage": 3, - "offload_param": { - "device": "none" - }, - "offload_optimizer": { - "device": "none" - }, - "allgather_partitions": true, - "allgather_bucket_size": 5e8, - "contiguous_gradients": true + "stage": 3, + "offload_param": { + "device": "none" + }, + "offload_optimizer": { + "device": "none" + }, + "allgather_partitions": true, + "allgather_bucket_size": 5e8, + "contiguous_gradients": true }, "optimizer": { - "type": "AdamW", - "params": { - "lr": "auto", - "betas": [ - 0.9, - 0.999 - ], - "eps": 1e-08 - } + "type": "AdamW", + "params": { + "lr": "auto", + "betas": [ + 0.9, + 0.999 + ], + "eps": 1e-08 + } }, "scheduler": { - "type": "WarmupLR", - "params": { - "warmup_min_lr": 0, - "warmup_max_lr": "auto", - "warmup_num_steps": 100 - } + "type": "WarmupLR", + "params": { + "warmup_min_lr": 0, + "warmup_max_lr": "auto", + "warmup_num_steps": 100 + } } - } \ No newline at end of file +} diff --git a/configs/trlx_configs/ds_config_trlx_gpt_j_z3.json b/configs/trlx_configs/ds_config_trlx_gpt_j_z3.json index e69de29..0967ef4 100644 --- a/configs/trlx_configs/ds_config_trlx_gpt_j_z3.json +++ b/configs/trlx_configs/ds_config_trlx_gpt_j_z3.json @@ -0,0 +1 @@ +{} diff --git a/data_generation/api_checker.py b/data_generation/api_checker.py index 63d21e7..3586d57 100644 --- a/data_generation/api_checker.py +++ b/data_generation/api_checker.py @@ -1,8 +1,9 @@ -from dataclasses import dataclass -from transformers import PreTrainedTokenizerBase -import dateutil.parser as dparser import random import re +from dataclasses import dataclass + +import dateutil.parser as dparser +from transformers import PreTrainedTokenizerBase @dataclass @@ -32,11 +33,11 @@ def check_apis_available( available = AvailableAPIs() # In case we need a different version, found this here: # https://stackoverflow.com/questions/28198370/regex-for-validating-correct-input-for-calculator - calc_pattern = re.compile("^(\d+[\+\-\*\/]{1})+\d+$") + calc_pattern = re.compile(r"^(\d+[\+\-\*\/]{1})+\d+$") if len(tokenized_data) < 4096: available.retrieval = False try: - date = dparser.parse(data["url"], fuzzy=True) + dparser.parse(data["url"], fuzzy=True) except (ValueError, OverflowError): available.calendar = False available.calculator = False diff --git a/data_generation/base_api.py b/data_generation/base_api.py index a32d014..df9851c 100644 --- a/data_generation/base_api.py +++ b/data_generation/base_api.py @@ -1,13 +1,13 @@ -import json from typing import List + import torch +from torch import nn from transformers import ( - PreTrainedTokenizerBase, - pipeline, PreTrainedModel, + PreTrainedTokenizerBase, TextGenerationPipeline, + pipeline, ) -from torch import nn MAX_BATCH_SIZE = 1 # My 3090 is weak 😔 N = 64 # SEQ Len @@ -22,7 +22,7 @@ def __init__( minimum_percentage: float = 0.1, ): """ - Base API Postprocesing class + Base API Postprocessing class :param start_tokens: token representation for [ or other tokens :param end_tokens: token representation for ] or other tokens diff --git a/data_generation/calculator.py b/data_generation/calculator.py index b2b8f24..836f32b 100644 --- a/data_generation/calculator.py +++ b/data_generation/calculator.py @@ -1,14 +1,11 @@ -import torch -from transformers import ( - PreTrainedTokenizerBase, - PreTrainedModel, -) -from tools import Calculator -from prompts import calculator_prompt from typing import List -from data_generation.base_api import APICallPostprocessing -import dateutil.parser as dparser +import torch +from transformers import PreTrainedModel, PreTrainedTokenizerBase + +from data_generation.base_api import APICallPostprocessing +from prompts import calculator_prompt +from tools import Calculator # TODO: Per API? MAX_BATCH_SIZE = 1 # My 3090 is weak 😔 @@ -68,7 +65,10 @@ def add_api_calls( continue if outputs[j]["Calculator"] is None: continue - outputs[j]["Calculator_output"] = [outputs[j]["Calculator_text"][1:], str(outputs[j]["Calculator"])] + outputs[j]["Calculator_output"] = [ + outputs[j]["Calculator_text"][1:], + str(outputs[j]["Calculator"]), + ] outputs[j]["Calculator_text"] = ( outputs[j]["Calculator_text"] + "->" @@ -113,7 +113,7 @@ def parse_article( ): outputs = list() tokens = tokenizer(data["text"], return_tensors="pt")["input_ids"] - for i in range((tokens.shape[1]-1)//N): + for i in range((tokens.shape[1] - 1) // N): if (N * (i + 1)) > tokens.shape[1]: continue input_tokens = tokens[:, (-N * (i + 1) - 1) : (-N * (i) - 1)] @@ -145,5 +145,7 @@ def parse_article( output["index"] += int(tokens.shape[1] + (-N * (i + 1))) # filter by score if output["Score"] > 0.0: - outputs.append([output["Score"], output["index"]] + output["Calculator_output"]) + outputs.append( + [output["Score"], output["index"]] + output["Calculator_output"] + ) return outputs diff --git a/data_generation/calendar.py b/data_generation/calendar.py index d7a2209..97d7ea0 100644 --- a/data_generation/calendar.py +++ b/data_generation/calendar.py @@ -1,14 +1,12 @@ -import torch -from transformers import ( - PreTrainedTokenizerBase, - PreTrainedModel, -) -from tools import Calendar -from prompts import calendar_prompt from typing import List -from data_generation.base_api import APICallPostprocessing + import dateutil.parser as dparser +import torch +from transformers import PreTrainedModel, PreTrainedTokenizerBase +from data_generation.base_api import APICallPostprocessing +from prompts import calendar_prompt +from tools import Calendar # TODO: Per API? MAX_BATCH_SIZE = 1 # My 3090 is weak 😔 @@ -62,7 +60,10 @@ def add_api_calls( return_tensors="pt", )["input_ids"].cuda() outputs[j]["Calendar"] = self.calendar(calendar_string) - outputs[j]["Calendar_output"] = [outputs[j]["Calendar_text"][1:], outputs[j]["Calendar"]] + outputs[j]["Calendar_output"] = [ + outputs[j]["Calendar_text"][1:], + outputs[j]["Calendar"], + ] outputs[j]["Calendar_text"] = ( outputs[j]["Calendar_text"] + "->" + outputs[j]["Calendar"] + "]" ) @@ -104,7 +105,7 @@ def parse_article( ): outputs = list() tokens = tokenizer(data["text"], return_tensors="pt")["input_ids"] - for i in range((tokens.shape[1]-1)//N): + for i in range((tokens.shape[1] - 1) // N): if (N * (i + 1)) > tokens.shape[1]: continue input_tokens = tokens[:, (-N * (i + 1) - 1) : (-N * (i) - 1)] @@ -112,7 +113,7 @@ def parse_article( :, int(tokens.shape[1] + (-N * (i + 1))) : int(tokens.shape[1] + (-N * i)), ] - ret_tokens = tokens[:, : (-N * (i + 1) - 1)] + # ret_tokens = tokens[:, : (-N * (i + 1) - 1)] print(tokens.shape) string = tokenizer.decode(input_tokens[0]) # print(ret_strings) @@ -138,5 +139,7 @@ def parse_article( output["index"] += int(tokens.shape[1] + (-N * (i + 1))) # filter by score if output["Score"] > 0.0: - outputs.append([output["Score"], output["index"]] + output["Calendar_output"]) + outputs.append( + [output["Score"], output["index"]] + output["Calendar_output"] + ) return outputs diff --git a/data_generation/llmchain.py b/data_generation/llmchain.py index 6b07a80..3e6b89a 100644 --- a/data_generation/llmchain.py +++ b/data_generation/llmchain.py @@ -1,14 +1,11 @@ -import torch -from transformers import ( - PreTrainedTokenizerBase, - PreTrainedModel, -) -from tools import langchain_llmchain -from prompts import llmchain_prompt from typing import List -from data_generation.base_api import APICallPostprocessing +import torch +from transformers import PreTrainedModel, PreTrainedTokenizerBase +from data_generation.base_api import APICallPostprocessing +from prompts import llmchain_prompt +from tools import langchain_llmchain # TODO: Per API? MAX_BATCH_SIZE = 1 # My 3090 is weak 😔 @@ -55,9 +52,9 @@ def add_api_calls( ) if ")" in outputs[j]["LLMChain"]: outputs[j]["LLMChain"] = outputs[j]["LLMChain"].split(")")[0] - if outputs[j]["LLMChain"][0] == "\"": + if outputs[j]["LLMChain"][0] == '"': outputs[j]["LLMChain"] = outputs[j]["LLMChain"][1:] - if outputs[j]["LLMChain"][-1] == "\"": + if outputs[j]["LLMChain"][-1] == '"': outputs[j]["LLMChain"] = outputs[j]["LLMChain"][:-1] outputs[j]["LLMChain_text"] = ( "[LLMChain(" + outputs[j]["LLMChain"] + ")" @@ -67,12 +64,12 @@ def add_api_calls( return_tensors="pt", )["input_ids"].cuda() outputs[j]["LLMChain"] = str(self.llmchain(outputs[j]["LLMChain"])) - outputs[j]["LLMChain_output"] = [outputs[j]["LLMChain_text"][1:], outputs[j]["LLMChain"]] + outputs[j]["LLMChain_output"] = [ + outputs[j]["LLMChain_text"][1:], + outputs[j]["LLMChain"], + ] outputs[j]["LLMChain_text"] = ( - outputs[j]["LLMChain_text"] - + "->" - + outputs[j]["LLMChain"] - + "]" + outputs[j]["LLMChain_text"] + "->" + outputs[j]["LLMChain"] + "]" ) test_inputs = tokenizer( outputs[j]["LLMChain_text"] + "\n", @@ -113,7 +110,7 @@ def parse_article( outputs = list() tokens = tokenizer(data["text"], return_tensors="pt")["input_ids"] start_step = 0 - total_steps = tokens.shape[1]//N + total_steps = tokens.shape[1] // N for i in range(start_step, total_steps): input_tokens = tokens[:, (-N * (i + 1) - 1) : (-N * (i) - 1)] labels = tokens[ @@ -145,5 +142,7 @@ def parse_article( output["index"] += int(tokens.shape[1] + (-N * (i + 1))) # filter by score if output["Score"] > 1.0: - outputs.append([output["Score"], output["index"]] + output["LLMChain_output"]) + outputs.append( + [output["Score"], output["index"]] + output["LLMChain_output"] + ) return outputs diff --git a/data_generation/retrieval.py b/data_generation/retrieval.py index 0fc331b..d4c713b 100644 --- a/data_generation/retrieval.py +++ b/data_generation/retrieval.py @@ -1,14 +1,13 @@ -import torch -from transformers import ( - PreTrainedTokenizerBase, - PreTrainedModel, -) +from typing import List + import nltk +import torch from nltk import tokenize -from tools import Retriever -from prompts import retrieval_prompt -from typing import List +from transformers import PreTrainedModel, PreTrainedTokenizerBase + from data_generation.base_api import APICallPostprocessing +from prompts import retrieval_prompt +from tools import Retriever nltk.download("punkt") @@ -68,7 +67,10 @@ def add_api_calls( outputs[j]["Retrieval"] = self.retriever.retrieval( retrieval_strings, outputs[j]["Retrieval"], 3 ) - outputs[j]["Retrieval_output"] = [outputs[j]["Retrieval_text"][1:], ", ".join(outputs[j]["Retrieval"])] + outputs[j]["Retrieval_output"] = [ + outputs[j]["Retrieval_text"][1:], + ", ".join(outputs[j]["Retrieval"]), + ] outputs[j]["Retrieval_text"] = ( outputs[j]["Retrieval_text"] + "->" @@ -113,9 +115,11 @@ def parse_article( ): outputs = list() tokens = tokenizer(data["text"], return_tensors="pt")["input_ids"] - start_step = 2048//N - ret_skip = 1024//N # naively assuming the model should be able to look back if it's less than this. - total_steps = tokens.shape[1]//N + start_step = 2048 // N + ret_skip = ( + 1024 // N + ) # naively assuming the model should be able to look back if it's less than this. + total_steps = tokens.shape[1] // N for i in range(start_step, total_steps): input_tokens = tokens[:, (-N * (i + 1) - 1) : (-N * (i) - 1)] labels = tokens[ @@ -149,5 +153,7 @@ def parse_article( output["index"] += int(tokens.shape[1] + (-N * (i + 1))) # filter by score if output["Score"] > 1.0: - outputs.append([output["Score"], output["index"]] + output["Retrieval_output"]) + outputs.append( + [output["Score"], output["index"]] + output["Retrieval_output"] + ) return outputs diff --git a/data_generator.py b/data_generator.py index d7ccc1e..0c11e49 100644 --- a/data_generator.py +++ b/data_generator.py @@ -1,24 +1,19 @@ +import argparse +import json import os +import time import torch -from transformers import ( - AutoModelForCausalLM, - AutoTokenizer, -) from datasets import load_dataset -from prompts import retrieval_prompt -from data_generation.retrieval import RetrievalPostprocessing -from data_generation.calendar import CalendarPostprocessing -from data_generation.calculator import CalculatorPostprocessing -from data_generation.api_checker import check_apis_available -import json -import time -import argparse +from transformers import AutoModelForCausalLM, AutoTokenizer +from data_generation.api_checker import check_apis_available +from data_generation.retrieval import RetrievalPostprocessing +from prompts import retrieval_prompt if __name__ == "__main__": - parser = argparse.ArgumentParser(description='do some continuations') - parser.add_argument('--device_id', type=int, default=0) + parser = argparse.ArgumentParser(description="do some continuations") + parser.add_argument("--device_id", type=int, default=0) parser.add_argument("--num_devices", type=int, default=8) args = parser.parse_args() gpt_tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B") @@ -46,14 +41,14 @@ found_examples = 0 output_dataset = list() start_time = time.process_time() - num_examples = int(25000.0/float(args.num_devices)) + num_examples = int(25000.0 / float(args.num_devices)) start_count = -1 if os.path.isfile(f"retrieval_data_{args.device_id}.json"): with open(f"retrieval_data_{args.device_id}.json") as f: output_dataset = json.load(f) - start_count = output_dataset[-1]['file_index'] + start_count = output_dataset[-1]["file_index"] for item in output_dataset: - num_examples -= len(item['retrieval_outputs']) + num_examples -= len(item["retrieval_outputs"]) while found_examples < num_examples: data = next(iter_data) if file_counter < start_count: @@ -70,21 +65,27 @@ { "file_index": file_counter, "text": data["text"], - "retrieval_outputs": data_outputs + "retrieval_outputs": data_outputs, } ) prev_found = found_examples found_examples += len(output_dataset[-1]["retrieval_outputs"]) - eta_s = (num_examples - found_examples) * (time.process_time()-start_time) / max(1, found_examples) + eta_s = ( + (num_examples - found_examples) + * (time.process_time() - start_time) + / max(1, found_examples) + ) eta_m = eta_s // 60 eta_h = eta_m // 60 - eta_m = eta_m - (eta_h*60) - eta_s = eta_s - ((eta_m*60) + (eta_h*60*60)) - print(f"Found: {found_examples}/{num_examples}, ETA: {eta_h}H:{eta_m}M:{eta_s}s") - if found_examples//100 > prev_found//100: - with open(f"retrieval_data_{args.device_id}.json", 'w') as f: + eta_m = eta_m - (eta_h * 60) + eta_s = eta_s - ((eta_m * 60) + (eta_h * 60 * 60)) + print( + f"Found: {found_examples}/{num_examples}, ETA: {eta_h}H:{eta_m}M:{eta_s}s" + ) + if found_examples // 100 > prev_found // 100: + with open(f"retrieval_data_{args.device_id}.json", "w") as f: json.dump(output_dataset, f, indent=2) counter += 1 file_counter += 1 - with open(f"retrieval_data_{args.device_id}.json", 'w') as f: - json.dump(output_dataset, f, indent=2) \ No newline at end of file + with open(f"retrieval_data_{args.device_id}.json", "w") as f: + json.dump(output_dataset, f, indent=2) diff --git a/data_generator_calc.py b/data_generator_calc.py index 1174cfa..a4296cb 100644 --- a/data_generator_calc.py +++ b/data_generator_calc.py @@ -1,24 +1,19 @@ +import argparse +import json import os +import time import torch -from transformers import ( - AutoModelForCausalLM, - AutoTokenizer, -) from datasets import load_dataset -from prompts import retrieval_prompt -from data_generation.retrieval import RetrievalPostprocessing -from data_generation.calendar import CalendarPostprocessing -from data_generation.calculator import CalculatorPostprocessing -from data_generation.api_checker import check_apis_available -import json -import time -import argparse +from transformers import AutoModelForCausalLM, AutoTokenizer +from data_generation.api_checker import check_apis_available +from data_generation.calculator import CalculatorPostprocessing +from prompts import retrieval_prompt if __name__ == "__main__": - parser = argparse.ArgumentParser(description='do some continuations') - parser.add_argument('--device_id', type=int, default=0) + parser = argparse.ArgumentParser(description="do some continuations") + parser.add_argument("--device_id", type=int, default=0) parser.add_argument("--num_devices", type=int, default=8) args = parser.parse_args() gpt_tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B") @@ -46,14 +41,14 @@ found_examples = 0 output_dataset = list() start_time = time.process_time() - num_examples = int(25000.0/float(args.num_devices)) + num_examples = int(25000.0 / float(args.num_devices)) start_count = -1 if os.path.isfile(f"calc_data_{args.device_id}.json"): with open(f"calc_data_{args.device_id}.json") as f: output_dataset = json.load(f) - start_count = output_dataset[-1]['file_index'] + start_count = output_dataset[-1]["file_index"] for item in output_dataset: - num_examples -= len(item['calculator_outputs']) + num_examples -= len(item["calculator_outputs"]) while found_examples < num_examples: data = next(iter_data) if file_counter < start_count: @@ -67,32 +62,44 @@ if test: data_outputs = api_handler.parse_article(data, model, gpt_tokenizer) if len(data_outputs) == 0: - eta_s = (num_examples - found_examples) * (time.process_time() - start_time) / max(1, found_examples) + eta_s = ( + (num_examples - found_examples) + * (time.process_time() - start_time) + / max(1, found_examples) + ) eta_m = eta_s // 60 eta_h = eta_m // 60 eta_m = eta_m - (eta_h * 60) eta_s = eta_s - ((eta_m * 60) + (eta_h * 60 * 60)) - print(f"device {args.device_id} Found: {found_examples}/{num_examples}, ETA: {eta_h}H:{eta_m}M:{eta_s}s") + print( + f"device {args.device_id} Found: {found_examples}/{num_examples}, ETA: {eta_h}H:{eta_m}M:{eta_s}s" + ) continue output_dataset.append( { "file_index": file_counter, "text": data["text"], - "calculator_outputs": data_outputs + "calculator_outputs": data_outputs, } ) prev_found = found_examples found_examples += len(output_dataset[-1]["calculator_outputs"]) - eta_s = (num_examples - found_examples) * (time.process_time()-start_time) / max(1, found_examples) + eta_s = ( + (num_examples - found_examples) + * (time.process_time() - start_time) + / max(1, found_examples) + ) eta_m = eta_s // 60 eta_h = eta_m // 60 - eta_m = eta_m - (eta_h*60) - eta_s = eta_s - ((eta_m*60) + (eta_h*60*60)) - print(f"device {args.device_id} Found: {found_examples}/{num_examples}, ETA: {eta_h}H:{eta_m}M:{eta_s}s") - if found_examples//100 > prev_found//100: - with open(f"calc_data_{args.device_id}.json", 'w') as f: + eta_m = eta_m - (eta_h * 60) + eta_s = eta_s - ((eta_m * 60) + (eta_h * 60 * 60)) + print( + f"device {args.device_id} Found: {found_examples}/{num_examples}, ETA: {eta_h}H:{eta_m}M:{eta_s}s" + ) + if found_examples // 100 > prev_found // 100: + with open(f"calc_data_{args.device_id}.json", "w") as f: json.dump(output_dataset, f, indent=2) counter += 1 file_counter += 1 - with open(f"calc_data_{args.device_id}.json", 'w') as f: - json.dump(output_dataset, f, indent=2) \ No newline at end of file + with open(f"calc_data_{args.device_id}.json", "w") as f: + json.dump(output_dataset, f, indent=2) diff --git a/data_generator_calendar.py b/data_generator_calendar.py index 64ea4b9..3983f1e 100644 --- a/data_generator_calendar.py +++ b/data_generator_calendar.py @@ -1,24 +1,19 @@ +import argparse +import json import os +import time import torch -from transformers import ( - AutoModelForCausalLM, - AutoTokenizer, -) from datasets import load_dataset -from prompts import retrieval_prompt -from data_generation.retrieval import RetrievalPostprocessing -from data_generation.calendar import CalendarPostprocessing -from data_generation.calculator import CalculatorPostprocessing -from data_generation.api_checker import check_apis_available -import json -import time -import argparse +from transformers import AutoModelForCausalLM, AutoTokenizer +from data_generation.api_checker import check_apis_available +from data_generation.calendar import CalendarPostprocessing +from prompts import retrieval_prompt if __name__ == "__main__": - parser = argparse.ArgumentParser(description='do some continuations') - parser.add_argument('--device_id', type=int, default=0) + parser = argparse.ArgumentParser(description="do some continuations") + parser.add_argument("--device_id", type=int, default=0) parser.add_argument("--num_devices", type=int, default=8) args = parser.parse_args() gpt_tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B") @@ -46,14 +41,14 @@ found_examples = 0 output_dataset = list() start_time = time.process_time() - num_examples = int(25000.0/float(args.num_devices)) + num_examples = int(25000.0 / float(args.num_devices)) start_count = -1 if os.path.isfile(f"calendar_data_{args.device_id}.json"): with open(f"calendar_data_{args.device_id}.json") as f: output_dataset = json.load(f) - start_count = output_dataset[-1]['file_index'] + start_count = output_dataset[-1]["file_index"] for item in output_dataset: - num_examples -= len(item['calendar_outputs']) + num_examples -= len(item["calendar_outputs"]) while found_examples < num_examples: data = next(iter_data) if file_counter < start_count: @@ -67,32 +62,44 @@ if test: data_outputs = api_handler.parse_article(data, model, gpt_tokenizer) if len(data_outputs) == 0: - eta_s = (num_examples - found_examples) * (time.process_time() - start_time) / max(1, found_examples) + eta_s = ( + (num_examples - found_examples) + * (time.process_time() - start_time) + / max(1, found_examples) + ) eta_m = eta_s // 60 eta_h = eta_m // 60 eta_m = eta_m - (eta_h * 60) eta_s = eta_s - ((eta_m * 60) + (eta_h * 60 * 60)) - print(f"device {args.device_id} Found: {found_examples}/{num_examples}, ETA: {eta_h}H:{eta_m}M:{eta_s}s") + print( + f"device {args.device_id} Found: {found_examples}/{num_examples}, ETA: {eta_h}H:{eta_m}M:{eta_s}s" + ) continue output_dataset.append( { "file_index": file_counter, "text": data["text"], - "calendar_outputs": data_outputs + "calendar_outputs": data_outputs, } ) prev_found = found_examples found_examples += len(output_dataset[-1]["calendar_outputs"]) - eta_s = (num_examples - found_examples) * (time.process_time()-start_time) / max(1, found_examples) + eta_s = ( + (num_examples - found_examples) + * (time.process_time() - start_time) + / max(1, found_examples) + ) eta_m = eta_s // 60 eta_h = eta_m // 60 - eta_m = eta_m - (eta_h*60) - eta_s = eta_s - ((eta_m*60) + (eta_h*60*60)) - print(f"device {args.device_id} Found: {found_examples}/{num_examples}, ETA: {eta_h}H:{eta_m}M:{eta_s}s") - if found_examples//100 > prev_found//100: - with open(f"calendar_data_{args.device_id}.json", 'w') as f: + eta_m = eta_m - (eta_h * 60) + eta_s = eta_s - ((eta_m * 60) + (eta_h * 60 * 60)) + print( + f"device {args.device_id} Found: {found_examples}/{num_examples}, ETA: {eta_h}H:{eta_m}M:{eta_s}s" + ) + if found_examples // 100 > prev_found // 100: + with open(f"calendar_data_{args.device_id}.json", "w") as f: json.dump(output_dataset, f, indent=2) counter += 1 file_counter += 1 - with open(f"calendar_data_{args.device_id}.json", 'w') as f: - json.dump(output_dataset, f, indent=2) \ No newline at end of file + with open(f"calendar_data_{args.device_id}.json", "w") as f: + json.dump(output_dataset, f, indent=2) diff --git a/data_generator_llmchain.py b/data_generator_llmchain.py index 5b2e4e8..b0d353f 100644 --- a/data_generator_llmchain.py +++ b/data_generator_llmchain.py @@ -1,25 +1,19 @@ +import argparse +import json import os +import time import torch -from transformers import ( - AutoModelForCausalLM, - AutoTokenizer, -) from datasets import load_dataset -from prompts import retrieval_prompt -from data_generation.retrieval import RetrievalPostprocessing -from data_generation.calendar import CalendarPostprocessing -from data_generation.calculator import CalculatorPostprocessing -from data_generation.llmchain import LLMChainPostprocessing -from data_generation.api_checker import check_apis_available -import json -import time -import argparse +from transformers import AutoModelForCausalLM, AutoTokenizer +from data_generation.api_checker import check_apis_available +from data_generation.llmchain import LLMChainPostprocessing +from prompts import retrieval_prompt if __name__ == "__main__": - parser = argparse.ArgumentParser(description='do some continuations') - parser.add_argument('--device_id', type=int, default=0) + parser = argparse.ArgumentParser(description="do some continuations") + parser.add_argument("--device_id", type=int, default=0) parser.add_argument("--num_devices", type=int, default=8) args = parser.parse_args() gpt_tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B") @@ -47,14 +41,14 @@ found_examples = 0 output_dataset = list() start_time = time.process_time() - num_examples = int(25000.0/float(args.num_devices)) + num_examples = int(25000.0 / float(args.num_devices)) start_count = -1 if os.path.isfile(f"llmchain_data_{args.device_id}.json"): with open(f"llmchain_data_{args.device_id}.json") as f: output_dataset = json.load(f) - start_count = output_dataset[-1]['file_index'] + start_count = output_dataset[-1]["file_index"] for item in output_dataset: - num_examples -= len(item['retrieval_outputs']) + num_examples -= len(item["retrieval_outputs"]) while found_examples < num_examples: data = next(iter_data) if file_counter < start_count: @@ -71,23 +65,29 @@ { "file_index": file_counter, "text": data["text"], - "llmchain_outputs": data_outputs + "llmchain_outputs": data_outputs, } ) prev_found = found_examples found_examples += len(output_dataset[-1]["llmchain_outputs"]) - eta_s = (num_examples - found_examples) * (time.process_time()-start_time) / max(1, found_examples) + eta_s = ( + (num_examples - found_examples) + * (time.process_time() - start_time) + / max(1, found_examples) + ) eta_m = eta_s // 60 eta_h = eta_m // 60 - eta_m = eta_m - (eta_h*60) - eta_s = eta_s - ((eta_m*60) + (eta_h*60*60)) - print(f"Found: {found_examples}/{num_examples}, ETA: {eta_h}H:{eta_m}M:{eta_s}s") - if found_examples//100 > prev_found//100: - with open(f"llmchain_data_{args.device_id}.json", 'w') as f: + eta_m = eta_m - (eta_h * 60) + eta_s = eta_s - ((eta_m * 60) + (eta_h * 60 * 60)) + print( + f"Found: {found_examples}/{num_examples}, ETA: {eta_h}H:{eta_m}M:{eta_s}s" + ) + if found_examples // 100 > prev_found // 100: + with open(f"llmchain_data_{args.device_id}.json", "w") as f: json.dump(output_dataset, f, indent=2) counter += 1 file_counter += 1 if found_examples > 10: break - with open(f"llmchain_data_{args.device_id}.json", 'w') as f: - json.dump(output_dataset, f, indent=2) \ No newline at end of file + with open(f"llmchain_data_{args.device_id}.json", "w") as f: + json.dump(output_dataset, f, indent=2) diff --git a/data_handling/examine_calculator.py b/data_handling/examine_calculator.py index 9a25a3b..ccd9946 100644 --- a/data_handling/examine_calculator.py +++ b/data_handling/examine_calculator.py @@ -1,9 +1,9 @@ -import matplotlib.pyplot as plt import json -import numpy as np +import matplotlib.pyplot as plt +import numpy as np -if __name__ == '__main__': +if __name__ == "__main__": scores = dict() jsons = list() for i in range(8): @@ -13,8 +13,15 @@ counter = 0 for i in range(8): for item in jsons[i]: - for output in item['calculator_outputs']: - if any(["*" in output[2], "/" in output[2], "+" in output[2], "-" in output[2]]): + for output in item["calculator_outputs"]: + if any( + [ + "*" in output[2], + "/" in output[2], + "+" in output[2], + "-" in output[2], + ] + ): scores[output[0]] = scores.get(output[0], 0) + 1 counter += 1 print(counter) @@ -24,10 +31,10 @@ running_values.append(running_values[-1] + scores[sorted_keys[0]]) running_values = np.array(running_values) sorted_keys = np.array(sorted_keys) - plt.plot(sorted_keys, (1.0 - (running_values/running_values[-1])) * counter) + plt.plot(sorted_keys, (1.0 - (running_values / running_values[-1])) * counter) plt.xlabel("Score") plt.ylabel("Percentage of examples left") plt.show() # Thresholds: # 0.25 = 10% 0.58 = 1% 0.075 = 50% - # call it 0.25 for now \ No newline at end of file + # call it 0.25 for now diff --git a/data_handling/examine_calendar.py b/data_handling/examine_calendar.py index b79ca58..09bf22e 100644 --- a/data_handling/examine_calendar.py +++ b/data_handling/examine_calendar.py @@ -1,9 +1,9 @@ -import matplotlib.pyplot as plt import json -import numpy as np +import matplotlib.pyplot as plt +import numpy as np -if __name__ == '__main__': +if __name__ == "__main__": scores = dict() jsons = list() for i in range(8): @@ -13,7 +13,7 @@ counter = 0 for i in range(8): for item in jsons[i]: - for output in item['calendar_outputs']: + for output in item["calendar_outputs"]: scores[output[0]] = scores.get(output[0], 0) + 1 counter += 1 sorted_keys = sorted(list(scores.keys())) @@ -22,10 +22,10 @@ running_values.append(running_values[-1] + scores[sorted_keys[0]]) running_values = np.array(running_values) sorted_keys = np.array(sorted_keys) - plt.plot(sorted_keys, (1.0 - (running_values/running_values[-1])) * counter) + plt.plot(sorted_keys, (1.0 - (running_values / running_values[-1])) * counter) plt.xlabel("Score") plt.ylabel("examples left") plt.show() # Thresholds: # 0.25 = 10% 0.58 = 1% 0.075 = 50% - # call it 0.25 for now \ No newline at end of file + # call it 0.25 for now diff --git a/data_handling/merge_datasets.py b/data_handling/merge_datasets.py index 8ae6b15..0ef3d45 100644 --- a/data_handling/merge_datasets.py +++ b/data_handling/merge_datasets.py @@ -1,39 +1,50 @@ import json -import numpy as np - -if __name__ == '__main__': +if __name__ == "__main__": combined_data = dict() for i in range(8): with open(f"calc_data_{i}.json") as f: data = json.load(f) for item in data: - for output in item['calculator_outputs']: - if any(["*" in output[2], "/" in output[2], "+" in output[2], "-" in output[2]]): + for output in item["calculator_outputs"]: + if any( + [ + "*" in output[2], + "/" in output[2], + "+" in output[2], + "-" in output[2], + ] + ): if output[0] > 0.07: if item["file_index"] not in list(combined_data.keys()): combined_data[item["file_index"]] = dict() combined_data[item["file_index"]]["text"] = item["text"] combined_data[item["file_index"]]["outputs"] = list() - combined_data[item["file_index"]]["outputs"].append([output[1], output[2], output[3]]) + combined_data[item["file_index"]]["outputs"].append( + [output[1], output[2], output[3]] + ) with open(f"calendar_data_{i}.json") as f: data = json.load(f) for item in data: - for output in item['calendar_outputs']: + for output in item["calendar_outputs"]: if output[0] > 0.25: if item["file_index"] not in list(combined_data.keys()): combined_data[item["file_index"]] = dict() combined_data[item["file_index"]]["text"] = item["text"] combined_data[item["file_index"]]["outputs"] = list() - combined_data[item["file_index"]]["outputs"].append([output[1], output[2], output[3]]) + combined_data[item["file_index"]]["outputs"].append( + [output[1], output[2], output[3]] + ) with open(f"retrieval_data_{i}.json") as f: data = json.load(f) for item in data: - for output in item['retrieval_outputs']: + for output in item["retrieval_outputs"]: if item["file_index"] not in list(combined_data.keys()): combined_data[item["file_index"]] = dict() combined_data[item["file_index"]]["text"] = item["text"] combined_data[item["file_index"]]["outputs"] = list() - combined_data[item["file_index"]]["outputs"].append([output[1], output[2], output[3]]) - with open("../combined_data.json", 'w') as f: + combined_data[item["file_index"]]["outputs"].append( + [output[1], output[2], output[3]] + ) + with open("../combined_data.json", "w") as f: json.dump(combined_data, f, indent=2) diff --git a/data_handling/to_hf_dataset.py b/data_handling/to_hf_dataset.py index 7ada1da..a401dc8 100644 --- a/data_handling/to_hf_dataset.py +++ b/data_handling/to_hf_dataset.py @@ -1,16 +1,16 @@ import json + +import tqdm from datasets import Dataset from transformers import AutoTokenizer -import tqdm - -if __name__ == '__main__': +if __name__ == "__main__": with open("../combined_data.json") as f: data = json.load(f) tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B") hf_training_data = {"text": []} for key in tqdm.tqdm(list(data.keys())): - sorted_keys = sorted(data[key]["outputs"], key=lambda x:x[0]) + sorted_keys = sorted(data[key]["outputs"], key=lambda x: x[0]) tokens = tokenizer(data[key]["text"])["input_ids"] output_text = "" start = 0 @@ -18,12 +18,17 @@ continue for i in range(len(sorted_keys)): if sorted_keys[i][0] != 0: - output_text += tokenizer.decode(tokens[start:sorted_keys[i][0]]) + output_text += tokenizer.decode(tokens[start : sorted_keys[i][0]]) start = sorted_keys[i][0] - output_text += "" + sorted_keys[i][1] + "" + str(sorted_keys[i][2]) + "" - if start < len(tokens)-1: + output_text += ( + "" + + sorted_keys[i][1] + + "" + + str(sorted_keys[i][2]) + + "" + ) + if start < len(tokens) - 1: output_text += tokenizer.decode(tokens[start:]) hf_training_data["text"].append(output_text) dataset = Dataset.from_dict(hf_training_data) dataset.push_to_hub("dmayhem93/toolformer-v0-postprocessed") - diff --git a/dictionary.txt b/dictionary.txt new file mode 100644 index 0000000..e69de29 diff --git a/ds_config_gpt_j.json b/ds_config_gpt_j.json index fcffb1c..d6dad3c 100644 --- a/ds_config_gpt_j.json +++ b/ds_config_gpt_j.json @@ -41,4 +41,4 @@ "warmup_num_steps": 100 } } - } \ No newline at end of file + } diff --git a/flash_attention/flash_attention_gptj_wrapper.py b/flash_attention/flash_attention_gptj_wrapper.py index 7a83bc1..d139f87 100644 --- a/flash_attention/flash_attention_gptj_wrapper.py +++ b/flash_attention/flash_attention_gptj_wrapper.py @@ -1,7 +1,7 @@ # From: https://github.com/kyleliang919/Long-context-transformers import torch -from transformers.models.gptj.modeling_gptj import apply_rotary_pos_emb from flash_attn.modules.mha import FlashSelfAttention +from transformers.models.gptj.modeling_gptj import apply_rotary_pos_emb class FlashAttentionWrapper(torch.nn.Module): diff --git a/flash_attention/flash_attention_neox_wrapper.py b/flash_attention/flash_attention_neox_wrapper.py index 73c4567..bda8917 100644 --- a/flash_attention/flash_attention_neox_wrapper.py +++ b/flash_attention/flash_attention_neox_wrapper.py @@ -1,10 +1,7 @@ # From: https://github.com/kyleliang919/Long-context-transformers import torch -from transformers.models.gpt_neox.modeling_gpt_neox import ( - RotaryEmbedding, - apply_rotary_pos_emb, -) from flash_attn.modules.mha import FlashSelfAttention +from transformers.models.gpt_neox.modeling_gpt_neox import apply_rotary_pos_emb class FlashAttentionWrapper(torch.nn.Module): diff --git a/gptj_pytorch.py b/gptj_pytorch.py index c66103b..3a7e5e7 100644 --- a/gptj_pytorch.py +++ b/gptj_pytorch.py @@ -1,6 +1,6 @@ import torch -from torch import nn, einsum from einops import rearrange +from torch import einsum, nn # helpers diff --git a/poetry.lock b/poetry.lock new file mode 100644 index 0000000..446d454 --- /dev/null +++ b/poetry.lock @@ -0,0 +1,422 @@ +# This file is automatically @generated by Poetry 1.4.0 and should not be changed by hand. + +[[package]] +name = "attrs" +version = "22.2.0" +description = "Classes Without Boilerplate" +category = "dev" +optional = false +python-versions = ">=3.6" +files = [ + {file = "attrs-22.2.0-py3-none-any.whl", hash = "sha256:29e95c7f6778868dbd49170f98f8818f78f3dc5e0e37c0b1f474e3561b240836"}, + {file = "attrs-22.2.0.tar.gz", hash = "sha256:c9227bfc2f01993c03f68db37d1d15c9690188323c067c641f1a35ca58185f99"}, +] + +[package.extras] +cov = ["attrs[tests]", "coverage-enable-subprocess", "coverage[toml] (>=5.3)"] +dev = ["attrs[docs,tests]"] +docs = ["furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib-towncrier", "towncrier", "zope.interface"] +tests = ["attrs[tests-no-zope]", "zope.interface"] +tests-no-zope = ["cloudpickle", "cloudpickle", "hypothesis", "hypothesis", "mypy (>=0.971,<0.990)", "mypy (>=0.971,<0.990)", "pympler", "pympler", "pytest (>=4.3.0)", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-mypy-plugins", "pytest-xdist[psutil]", "pytest-xdist[psutil]"] + +[[package]] +name = "black" +version = "23.1.0" +description = "The uncompromising code formatter." +category = "dev" +optional = false +python-versions = ">=3.7" +files = [ + {file = "black-23.1.0-cp310-cp310-macosx_10_16_arm64.whl", hash = "sha256:b6a92a41ee34b883b359998f0c8e6eb8e99803aa8bf3123bf2b2e6fec505a221"}, + {file = "black-23.1.0-cp310-cp310-macosx_10_16_universal2.whl", hash = "sha256:57c18c5165c1dbe291d5306e53fb3988122890e57bd9b3dcb75f967f13411a26"}, + {file = "black-23.1.0-cp310-cp310-macosx_10_16_x86_64.whl", hash = "sha256:9880d7d419bb7e709b37e28deb5e68a49227713b623c72b2b931028ea65f619b"}, + {file = "black-23.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e6663f91b6feca5d06f2ccd49a10f254f9298cc1f7f49c46e498a0771b507104"}, + {file = "black-23.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:9afd3f493666a0cd8f8df9a0200c6359ac53940cbde049dcb1a7eb6ee2dd7074"}, + {file = "black-23.1.0-cp311-cp311-macosx_10_16_arm64.whl", hash = "sha256:bfffba28dc52a58f04492181392ee380e95262af14ee01d4bc7bb1b1c6ca8d27"}, + {file = "black-23.1.0-cp311-cp311-macosx_10_16_universal2.whl", hash = "sha256:c1c476bc7b7d021321e7d93dc2cbd78ce103b84d5a4cf97ed535fbc0d6660648"}, + {file = "black-23.1.0-cp311-cp311-macosx_10_16_x86_64.whl", hash = "sha256:382998821f58e5c8238d3166c492139573325287820963d2f7de4d518bd76958"}, + {file = "black-23.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2bf649fda611c8550ca9d7592b69f0637218c2369b7744694c5e4902873b2f3a"}, + {file = "black-23.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:121ca7f10b4a01fd99951234abdbd97728e1240be89fde18480ffac16503d481"}, + {file = "black-23.1.0-cp37-cp37m-macosx_10_16_x86_64.whl", hash = "sha256:a8471939da5e824b891b25751955be52ee7f8a30a916d570a5ba8e0f2eb2ecad"}, + {file = "black-23.1.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8178318cb74f98bc571eef19068f6ab5613b3e59d4f47771582f04e175570ed8"}, + {file = "black-23.1.0-cp37-cp37m-win_amd64.whl", hash = "sha256:a436e7881d33acaf2536c46a454bb964a50eff59b21b51c6ccf5a40601fbef24"}, + {file = "black-23.1.0-cp38-cp38-macosx_10_16_arm64.whl", hash = "sha256:a59db0a2094d2259c554676403fa2fac3473ccf1354c1c63eccf7ae65aac8ab6"}, + {file = "black-23.1.0-cp38-cp38-macosx_10_16_universal2.whl", hash = "sha256:0052dba51dec07ed029ed61b18183942043e00008ec65d5028814afaab9a22fd"}, + {file = "black-23.1.0-cp38-cp38-macosx_10_16_x86_64.whl", hash = "sha256:49f7b39e30f326a34b5c9a4213213a6b221d7ae9d58ec70df1c4a307cf2a1580"}, + {file = "black-23.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:162e37d49e93bd6eb6f1afc3e17a3d23a823042530c37c3c42eeeaf026f38468"}, + {file = "black-23.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:8b70eb40a78dfac24842458476135f9b99ab952dd3f2dab738c1881a9b38b753"}, + {file = "black-23.1.0-cp39-cp39-macosx_10_16_arm64.whl", hash = "sha256:a29650759a6a0944e7cca036674655c2f0f63806ddecc45ed40b7b8aa314b651"}, + {file = "black-23.1.0-cp39-cp39-macosx_10_16_universal2.whl", hash = "sha256:bb460c8561c8c1bec7824ecbc3ce085eb50005883a6203dcfb0122e95797ee06"}, + {file = "black-23.1.0-cp39-cp39-macosx_10_16_x86_64.whl", hash = "sha256:c91dfc2c2a4e50df0026f88d2215e166616e0c80e86004d0003ece0488db2739"}, + {file = "black-23.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2a951cc83ab535d248c89f300eccbd625e80ab880fbcfb5ac8afb5f01a258ac9"}, + {file = "black-23.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:0680d4380db3719ebcfb2613f34e86c8e6d15ffeabcf8ec59355c5e7b85bb555"}, + {file = "black-23.1.0-py3-none-any.whl", hash = "sha256:7a0f701d314cfa0896b9001df70a530eb2472babb76086344e688829efd97d32"}, + {file = "black-23.1.0.tar.gz", hash = "sha256:b0bd97bea8903f5a2ba7219257a44e3f1f9d00073d6cc1add68f0beec69692ac"}, +] + +[package.dependencies] +click = ">=8.0.0" +mypy-extensions = ">=0.4.3" +packaging = ">=22.0" +pathspec = ">=0.9.0" +platformdirs = ">=2" +tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} +typing-extensions = {version = ">=3.10.0.0", markers = "python_version < \"3.10\""} + +[package.extras] +colorama = ["colorama (>=0.4.3)"] +d = ["aiohttp (>=3.7.4)"] +jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"] +uvloop = ["uvloop (>=0.15.2)"] + +[[package]] +name = "cfgv" +version = "3.3.1" +description = "Validate configuration and produce human readable error messages." +category = "dev" +optional = false +python-versions = ">=3.6.1" +files = [ + {file = "cfgv-3.3.1-py2.py3-none-any.whl", hash = "sha256:c6a0883f3917a037485059700b9e75da2464e6c27051014ad85ba6aaa5884426"}, + {file = "cfgv-3.3.1.tar.gz", hash = "sha256:f5a830efb9ce7a445376bb66ec94c638a9787422f96264c98edc6bdeed8ab736"}, +] + +[[package]] +name = "click" +version = "8.1.3" +description = "Composable command line interface toolkit" +category = "dev" +optional = false +python-versions = ">=3.7" +files = [ + {file = "click-8.1.3-py3-none-any.whl", hash = "sha256:bb4d8133cb15a609f44e8213d9b391b0809795062913b383c62be0ee95b1db48"}, + {file = "click-8.1.3.tar.gz", hash = "sha256:7682dc8afb30297001674575ea00d1814d808d6a36af415a82bd481d37ba7b8e"}, +] + +[package.dependencies] +colorama = {version = "*", markers = "platform_system == \"Windows\""} + +[[package]] +name = "colorama" +version = "0.4.6" +description = "Cross-platform colored terminal text." +category = "dev" +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" +files = [ + {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"}, + {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, +] + +[[package]] +name = "distlib" +version = "0.3.6" +description = "Distribution utilities" +category = "dev" +optional = false +python-versions = "*" +files = [ + {file = "distlib-0.3.6-py2.py3-none-any.whl", hash = "sha256:f35c4b692542ca110de7ef0bea44d73981caeb34ca0b9b6b2e6d7790dda8f80e"}, + {file = "distlib-0.3.6.tar.gz", hash = "sha256:14bad2d9b04d3a36127ac97f30b12a19268f211063d8f8ee4f47108896e11b46"}, +] + +[[package]] +name = "exceptiongroup" +version = "1.1.0" +description = "Backport of PEP 654 (exception groups)" +category = "dev" +optional = false +python-versions = ">=3.7" +files = [ + {file = "exceptiongroup-1.1.0-py3-none-any.whl", hash = "sha256:327cbda3da756e2de031a3107b81ab7b3770a602c4d16ca618298c526f4bec1e"}, + {file = "exceptiongroup-1.1.0.tar.gz", hash = "sha256:bcb67d800a4497e1b404c2dd44fca47d3b7a5e5433dbab67f96c1a685cdfdf23"}, +] + +[package.extras] +test = ["pytest (>=6)"] + +[[package]] +name = "filelock" +version = "3.9.0" +description = "A platform independent file lock." +category = "dev" +optional = false +python-versions = ">=3.7" +files = [ + {file = "filelock-3.9.0-py3-none-any.whl", hash = "sha256:f58d535af89bb9ad5cd4df046f741f8553a418c01a7856bf0d173bbc9f6bd16d"}, + {file = "filelock-3.9.0.tar.gz", hash = "sha256:7b319f24340b51f55a2bf7a12ac0755a9b03e718311dac567a0f4f7fabd2f5de"}, +] + +[package.extras] +docs = ["furo (>=2022.12.7)", "sphinx (>=5.3)", "sphinx-autodoc-typehints (>=1.19.5)"] +testing = ["covdefaults (>=2.2.2)", "coverage (>=7.0.1)", "pytest (>=7.2)", "pytest-cov (>=4)", "pytest-timeout (>=2.1)"] + +[[package]] +name = "identify" +version = "2.5.19" +description = "File identification library for Python" +category = "dev" +optional = false +python-versions = ">=3.7" +files = [ + {file = "identify-2.5.19-py2.py3-none-any.whl", hash = "sha256:3ee3533e7f6f5023157fbebbd5687bb4b698ce6f305259e0d24b2d7d9efb72bc"}, + {file = "identify-2.5.19.tar.gz", hash = "sha256:4102ecd051f6884449e7359e55b38ba6cd7aafb6ef27b8e2b38495a5723ea106"}, +] + +[package.extras] +license = ["ukkonen"] + +[[package]] +name = "iniconfig" +version = "2.0.0" +description = "brain-dead simple config-ini parsing" +category = "dev" +optional = false +python-versions = ">=3.7" +files = [ + {file = "iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374"}, + {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, +] + +[[package]] +name = "mypy-extensions" +version = "1.0.0" +description = "Type system extensions for programs checked with the mypy type checker." +category = "dev" +optional = false +python-versions = ">=3.5" +files = [ + {file = "mypy_extensions-1.0.0-py3-none-any.whl", hash = "sha256:4392f6c0eb8a5668a69e23d168ffa70f0be9ccfd32b5cc2d26a34ae5b844552d"}, + {file = "mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782"}, +] + +[[package]] +name = "nodeenv" +version = "1.7.0" +description = "Node.js virtual environment builder" +category = "dev" +optional = false +python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*" +files = [ + {file = "nodeenv-1.7.0-py2.py3-none-any.whl", hash = "sha256:27083a7b96a25f2f5e1d8cb4b6317ee8aeda3bdd121394e5ac54e498028a042e"}, + {file = "nodeenv-1.7.0.tar.gz", hash = "sha256:e0e7f7dfb85fc5394c6fe1e8fa98131a2473e04311a45afb6508f7cf1836fa2b"}, +] + +[package.dependencies] +setuptools = "*" + +[[package]] +name = "packaging" +version = "23.0" +description = "Core utilities for Python packages" +category = "dev" +optional = false +python-versions = ">=3.7" +files = [ + {file = "packaging-23.0-py3-none-any.whl", hash = "sha256:714ac14496c3e68c99c29b00845f7a2b85f3bb6f1078fd9f72fd20f0570002b2"}, + {file = "packaging-23.0.tar.gz", hash = "sha256:b6ad297f8907de0fa2fe1ccbd26fdaf387f5f47c7275fedf8cce89f99446cf97"}, +] + +[[package]] +name = "pathspec" +version = "0.11.0" +description = "Utility library for gitignore style pattern matching of file paths." +category = "dev" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pathspec-0.11.0-py3-none-any.whl", hash = "sha256:3a66eb970cbac598f9e5ccb5b2cf58930cd8e3ed86d393d541eaf2d8b1705229"}, + {file = "pathspec-0.11.0.tar.gz", hash = "sha256:64d338d4e0914e91c1792321e6907b5a593f1ab1851de7fc269557a21b30ebbc"}, +] + +[[package]] +name = "platformdirs" +version = "3.1.0" +description = "A small Python package for determining appropriate platform-specific dirs, e.g. a \"user data dir\"." +category = "dev" +optional = false +python-versions = ">=3.7" +files = [ + {file = "platformdirs-3.1.0-py3-none-any.whl", hash = "sha256:13b08a53ed71021350c9e300d4ea8668438fb0046ab3937ac9a29913a1a1350a"}, + {file = "platformdirs-3.1.0.tar.gz", hash = "sha256:accc3665857288317f32c7bebb5a8e482ba717b474f3fc1d18ca7f9214be0cef"}, +] + +[package.extras] +docs = ["furo (>=2022.12.7)", "proselint (>=0.13)", "sphinx (>=6.1.3)", "sphinx-autodoc-typehints (>=1.22,!=1.23.4)"] +test = ["appdirs (==1.4.4)", "covdefaults (>=2.2.2)", "pytest (>=7.2.1)", "pytest-cov (>=4)", "pytest-mock (>=3.10)"] + +[[package]] +name = "pluggy" +version = "1.0.0" +description = "plugin and hook calling mechanisms for python" +category = "dev" +optional = false +python-versions = ">=3.6" +files = [ + {file = "pluggy-1.0.0-py2.py3-none-any.whl", hash = "sha256:74134bbf457f031a36d68416e1509f34bd5ccc019f0bcc952c7b909d06b37bd3"}, + {file = "pluggy-1.0.0.tar.gz", hash = "sha256:4224373bacce55f955a878bf9cfa763c1e360858e330072059e10bad68531159"}, +] + +[package.extras] +dev = ["pre-commit", "tox"] +testing = ["pytest", "pytest-benchmark"] + +[[package]] +name = "pre-commit" +version = "3.1.1" +description = "A framework for managing and maintaining multi-language pre-commit hooks." +category = "dev" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pre_commit-3.1.1-py2.py3-none-any.whl", hash = "sha256:b80254e60668e1dd1f5c03a1c9e0413941d61f568a57d745add265945f65bfe8"}, + {file = "pre_commit-3.1.1.tar.gz", hash = "sha256:d63e6537f9252d99f65755ae5b79c989b462d511ebbc481b561db6a297e1e865"}, +] + +[package.dependencies] +cfgv = ">=2.0.0" +identify = ">=1.0.0" +nodeenv = ">=0.11.1" +pyyaml = ">=5.1" +virtualenv = ">=20.10.0" + +[[package]] +name = "pytest" +version = "7.2.2" +description = "pytest: simple powerful testing with Python" +category = "dev" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pytest-7.2.2-py3-none-any.whl", hash = "sha256:130328f552dcfac0b1cec75c12e3f005619dc5f874f0a06e8ff7263f0ee6225e"}, + {file = "pytest-7.2.2.tar.gz", hash = "sha256:c99ab0c73aceb050f68929bc93af19ab6db0558791c6a0715723abe9d0ade9d4"}, +] + +[package.dependencies] +attrs = ">=19.2.0" +colorama = {version = "*", markers = "sys_platform == \"win32\""} +exceptiongroup = {version = ">=1.0.0rc8", markers = "python_version < \"3.11\""} +iniconfig = "*" +packaging = "*" +pluggy = ">=0.12,<2.0" +tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""} + +[package.extras] +testing = ["argcomplete", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "xmlschema"] + +[[package]] +name = "pyyaml" +version = "6.0" +description = "YAML parser and emitter for Python" +category = "dev" +optional = false +python-versions = ">=3.6" +files = [ + {file = "PyYAML-6.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d4db7c7aef085872ef65a8fd7d6d09a14ae91f691dec3e87ee5ee0539d516f53"}, + {file = "PyYAML-6.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9df7ed3b3d2e0ecfe09e14741b857df43adb5a3ddadc919a2d94fbdf78fea53c"}, + {file = "PyYAML-6.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:77f396e6ef4c73fdc33a9157446466f1cff553d979bd00ecb64385760c6babdc"}, + {file = "PyYAML-6.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a80a78046a72361de73f8f395f1f1e49f956c6be882eed58505a15f3e430962b"}, + {file = "PyYAML-6.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:f84fbc98b019fef2ee9a1cb3ce93e3187a6df0b2538a651bfb890254ba9f90b5"}, + {file = "PyYAML-6.0-cp310-cp310-win32.whl", hash = "sha256:2cd5df3de48857ed0544b34e2d40e9fac445930039f3cfe4bcc592a1f836d513"}, + {file = "PyYAML-6.0-cp310-cp310-win_amd64.whl", hash = "sha256:daf496c58a8c52083df09b80c860005194014c3698698d1a57cbcfa182142a3a"}, + {file = "PyYAML-6.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:d4b0ba9512519522b118090257be113b9468d804b19d63c71dbcf4a48fa32358"}, + {file = "PyYAML-6.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:81957921f441d50af23654aa6c5e5eaf9b06aba7f0a19c18a538dc7ef291c5a1"}, + {file = "PyYAML-6.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:afa17f5bc4d1b10afd4466fd3a44dc0e245382deca5b3c353d8b757f9e3ecb8d"}, + {file = "PyYAML-6.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:dbad0e9d368bb989f4515da330b88a057617d16b6a8245084f1b05400f24609f"}, + {file = "PyYAML-6.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:432557aa2c09802be39460360ddffd48156e30721f5e8d917f01d31694216782"}, + {file = "PyYAML-6.0-cp311-cp311-win32.whl", hash = "sha256:bfaef573a63ba8923503d27530362590ff4f576c626d86a9fed95822a8255fd7"}, + {file = "PyYAML-6.0-cp311-cp311-win_amd64.whl", hash = "sha256:01b45c0191e6d66c470b6cf1b9531a771a83c1c4208272ead47a3ae4f2f603bf"}, + {file = "PyYAML-6.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:897b80890765f037df3403d22bab41627ca8811ae55e9a722fd0392850ec4d86"}, + {file = "PyYAML-6.0-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:50602afada6d6cbfad699b0c7bb50d5ccffa7e46a3d738092afddc1f9758427f"}, + {file = "PyYAML-6.0-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:48c346915c114f5fdb3ead70312bd042a953a8ce5c7106d5bfb1a5254e47da92"}, + {file = "PyYAML-6.0-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:98c4d36e99714e55cfbaaee6dd5badbc9a1ec339ebfc3b1f52e293aee6bb71a4"}, + {file = "PyYAML-6.0-cp36-cp36m-win32.whl", hash = "sha256:0283c35a6a9fbf047493e3a0ce8d79ef5030852c51e9d911a27badfde0605293"}, + {file = "PyYAML-6.0-cp36-cp36m-win_amd64.whl", hash = "sha256:07751360502caac1c067a8132d150cf3d61339af5691fe9e87803040dbc5db57"}, + {file = "PyYAML-6.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:819b3830a1543db06c4d4b865e70ded25be52a2e0631ccd2f6a47a2822f2fd7c"}, + {file = "PyYAML-6.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:473f9edb243cb1935ab5a084eb238d842fb8f404ed2193a915d1784b5a6b5fc0"}, + {file = "PyYAML-6.0-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0ce82d761c532fe4ec3f87fc45688bdd3a4c1dc5e0b4a19814b9009a29baefd4"}, + {file = "PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:231710d57adfd809ef5d34183b8ed1eeae3f76459c18fb4a0b373ad56bedcdd9"}, + {file = "PyYAML-6.0-cp37-cp37m-win32.whl", hash = "sha256:c5687b8d43cf58545ade1fe3e055f70eac7a5a1a0bf42824308d868289a95737"}, + {file = "PyYAML-6.0-cp37-cp37m-win_amd64.whl", hash = "sha256:d15a181d1ecd0d4270dc32edb46f7cb7733c7c508857278d3d378d14d606db2d"}, + {file = "PyYAML-6.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:0b4624f379dab24d3725ffde76559cff63d9ec94e1736b556dacdfebe5ab6d4b"}, + {file = "PyYAML-6.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:213c60cd50106436cc818accf5baa1aba61c0189ff610f64f4a3e8c6726218ba"}, + {file = "PyYAML-6.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9fa600030013c4de8165339db93d182b9431076eb98eb40ee068700c9c813e34"}, + {file = "PyYAML-6.0-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:277a0ef2981ca40581a47093e9e2d13b3f1fbbeffae064c1d21bfceba2030287"}, + {file = "PyYAML-6.0-cp38-cp38-win32.whl", hash = "sha256:d4eccecf9adf6fbcc6861a38015c2a64f38b9d94838ac1810a9023a0609e1b78"}, + {file = "PyYAML-6.0-cp38-cp38-win_amd64.whl", hash = "sha256:1e4747bc279b4f613a09eb64bba2ba602d8a6664c6ce6396a4d0cd413a50ce07"}, + {file = "PyYAML-6.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:055d937d65826939cb044fc8c9b08889e8c743fdc6a32b33e2390f66013e449b"}, + {file = "PyYAML-6.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:e61ceaab6f49fb8bdfaa0f92c4b57bcfbea54c09277b1b4f7ac376bfb7a7c174"}, + {file = "PyYAML-6.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d67d839ede4ed1b28a4e8909735fc992a923cdb84e618544973d7dfc71540803"}, + {file = "PyYAML-6.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cba8c411ef271aa037d7357a2bc8f9ee8b58b9965831d9e51baf703280dc73d3"}, + {file = "PyYAML-6.0-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:40527857252b61eacd1d9af500c3337ba8deb8fc298940291486c465c8b46ec0"}, + {file = "PyYAML-6.0-cp39-cp39-win32.whl", hash = "sha256:b5b9eccad747aabaaffbc6064800670f0c297e52c12754eb1d976c57e4f74dcb"}, + {file = "PyYAML-6.0-cp39-cp39-win_amd64.whl", hash = "sha256:b3d267842bf12586ba6c734f89d1f5b871df0273157918b0ccefa29deb05c21c"}, + {file = "PyYAML-6.0.tar.gz", hash = "sha256:68fb519c14306fec9720a2a5b45bc9f0c8d1b9c72adf45c37baedfcd949c35a2"}, +] + +[[package]] +name = "setuptools" +version = "67.6.0" +description = "Easily download, build, install, upgrade, and uninstall Python packages" +category = "dev" +optional = false +python-versions = ">=3.7" +files = [ + {file = "setuptools-67.6.0-py3-none-any.whl", hash = "sha256:b78aaa36f6b90a074c1fa651168723acbf45d14cb1196b6f02c0fd07f17623b2"}, + {file = "setuptools-67.6.0.tar.gz", hash = "sha256:2ee892cd5f29f3373097f5a814697e397cf3ce313616df0af11231e2ad118077"}, +] + +[package.extras] +docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-hoverxref (<2)", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (==0.8.3)", "sphinx-reredirects", "sphinxcontrib-towncrier"] +testing = ["build[virtualenv]", "filelock (>=3.4.0)", "flake8 (<5)", "flake8-2020", "ini2toml[lite] (>=0.9)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pip (>=19.1)", "pip-run (>=8.8)", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)", "pytest-perf", "pytest-timeout", "pytest-xdist", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"] +testing-integration = ["build[virtualenv]", "filelock (>=3.4.0)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pytest", "pytest-enabler", "pytest-xdist", "tomli", "virtualenv (>=13.0.0)", "wheel"] + +[[package]] +name = "tomli" +version = "2.0.1" +description = "A lil' TOML parser" +category = "dev" +optional = false +python-versions = ">=3.7" +files = [ + {file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"}, + {file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"}, +] + +[[package]] +name = "typing-extensions" +version = "4.5.0" +description = "Backported and Experimental Type Hints for Python 3.7+" +category = "dev" +optional = false +python-versions = ">=3.7" +files = [ + {file = "typing_extensions-4.5.0-py3-none-any.whl", hash = "sha256:fb33085c39dd998ac16d1431ebc293a8b3eedd00fd4a32de0ff79002c19511b4"}, + {file = "typing_extensions-4.5.0.tar.gz", hash = "sha256:5cb5f4a79139d699607b3ef622a1dedafa84e115ab0024e0d9c044a9479ca7cb"}, +] + +[[package]] +name = "virtualenv" +version = "20.20.0" +description = "Virtual Python Environment builder" +category = "dev" +optional = false +python-versions = ">=3.7" +files = [ + {file = "virtualenv-20.20.0-py3-none-any.whl", hash = "sha256:3c22fa5a7c7aa106ced59934d2c20a2ecb7f49b4130b8bf444178a16b880fa45"}, + {file = "virtualenv-20.20.0.tar.gz", hash = "sha256:a8a4b8ca1e28f864b7514a253f98c1d62b64e31e77325ba279248c65fb4fcef4"}, +] + +[package.dependencies] +distlib = ">=0.3.6,<1" +filelock = ">=3.4.1,<4" +platformdirs = ">=2.4,<4" + +[package.extras] +docs = ["furo (>=2022.12.7)", "proselint (>=0.13)", "sphinx (>=6.1.3)", "sphinx-argparse (>=0.4)", "sphinxcontrib-towncrier (>=0.2.1a0)", "towncrier (>=22.12)"] +test = ["covdefaults (>=2.2.2)", "coverage (>=7.1)", "coverage-enable-subprocess (>=1)", "flaky (>=3.7)", "packaging (>=23)", "pytest (>=7.2.1)", "pytest-env (>=0.8.1)", "pytest-freezegun (>=0.4.2)", "pytest-mock (>=3.10)", "pytest-randomly (>=3.12)", "pytest-timeout (>=2.1)"] + +[metadata] +lock-version = "2.0" +python-versions = "^3.8" +content-hash = "a519bb35a7a2e7275e5068ea56e9c3492d3ce97820c418ab29ea9353d9f8abfb" diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..8a68999 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,22 @@ +[tool.poetry] +name = "toolformer" +version = "0.1.0" +description = "" +authors = ["Your Name "] +license = "MIT" +readme = "README.md" + +[tool.poetry.dependencies] +python = "^3.8" + +[tool.poetry.group.dev.dependencies] +pre-commit = "^3.1.1" +black = "^23.1.0" +pytest = "^7.2.2" + +[tool.isort] +profile = "black" + +[build-system] +requires = ["poetry-core"] +build-backend = "poetry.core.masonry.api" diff --git a/requirements.txt b/requirements.txt index 194f7d3..6b47d19 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ +cohere google-api-python-client -wolframalpha -transformers -openai langchain -cohere \ No newline at end of file +openai +transformers +wolframalpha diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..8332de3 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,3 @@ +[flake8] +max-line-length = 120 +extend-ignore = E203 diff --git a/train_gptj_toolformer.py b/train_gptj_toolformer.py index edc3dd9..8d9bf6a 100644 --- a/train_gptj_toolformer.py +++ b/train_gptj_toolformer.py @@ -1,5 +1,4 @@ #!/usr/bin/env python -# coding=utf-8 # Copyright 2020 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -32,9 +31,8 @@ import datasets import evaluate import torch -from datasets import load_dataset - import transformers +from datasets import load_dataset from transformers import ( CONFIG_MAPPING, MODEL_FOR_CAUSAL_LM_MAPPING, @@ -53,7 +51,6 @@ from transformers.utils import check_min_version, send_example_telemetry from transformers.utils.versions import require_version - # Will error if the minimal version of Transformers is not installed. Remove at your own risks. check_min_version("4.26.0.dev0") diff --git a/train_pythia_flash_toolformer.py b/train_pythia_flash_toolformer.py index e508f28..438d014 100644 --- a/train_pythia_flash_toolformer.py +++ b/train_pythia_flash_toolformer.py @@ -1,22 +1,22 @@ # From: https://github.com/kyleliang919/Long-context-transformers -import torch -import numpy as np -import evaluate -from datasets import load_dataset -from transformers import GPTNeoXForCausalLM -from transformers.models.gpt_neox.modeling_gpt_neox import RotaryEmbedding -from transformers.trainer_utils import get_last_checkpoint +from dataclasses import dataclass, field from itertools import chain from typing import Optional -from dataclasses import dataclass, field + +import evaluate +import torch +from datasets import load_dataset from transformers import ( AutoTokenizer, + GPTNeoXForCausalLM, HfArgumentParser, Trainer, TrainingArguments, default_data_collator, set_seed, ) +from transformers.models.gpt_neox.modeling_gpt_neox import RotaryEmbedding +from transformers.trainer_utils import get_last_checkpoint from flash_attention.flash_attention_gptj_wrapper import FlashAttentionWrapper @@ -38,7 +38,7 @@ class ModelArguments: max_positions: Optional[int] = field( default=8192, - metadata={"help": ("The maximun sequence length of the model.")}, + metadata={"help": ("The maximum sequence length of the model.")}, ) @@ -67,7 +67,7 @@ def main(): max_positions = model_args.max_positions tokenizer.model_max_length = max_positions for each in model.gpt_neox.layers: - original_emb = each.attention.rotary_emb + # original_emb = each.attention.rotary_emb each.attention.rotary_emb = RotaryEmbedding( each.attention.rotary_ndims, max_positions, 10000 )