-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathllmConnect.py
More file actions
98 lines (75 loc) · 2.83 KB
/
llmConnect.py
File metadata and controls
98 lines (75 loc) · 2.83 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
from openai import OpenAI
import os
from dotenv import load_dotenv
import transformers
import torch
import datetime
import glob
print(datetime.datetime.now().strftime("%H:%M"))
print('program start')
# Enviroment
load_dotenv()
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
print('set up chatGPT')
cl7B = "codellama/CodeLlama-7b-Instruct-hf"
tokenizer7B = transformers.AutoTokenizer.from_pretrained(cl7B)
model7B = transformers.AutoModelForCausalLM.from_pretrained(
cl7B,
torch_dtype=torch.float16,
device_map="auto",
)
pipeline7B = transformers.pipeline(
"text-generation",
model=cl7B,
torch_dtype=torch.float16,
device_map="auto"
)
print('set up codellama7b')
prompt = "Please review the following code and identify and list any errors. If errors are present, please present a corrected version of the code as well. Keep explanations short."
# Iterating over files in a folder
files = glob.glob('.\\QuixBugs\\python_programs\\*.py') # Change path and/or .py to .java when checking java files
for index, filePath in enumerate(files):
if filePath.endswith('_test.py'):
continue
print(datetime.datetime.now().strftime("%H:%M"))
print(f"Processing file {index+1}/{len(files)}: {filePath}")
with open(filePath, 'r') as file:
code = file.read()
# gpt
gptResponse = client.chat.completions.create(
model='gpt-4',
messages=[
{"role": "system", "content": prompt},
{"role": "user", "content": code}
]
)
gptCorrectionPath = './llmCorrection/python/chatGPT/' + os.path.basename(filePath).replace('.py', '.txt') # Change .py to .java when checking java files
os.makedirs(os.path.dirname(gptCorrectionPath), exist_ok=True)
with open(gptCorrectionPath, 'a', encoding="utf-8") as gptFile:
gptFile.write(gptResponse.choices[0].message.content + "\n----- another analysis -----\n")
print('Iterated over gpt', end='')
# CodeLlama 7B
chat = [
{"role": "user", "content": prompt + "\n\n\n" + code}
]
inputs = tokenizer7B.apply_chat_template(chat, return_tensors="pt").to("cuda")
cl7BSeq = model7B.generate(
input_ids=inputs,
max_new_tokens=1500,
eos_token_id=tokenizer7B.eos_token_id,
num_return_sequences=1,
top_p=0.95,
temperature=0.3,
top_k=10,
do_sample=True
)
cl7BSeq = cl7BSeq[0].to("cpu")
cl7BCorrectionPath = './llmCorrection/python/codeLlama7B/' + os.path.basename(filePath).replace('.py', '.txt') # Change .py to .java when checking java files
os.makedirs(os.path.dirname(cl7BCorrectionPath), exist_ok=True)
with open(cl7BCorrectionPath, 'a', encoding="utf-8") as cl7BFile:
for seq in cl7BSeq:
cl7BFile.write(tokenizer7B.decode(seq, skip_special_tokens=True) + " ")
cl7BFile.write("\n----- another analysis -----\n")
print('and codeLlama7B')
print(datetime.datetime.now().strftime("%H:%M"))
print('program finished')