-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathserve.py
132 lines (100 loc) · 4.15 KB
/
serve.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import ray
import os
from starlette.requests import Request
from ray import serve
from typing import List, Optional, Any
import langchain
from langchain.llms.utils import enforce_stop_tokens
from langchain.embeddings import HuggingFaceEmbeddings, SentenceTransformerEmbeddings
from langchain.vectorstores import Pinecone
from langchain.chains.qa_with_sources import load_qa_with_sources_chain
from langchain.chains.question_answering import load_qa_chain
from langchain import HuggingFacePipeline
from langchain.prompts import PromptTemplate
from langchain.chains import RetrievalQA
from langchain.chains.mapreduce import MapReduceChain
from transformers import pipeline as hf_pipeline
from transformers import (
AutoModelForCausalLM,
AutoModelForSeq2SeqLM,
AutoTokenizer,
)
import torch
import time
import pinecone
from langchain.llms import HuggingFacePipeline
import torch
import time
from local_embeddings import LocalHuggingFaceEmbeddings
from llama_local_pipelines import LlamaPipeline
import demoConfig
from langchain.cache import InMemoryCache, RedisSemanticCache
from transformers import BitsAndBytesConfig
from fastapi import FastAPI
token = demoConfig.hf_token
prompt_template = """<s>[INST] <<SYS>>
Answer the following query based on the CONTEXT
given. If you do not know the answer and the CONTEXT doesn't
contain the answer truthfully say "I don't know".
CONTEXT:
{context}
QUESTION:
{question}
<</SYS>>
[/INST]
"""
PROMPT = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
@serve.deployment(ray_actor_options={"num_gpus": 1, "num_cpus": 16})
class LlamaRAGDeployment:
def __init__(self):
# Enable LLM Cache for Langchain QA Chain
langchain.llm_cache = InMemoryCache()
# WandbTracer.init({"project": "retrieval_demo"})
# add Pinecone API key from app.pinecone.io
pinecone_api_key = demoConfig.pinecone_api_key
# set Pinecone environment - find next to API key in console
env = demoConfig.pinecone_env
index_name = demoConfig.pinecone_index_name
pinecone.init(api_key=pinecone_api_key, environment=env)
# Load the data from Pinecone. No change from Part 1
st = time.time()
self.embeddings = LocalHuggingFaceEmbeddings("all-MiniLM-L6-v2")
text_field = "text"
index = pinecone.Index(index_name)
self.vectorstore = Pinecone(
index, self.embeddings.embed_query, text_field
)
et = time.time() - st
print(f"Loading Pinecone database took {et} seconds.")
st = time.time()
global token
bnb_config = BitsAndBytesConfig(
load_in_4bit=True, # enable 4-bit quantization by replacing the Linear layers with FP4/NF4 layers from bitsandbytes.
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16, # sets the computational type which might be different than the input time
bnb_4bit_use_double_quant=True, # nested quantization where the quantization constants from the first quantization are quantized again.
)
self.llm = LlamaPipeline.from_model(
model="meta-llama/Llama-2-13b-chat-hf",
task="text-generation",
token = token,
model_kwargs={
"device_map": "auto",
"quantization_config": bnb_config
}
)
et = time.time() - st
print(f"Loading HF model took {et} seconds.")
self.chain = load_qa_chain(llm=self.llm, chain_type="stuff", prompt=PROMPT)
def qa(self, query):
st = time.time()
context = self.vectorstore.similarity_search(query, k=3)
print(f"Results from db are: {context}")
et = time.time() - st
result = self.chain({"input_documents": context, "question": query})
print(f"Result is: {result}")
print(f"Vector Retrieval took: {et} seconds.")
return result["output_text"]
async def __call__(self, request: Request) -> List[str]:
return self.qa(request.query_params["query"])
deployment = LlamaRAGDeployment.bind()