-
Notifications
You must be signed in to change notification settings - Fork 119
Open
Description
当前从examples 看是单batch的,如何能够使用多batch进行推理额,现在多batch 的结果好像不太对的样子。
from transformers import AutoModelForCausalLM, AutoTokenizer
import pdb
import torch
tokenizer = AutoTokenizer.from_pretrained("CodeShell-7B", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("CodeShell-7B", torch_dtype=torch.float16, trust_remote_code=True).cuda()
examples = [
"import math\ndef print_hello():",
"import math\ndef quick_sort():",
"import math\ndef test_quick_sort():",
"import math\ndef test_print_hello():",
"import math\ndef test_merge_sort():",
"import math\ndef two_sum():",
"import math\ndef preoder_transverse():",
"import math\ndef merge_sort():",
]
inputs = tokenizer(examples, return_tensors='pt', padding=True)['input_ids'].cuda()
outputs = model.generate(inputs, max_new_tokens=128)
for output in outputs:
print("=====================> ",tokenizer.decode(output))
测试代码如上,
测试结果有点奇怪
=====================> import math
def print_hello():<|endoftext|><|endoftext|><|endoftext|><fim_prefix><fim_suffix> }
}
}<fim_middle>using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
namespace _04.Longest_Increasing_Subsequence
{
class Program
{
static void Main(string[] args)
{
int[] nums = Console.ReadLine().Split(' ').Select(int.Parse).ToArray();
int[] len = new int[nums.Length];
int[] prev = new int[nums.Length];
int maxLen = 0;
=====================> import math
def quick_sort():<|endoftext|><|endoftext|><|endoftext|><fim_prefix><fim_suffix> }
}
}<fim_middle>using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
Metadata
Metadata
Assignees
Labels
No labels