forked from jvhs0706/zkllm-ccs2024
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathllama-skip-connection.py
28 lines (20 loc) · 1023 Bytes
/
llama-skip-connection.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import os, sys
import argparse
import torch
import numpy as np
parser = argparse.ArgumentParser(description='LLaMa-2 Skip Connection')
parser.add_argument('--block_input_file', required = True, type=str, help='Input of the block.')
parser.add_argument('--block_output_file', required = True, type=str, help='Output of the block.')
parser.add_argument('--output_file', required = True, type=str, help='Output of the skip connection.')
from transformers import AutoTokenizer, AutoModelForCausalLM
import fileio_utils
if __name__ == '__main__':
compilation_error = os.system('make skip-connection')
if compilation_error:
print("Error compiling skip-connection")
exit(1)
args = parser.parse_args()
if not os.path.isfile(args.block_input_file) or not os.path.isfile(args.block_output_file):
print("Input or output file does not exist.")
exit(1)
os.system('./skip-connection {} {} {}'.format(args.block_input_file, args.block_output_file, args.output_file))