"
+ if resp.status_code != 200:
+ raise Exception("Could not get Google Bard")
+ SNlM0e = re.search(r"SNlM0e\":\"(.*?)\"", resp.text).group(1)
+ return SNlM0e
+
+ async def ask(self, message: Message) -> Response:
+ """
+ Send a message to Google Bard and return the response.
+ :param message: The message to send to Google Bard.
+ :return: A dict containing the response from Google Bard.
+ """
+ if message.state.conversation_id == "":
+ message.state.req_id = int("".join(random.choices(string.digits, k=4)))
+ # url params
+ params = {
+ # "bl": "boq_assistant-bard-web-server_20230315.04_p2",
+ # This is a newer API version
+ "bl": "boq_assistant-bard-web-server_20230507.20_p2",
+ "_reqid": str(message.state.req_id),
+ "rt": "c",
+ }
+
+ # message arr -> data["f.req"]. Message is double json stringified
+ message_struct = [
+ [message.content],
+ None,
+ [
+ message.state.conversation_id,
+ message.state.response_id,
+ message.state.choice_id,
+ ],
+ ]
+ data = {
+ "f.req": json.dumps([None, json.dumps(message_struct)]),
+ "at": self.SNlM0e,
+ }
+
+ # do the request!
+ resp = await self.session.post(
+ "https://bard.google.com/_/BardChatUi/data/assistant.lamda.BardFrontendService/StreamGenerate",
+ params=params,
+ data=data,
+ timeout=60,
+ )
+
+ chat_data = json.loads(resp.content.splitlines()[3])[0][2]
+ if not chat_data:
+ return Response(
+ content=f"Google Bard encountered an error: {resp.content}.",
+ factualityQueries=[],
+ textQuery="",
+ choices=[],
+ state=message.state,
+ )
+ json_chat_data = json.loads(chat_data)
+ conversation = ConversationState(
+ conversation_id=json_chat_data[1][0],
+ response_id=json_chat_data[1][1],
+ choice_id=json_chat_data[4][0][0],
+ req_id=message.state.req_id + 100000,
+ )
+ return Response(
+ content=json_chat_data[0][0],
+ factualityQueries=json_chat_data[3],
+ textQuery=json_chat_data[2][0] if json_chat_data[2] is not None else "",
+ choices=[{"id": i[0], "content": i[1]} for i in json_chat_data[4]],
+ state=conversation,
+ )
+
+
+app = FastAPI()
+chatbot = None
+
+
+@app.on_event("startup")
+async def startup_event():
+ global chatbot
+ cookie = json.load(open("bard_cookie.json"))
+ chatbot = Chatbot(cookie["__Secure-1PSID"])
+ chatbot.SNlM0e = await chatbot._get_snlm0e()
+
+
+@app.post("/chat", response_model=Response)
+async def chat(message: Message):
+ response = await chatbot.ask(message)
+ return response
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser("Google Bard worker")
+ parser.add_argument("--host", type=str, default="localhost")
+ parser.add_argument("--port", type=int, default=18900)
+ parser.add_argument("--reload", action="store_true")
+ args = parser.parse_args()
+ uvicorn.run(
+ "bard_worker:app", host=args.host, port=args.port, log_level="info",
+ reload=args.reload
+ )
diff --git a/graphgpt/serve/cacheflow_worker.py b/graphgpt/serve/cacheflow_worker.py
new file mode 100644
index 0000000..de3ed3b
--- /dev/null
+++ b/graphgpt/serve/cacheflow_worker.py
@@ -0,0 +1,346 @@
+"""
+A model worker executes the model based on Cacheflow.
+
+Install Cacheflow first. Then, assuming controller is live:
+1. ray start --head
+2. python3 -m fastchat.serve.cacheflow_worker --model-path path_to_vicuna
+
+launch Gradio:
+3. python3 -m fastchat.serve.gradio_web_server --concurrency-count 10000
+"""
+import argparse
+import asyncio
+import json
+import threading
+import time
+import uuid
+from typing import List, Dict
+
+import requests
+import torch
+import uvicorn
+from fastapi import FastAPI, Request, BackgroundTasks
+from fastapi.responses import StreamingResponse
+from transformers import AutoTokenizer
+
+from cacheflow.master.server import Server, initialize_ray_cluster
+from cacheflow.sampling_params import SamplingParams
+from cacheflow.sequence import Sequence, SequenceGroup
+from cacheflow.utils import Counter, get_gpu_memory, get_cpu_memory
+from fastchat.constants import WORKER_HEART_BEAT_INTERVAL
+from fastchat.utils import build_logger, pretty_print_semaphore
+
+GB = 1 << 30
+TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds
+
+worker_id = str(uuid.uuid4())[:6]
+logger = build_logger("model_worker", f"model_worker_{worker_id}.log")
+global_counter = 0
+seed = torch.cuda.current_device()
+
+
+def heart_beat_worker(controller):
+ while True:
+ time.sleep(WORKER_HEART_BEAT_INTERVAL)
+ controller.send_heart_beat()
+
+
+class CacheFlowWorker:
+ def __init__(
+ self,
+ controller_addr,
+ worker_addr,
+ worker_id,
+ no_register,
+ model_path,
+ model_name,
+ block_size,
+ seed,
+ swap_space,
+ max_num_batched_tokens,
+ distributed_init_method,
+ all_stage_devices,
+ ):
+ self.controller_addr = controller_addr
+ self.worker_addr = worker_addr
+ self.worker_id = worker_id
+ if model_path.endswith("/"):
+ model_path = model_path[:-1]
+ self.model_name = model_name or model_path.split("/")[-1]
+
+ logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...")
+ self.block_size = block_size
+
+ # FIXME(Hao): we need to pass the tokenizer into cacheflow because we need
+ # to detect the stopping criteria "###".
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
+ self.seq_group_counter = Counter()
+ self.seq_counter = Counter()
+ # FIXME(Hao): hard code context len
+ self.context_len = 2048
+ # pipeline_parallel_size = 1,
+ # tensor_parallel_size = 1,
+ # dtype = torch.float16
+ remote_server_class = Server
+ self.server = remote_server_class(
+ model=self.model_name,
+ model_path=model_path,
+ pipeline_parallel_size=1,
+ tensor_parallel_size=1,
+ block_size=block_size,
+ dtype=torch.float16,
+ seed=seed,
+ swap_space=swap_space,
+ max_num_batched_tokens=max_num_batched_tokens,
+ num_nodes=1,
+ num_devices_per_node=4,
+ distributed_init_method=distributed_init_method,
+ all_stage_devices=all_stage_devices,
+ gpu_memory=get_gpu_memory(),
+ cpu_memory=get_cpu_memory(),
+ )
+ self.running_seq_groups: Dict[int, SequenceGroup] = {}
+ self.sequence_group_events: Dict[int, asyncio.Event] = {}
+ self.is_server_running = False
+
+ if not no_register:
+ time.sleep(30) # wait for model loading
+ self.register_to_controller()
+ self.heart_beat_thread = threading.Thread(
+ target=heart_beat_worker, args=(self,)
+ )
+ self.heart_beat_thread.start()
+
+ def register_to_controller(self):
+ logger.info("Register to controller")
+
+ url = self.controller_addr + "/register_worker"
+ data = {
+ "worker_name": self.worker_addr,
+ "check_heart_beat": True,
+ "worker_status": self.get_status(),
+ }
+ r = requests.post(url, json=data)
+ assert r.status_code == 200
+
+ def send_heart_beat(self):
+ logger.info(
+ f"Send heart beat. Models: {[self.model_name]}. "
+ f"Semaphore: {pretty_print_semaphore(model_semaphore)}. "
+ f"global_counter: {global_counter}"
+ )
+
+ url = self.controller_addr + "/receive_heart_beat"
+
+ while True:
+ try:
+ ret = requests.post(
+ url,
+ json={
+ "worker_name": self.worker_addr,
+ "queue_length": self.get_queue_length(),
+ },
+ timeout=5,
+ )
+ exist = ret.json()["exist"]
+ break
+ except requests.exceptions.RequestException as e:
+ logger.error(f"heart beat error: {e}")
+ time.sleep(5)
+
+ if not exist:
+ self.register_to_controller()
+
+ def get_queue_length(self):
+ if (
+ model_semaphore is None
+ or model_semaphore._value is None
+ or model_semaphore._waiters is None
+ ):
+ return 0
+ else:
+ return (
+ args.limit_model_concurrency
+ - model_semaphore._value
+ + len(model_semaphore._waiters)
+ )
+
+ def get_status(self):
+ return {
+ "model_names": [self.model_name],
+ "speed": 1,
+ "queue_length": self.get_queue_length(),
+ }
+
+ async def server_step(self):
+ self.is_server_running = True
+ updated_seq_groups = self.server.step()
+ self.is_server_running = False
+ # Notify the waiting coroutines that there new outputs ready.
+ for seq_group in updated_seq_groups:
+ group_id = seq_group.group_id
+ self.running_seq_groups[group_id] = seq_group
+ self.sequence_group_events[group_id].set()
+
+ async def generate_stream(self, params):
+ tokenizer = self.tokenizer
+ context = params["prompt"]
+ temperature = float(params.get("temperature", 1.0))
+ top_p = float(params.get("top_p", 1.0))
+ max_new_tokens = min(int(params.get("max_new_tokens", 256)), 1024)
+ stop_str = params.get("stop", None)
+ echo = params.get("echo", True)
+ stop_token_ids = params.get("stop_token_ids", None) or []
+ stop_token_ids.append(tokenizer.eos_token_id)
+
+ input_ids = tokenizer(context).input_ids
+ max_src_len = self.context_len - max_new_tokens - 8
+ input_ids = input_ids[-max_src_len:]
+
+ # make sampling params in cacheflow
+ top_p = max(top_p, 1e-5)
+ if temperature <= 1e-5:
+ top_p = 1.0
+ sampling_params = SamplingParams(
+ n=1,
+ temperature=temperature,
+ top_p=top_p,
+ use_beam_search=False,
+ stop_token_ids=stop_token_ids,
+ max_num_steps=max_new_tokens,
+ num_logprobs=0,
+ context_window_size=None,
+ )
+
+ if stop_str is not None:
+ sampling_params.stop_str = stop_str
+ # we might sample multiple sequences, but in chatbot, this is one
+ seqs: List[Sequence] = []
+ for _ in range(sampling_params.n):
+ seq_id = next(self.seq_counter)
+ seq = Sequence(seq_id, input_ids, block_size=self.block_size)
+ seqs.append(seq)
+
+ arrival_time = time.time()
+ group_id = next(self.seq_group_counter)
+ # logger.info(f"Group {group_id} arrives at {time.time()}")
+ seq_group = SequenceGroup(group_id, seqs, arrival_time)
+ group_event = asyncio.Event()
+ self.running_seq_groups[group_id] = seq_group
+ self.sequence_group_events[group_id] = group_event
+ self.server.add_sequence_groups([(seq_group, sampling_params)])
+ while True:
+ if not self.is_server_running:
+ await self.server_step()
+ try:
+ await asyncio.wait_for(
+ group_event.wait(), timeout=TIMEOUT_TO_PREVENT_DEADLOCK
+ )
+ except:
+ pass
+ group_event.clear()
+ seq_group = self.running_seq_groups[group_id]
+ all_outputs = []
+ for seq in seq_group.seqs:
+ token_ids = seq.get_token_ids()
+ if not echo:
+ token_ids = token_ids[len(input_ids) :]
+ output = self.tokenizer.decode(token_ids, skip_special_tokens=True)
+ if stop_str is not None:
+ if output.endswith(stop_str):
+ output = output[: -len(stop_str)]
+ all_outputs.append(output)
+ assert len(seq_group.seqs) == 1
+ ret = {
+ "text": all_outputs[0],
+ "error_code": 0,
+ }
+ yield (json.dumps(ret) + "\0").encode("utf-8")
+ if seq_group.is_finished():
+ del self.running_seq_groups[group_id]
+ del self.sequence_group_events[group_id]
+ break
+
+
+app = FastAPI()
+model_semaphore = None
+
+
+def release_model_semaphore():
+ model_semaphore.release()
+
+
+@app.post("/worker_generate_stream")
+async def generate_stream(request: Request):
+ global model_semaphore, global_counter
+ global_counter += 1
+ params = await request.json()
+
+ if model_semaphore is None:
+ model_semaphore = asyncio.Semaphore(args.limit_model_concurrency)
+ await model_semaphore.acquire()
+ background_tasks = BackgroundTasks()
+ background_tasks.add_task(release_model_semaphore)
+ # return StreamingResponse(generator, background=background_tasks)
+ return StreamingResponse(
+ worker.generate_stream(params), background=background_tasks
+ )
+
+
+@app.post("/worker_get_status")
+async def get_status(request: Request):
+ return worker.get_status()
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--host", type=str, default="localhost")
+ parser.add_argument("--port", type=int, default=21002)
+ parser.add_argument("--worker-address", type=str, default="http://localhost:21002")
+ parser.add_argument(
+ "--controller-address", type=str, default="http://localhost:21001"
+ )
+ parser.add_argument(
+ "--model-path", type=str, default="/home/haozhang/weights/hf-llama-7b"
+ )
+ parser.add_argument("--model-name", type=str)
+ parser.add_argument("--limit-model-concurrency", type=int, default=1024)
+ parser.add_argument("--stream-interval", type=int, default=2)
+ parser.add_argument("--no-register", action="store_true")
+ # cacheflow specific params
+ parser.add_argument(
+ "--block-size", type=int, default=8, choices=[8, 16], help="token block size"
+ )
+ parser.add_argument(
+ "--swap-space", type=int, default=20, help="CPU swap space size (GiB) per GPU"
+ )
+ parser.add_argument(
+ "--max-num-batched-tokens",
+ type=int,
+ default=2560,
+ help="maximum number of batched tokens",
+ )
+ args = parser.parse_args()
+
+ (
+ num_nodes,
+ num_devices_per_node,
+ distributed_init_method,
+ all_stage_devices,
+ ) = initialize_ray_cluster(pipeline_parallel_size=1, tensor_parallel_size=1)
+
+ worker = CacheFlowWorker(
+ args.controller_address,
+ args.worker_address,
+ worker_id,
+ args.no_register,
+ args.model_path,
+ args.model_name,
+ args.block_size,
+ seed,
+ args.swap_space,
+ args.max_num_batched_tokens,
+ distributed_init_method,
+ all_stage_devices,
+ )
+ uvicorn.run(app, host=args.host, port=args.port, log_level="info")
diff --git a/graphgpt/serve/cli.py b/graphgpt/serve/cli.py
new file mode 100644
index 0000000..3d97a17
--- /dev/null
+++ b/graphgpt/serve/cli.py
@@ -0,0 +1,200 @@
+"""
+Chat with a model with command line interface.
+
+Usage:
+python3 -m fastchat.serve.cli --model lmsys/fastchat-t5-3b-v1.0
+python3 -m fastchat.serve.cli --model ~/model_weights/vicuna-7b
+"""
+import argparse
+import os
+import re
+import sys
+
+from prompt_toolkit import PromptSession
+from prompt_toolkit.auto_suggest import AutoSuggestFromHistory
+from prompt_toolkit.completion import WordCompleter
+from prompt_toolkit.history import InMemoryHistory
+from rich.console import Console
+from rich.markdown import Markdown
+from rich.live import Live
+
+from fastchat.model.model_adapter import add_model_args
+from fastchat.serve.inference import chat_loop, ChatIO
+
+
+class SimpleChatIO(ChatIO):
+ def prompt_for_input(self, role) -> str:
+ return input(f"{role}: ")
+
+ def prompt_for_output(self, role: str):
+ print(f"{role}: ", end="", flush=True)
+
+ def stream_output(self, output_stream):
+ pre = 0
+ for outputs in output_stream:
+ output_text = outputs["text"]
+ output_text = output_text.strip().split(" ")
+ now = len(output_text) - 1
+ if now > pre:
+ print(" ".join(output_text[pre:now]), end=" ", flush=True)
+ pre = now
+ print(" ".join(output_text[pre:]), flush=True)
+ return " ".join(output_text)
+
+
+class RichChatIO(ChatIO):
+ def __init__(self):
+ self._prompt_session = PromptSession(history=InMemoryHistory())
+ self._completer = WordCompleter(
+ words=["!exit", "!reset"], pattern=re.compile("$")
+ )
+ self._console = Console()
+
+ def prompt_for_input(self, role) -> str:
+ self._console.print(f"[bold]{role}:")
+ # TODO(suquark): multiline input has some issues. fix it later.
+ prompt_input = self._prompt_session.prompt(
+ completer=self._completer,
+ multiline=False,
+ auto_suggest=AutoSuggestFromHistory(),
+ key_bindings=None,
+ )
+ self._console.print()
+ return prompt_input
+
+ def prompt_for_output(self, role: str):
+ self._console.print(f"[bold]{role}:")
+
+ def stream_output(self, output_stream):
+ """Stream output from a role."""
+ # TODO(suquark): the console flickers when there is a code block
+ # above it. We need to cut off "live" when a code block is done.
+
+ # Create a Live context for updating the console output
+ with Live(console=self._console, refresh_per_second=4) as live:
+ # Read lines from the stream
+ for outputs in output_stream:
+ if not outputs:
+ continue
+ text = outputs["text"]
+ # Render the accumulated text as Markdown
+ # NOTE: this is a workaround for the rendering "unstandard markdown"
+ # in rich. The chatbots output treat "\n" as a new line for
+ # better compatibility with real-world text. However, rendering
+ # in markdown would break the format. It is because standard markdown
+ # treat a single "\n" in normal text as a space.
+ # Our workaround is adding two spaces at the end of each line.
+ # This is not a perfect solution, as it would
+ # introduce trailing spaces (only) in code block, but it works well
+ # especially for console output, because in general the console does not
+ # care about trailing spaces.
+ lines = []
+ for line in text.splitlines():
+ lines.append(line)
+ if line.startswith("```"):
+ # Code block marker - do not add trailing spaces, as it would
+ # break the syntax highlighting
+ lines.append("\n")
+ else:
+ lines.append(" \n")
+ markdown = Markdown("".join(lines))
+ # Update the Live console output
+ live.update(markdown)
+ self._console.print()
+ return text
+
+
+class ProgrammaticChatIO(ChatIO):
+ def prompt_for_input(self, role) -> str:
+ print(f"[!OP:{role}]: ", end="", flush=True)
+ contents = ""
+ # `end_sequence` is a randomly-generated, 16-digit number
+ # that signals the end of a message. It is unlikely to occur in
+ # message content.
+ end_sequence = "9745805894023423"
+ while True:
+ if len(contents) >= 16:
+ last_chars = contents[-16:]
+ if last_chars == end_sequence:
+ break
+ try:
+ char = sys.stdin.read(1)
+ contents = contents + char
+ except EOFError:
+ continue
+ return contents[:-16]
+
+ def prompt_for_output(self, role: str):
+ print(f"[!OP:{role}]: ", end="", flush=True)
+
+ def stream_output(self, output_stream):
+ pre = 0
+ for outputs in output_stream:
+ output_text = outputs["text"]
+ output_text = output_text.strip().split(" ")
+ now = len(output_text) - 1
+ if now > pre:
+ print(" ".join(output_text[pre:now]), end=" ", flush=True)
+ pre = now
+ print(" ".join(output_text[pre:]), flush=True)
+ return " ".join(output_text)
+
+
+def main(args):
+ if args.gpus:
+ if len(args.gpus.split(",")) < args.num_gpus:
+ raise ValueError(
+ f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!"
+ )
+ os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
+
+ if args.style == "simple":
+ chatio = SimpleChatIO()
+ elif args.style == "rich":
+ chatio = RichChatIO()
+ elif args.style == "programmatic":
+ chatio = ProgrammaticChatIO()
+ else:
+ raise ValueError(f"Invalid style for console: {args.style}")
+ try:
+ chat_loop(
+ args.model_path,
+ args.device,
+ args.num_gpus,
+ args.max_gpu_memory,
+ args.load_8bit,
+ args.cpu_offloading,
+ args.conv_template,
+ args.temperature,
+ args.repetition_penalty,
+ args.max_new_tokens,
+ chatio,
+ args.debug,
+ )
+ except KeyboardInterrupt:
+ print("exit...")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ add_model_args(parser)
+ parser.add_argument(
+ "--conv-template", type=str, default=None, help="Conversation prompt template."
+ )
+ parser.add_argument("--temperature", type=float, default=0.7)
+ parser.add_argument("--repetition_penalty", type=float, default=1.0)
+ parser.add_argument("--max-new-tokens", type=int, default=512)
+ parser.add_argument(
+ "--style",
+ type=str,
+ default="simple",
+ choices=["simple", "rich", "programmatic"],
+ help="Display style.",
+ )
+ parser.add_argument(
+ "--debug",
+ action="store_true",
+ help="Print useful debug information (e.g., prompts)",
+ )
+ args = parser.parse_args()
+ main(args)
\ No newline at end of file
diff --git a/graphgpt/serve/controller.py b/graphgpt/serve/controller.py
new file mode 100644
index 0000000..3b1ecf6
--- /dev/null
+++ b/graphgpt/serve/controller.py
@@ -0,0 +1,361 @@
+"""
+A controller manages distributed workers.
+It sends worker addresses to clients.
+"""
+import argparse
+import asyncio
+import dataclasses
+from enum import Enum, auto
+import json
+import logging
+import time
+from typing import List, Union
+import threading
+
+from fastapi import FastAPI, Request
+from fastapi.responses import StreamingResponse
+import numpy as np
+import requests
+import uvicorn
+
+from fastchat.constants import (CONTROLLER_HEART_BEAT_EXPIRATION, ErrorCode,
+ SERVER_ERROR_MSG)
+from fastchat.utils import build_logger
+
+
+logger = build_logger("controller", "controller.log")
+
+
+class DispatchMethod(Enum):
+ LOTTERY = auto()
+ SHORTEST_QUEUE = auto()
+
+ @classmethod
+ def from_str(cls, name):
+ if name == "lottery":
+ return cls.LOTTERY
+ elif name == "shortest_queue":
+ return cls.SHORTEST_QUEUE
+ else:
+ raise ValueError(f"Invalid dispatch method")
+
+
+@dataclasses.dataclass
+class WorkerInfo:
+ model_names: List[str]
+ speed: int
+ queue_length: int
+ check_heart_beat: bool
+ last_heart_beat: str
+
+
+def heart_beat_controller(controller):
+ while True:
+ time.sleep(CONTROLLER_HEART_BEAT_EXPIRATION)
+ controller.remove_stable_workers_by_expiration()
+
+
+class Controller:
+ def __init__(self, dispatch_method: str):
+ # Dict[str -> WorkerInfo]
+ self.worker_info = {}
+ self.dispatch_method = DispatchMethod.from_str(dispatch_method)
+
+ self.heart_beat_thread = threading.Thread(
+ target=heart_beat_controller, args=(self,)
+ )
+ self.heart_beat_thread.start()
+
+ logger.info("Init controller")
+
+ def register_worker(
+ self, worker_name: str, check_heart_beat: bool, worker_status: dict
+ ):
+ if worker_name not in self.worker_info:
+ logger.info(f"Register a new worker: {worker_name}")
+ else:
+ logger.info(f"Register an existing worker: {worker_name}")
+
+ if not worker_status:
+ worker_status = self.get_worker_status(worker_name)
+ if not worker_status:
+ return False
+
+ self.worker_info[worker_name] = WorkerInfo(
+ worker_status["model_names"],
+ worker_status["speed"],
+ worker_status["queue_length"],
+ check_heart_beat,
+ time.time(),
+ )
+
+ logger.info(f"Register done: {worker_name}, {worker_status}")
+ return True
+
+ def get_worker_status(self, worker_name: str):
+ try:
+ r = requests.post(worker_name + "/worker_get_status", timeout=5)
+ except requests.exceptions.RequestException as e:
+ logger.error(f"Get status fails: {worker_name}, {e}")
+ return None
+
+ if r.status_code != 200:
+ logger.error(f"Get status fails: {worker_name}, {r}")
+ return None
+
+ return r.json()
+
+ def remove_worker(self, worker_name: str):
+ del self.worker_info[worker_name]
+
+ def refresh_all_workers(self):
+ old_info = dict(self.worker_info)
+ self.worker_info = {}
+
+ for w_name, w_info in old_info.items():
+ if not self.register_worker(w_name, w_info.check_heart_beat, None):
+ logger.info(f"Remove stale worker: {w_name}")
+
+ def list_models(self):
+ model_names = set()
+
+ for w_name, w_info in self.worker_info.items():
+ model_names.update(w_info.model_names)
+
+ return list(model_names)
+
+ def get_worker_address(self, model_name: str):
+ if self.dispatch_method == DispatchMethod.LOTTERY:
+ worker_names = []
+ worker_speeds = []
+ for w_name, w_info in self.worker_info.items():
+ if model_name in w_info.model_names:
+ worker_names.append(w_name)
+ worker_speeds.append(w_info.speed)
+ worker_speeds = np.array(worker_speeds, dtype=np.float32)
+ norm = np.sum(worker_speeds)
+ if norm < 1e-4:
+ return ""
+ worker_speeds = worker_speeds / norm
+ if True: # Directly return address
+ pt = np.random.choice(np.arange(len(worker_names)), p=worker_speeds)
+ worker_name = worker_names[pt]
+ return worker_name
+
+ # Check status before returning
+ while True:
+ pt = np.random.choice(np.arange(len(worker_names)), p=worker_speeds)
+ worker_name = worker_names[pt]
+
+ if self.get_worker_status(worker_name):
+ break
+ else:
+ self.remove_worker(worker_name)
+ worker_speeds[pt] = 0
+ norm = np.sum(worker_speeds)
+ if norm < 1e-4:
+ return ""
+ worker_speeds = worker_speeds / norm
+ continue
+ return worker_name
+ elif self.dispatch_method == DispatchMethod.SHORTEST_QUEUE:
+ worker_names = []
+ worker_qlen = []
+ for w_name, w_info in self.worker_info.items():
+ if model_name in w_info.model_names:
+ worker_names.append(w_name)
+ worker_qlen.append(w_info.queue_length / w_info.speed)
+ if len(worker_names) == 0:
+ return ""
+ min_index = np.argmin(worker_qlen)
+ w_name = worker_names[min_index]
+ self.worker_info[w_name].queue_length += 1
+ logger.info(
+ f"names: {worker_names}, queue_lens: {worker_qlen}, ret: {w_name}"
+ )
+ return w_name
+ else:
+ raise ValueError(f"Invalid dispatch method: {self.dispatch_method}")
+
+ def receive_heart_beat(self, worker_name: str, queue_length: int):
+ if worker_name not in self.worker_info:
+ logger.info(f"Receive unknown heart beat. {worker_name}")
+ return False
+
+ self.worker_info[worker_name].queue_length = queue_length
+ self.worker_info[worker_name].last_heart_beat = time.time()
+ logger.info(f"Receive heart beat. {worker_name}")
+ return True
+
+ def remove_stable_workers_by_expiration(self):
+ expire = time.time() - CONTROLLER_HEART_BEAT_EXPIRATION
+ to_delete = []
+ for worker_name, w_info in self.worker_info.items():
+ if w_info.check_heart_beat and w_info.last_heart_beat < expire:
+ to_delete.append(worker_name)
+
+ for worker_name in to_delete:
+ self.remove_worker(worker_name)
+
+ def handle_no_worker(params):
+ logger.info(f"no worker: {params['model']}")
+ ret = {
+ "text": SERVER_ERROR_MSG,
+ "error_code": ErrorCode.CONTROLLER_NO_WORKER,
+ }
+ return json.dumps(ret).encode() + b"\0"
+
+ def handle_worker_timeout(worker_address):
+ logger.info(f"worker timeout: {worker_address}")
+ ret = {
+ "text": SERVER_ERROR_MSG,
+ "error_code": ErrorCode.CONTROLLER_WORKER_TIMEOUT,
+ }
+ return json.dumps(ret).encode() + b"\0"
+
+ def worker_api_generate_stream(self, params):
+ worker_addr = self.get_worker_address(params["model"])
+ if not worker_addr:
+ yield self.handle_no_worker(params)
+
+ try:
+ response = requests.post(
+ worker_addr + "/worker_generate_stream",
+ json=params,
+ stream=True,
+ timeout=15,
+ )
+ for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
+ if chunk:
+ yield chunk + b"\0"
+ except requests.exceptions.RequestException as e:
+ yield self.handle_worker_timeout(worker_addr)
+
+ def worker_api_generate_completion(self, params):
+ worker_addr = self.get_worker_address(params["model"])
+ if not worker_addr:
+ return self.handle_no_worker(params)
+
+ try:
+ response = requests.post(
+ worker_addr + "/worker_generate_completion",
+ json=params,
+ timeout=15,
+ )
+ return response.json()
+ except requests.exceptions.RequestException as e:
+ return self.handle_worker_timeout(worker_addr)
+
+ def worker_api_embeddings(self, params):
+ worker_addr = self.get_worker_address(params["model"])
+ if not worker_addr:
+ return self.handle_no_worker(params)
+
+ try:
+ response = requests.post(
+ worker_addr + "/worker_get_embeddings",
+ json=params,
+ timeout=15,
+ )
+ return response.json()
+ except requests.exceptions.RequestException as e:
+ return self.handle_worker_timeout(worker_addr)
+
+ # Let the controller act as a worker to achieve hierarchical
+ # management. This can be used to connect isolated sub networks.
+ def worker_api_get_status(self):
+ model_names = set()
+ speed = 0
+ queue_length = 0
+
+ for w_name in self.worker_info:
+ worker_status = self.get_worker_status(w_name)
+ if worker_status is not None:
+ model_names.update(worker_status["model_names"])
+ speed += worker_status["speed"]
+ queue_length += worker_status["queue_length"]
+
+ return {
+ "model_names": list(model_names),
+ "speed": speed,
+ "queue_length": queue_length,
+ }
+
+
+app = FastAPI()
+
+
+@app.post("/register_worker")
+async def register_worker(request: Request):
+ data = await request.json()
+ controller.register_worker(
+ data["worker_name"], data["check_heart_beat"], data.get("worker_status", None)
+ )
+
+
+@app.post("/refresh_all_workers")
+async def refresh_all_workers():
+ models = controller.refresh_all_workers()
+
+
+@app.post("/list_models")
+async def list_models():
+ models = controller.list_models()
+ return {"models": models}
+
+
+@app.post("/get_worker_address")
+async def get_worker_address(request: Request):
+ data = await request.json()
+ addr = controller.get_worker_address(data["model"])
+ return {"address": addr}
+
+
+@app.post("/receive_heart_beat")
+async def receive_heart_beat(request: Request):
+ data = await request.json()
+ exist = controller.receive_heart_beat(data["worker_name"], data["queue_length"])
+ return {"exist": exist}
+
+
+@app.post("/worker_generate_stream")
+async def worker_api_generate_stream(request: Request):
+ params = await request.json()
+ generator = controller.worker_api_generate_stream(params)
+ return StreamingResponse(generator)
+
+
+@app.post("/worker_generate_completion")
+async def worker_api_generate_completion(request: Request):
+ params = await request.json()
+ output = controller.worker_api_generate_completion(params)
+ return output
+
+
+@app.post("/worker_get_embeddings")
+async def worker_api_embeddings(request: Request):
+ params = await request.json()
+ output = controller.worker_api_embeddings(params)
+ return output
+
+
+@app.post("/worker_get_status")
+async def worker_api_get_status(request: Request):
+ return controller.worker_api_get_status()
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--host", type=str, default="localhost")
+ parser.add_argument("--port", type=int, default=21001)
+ parser.add_argument(
+ "--dispatch-method",
+ type=str,
+ choices=["lottery", "shortest_queue"],
+ default="shortest_queue",
+ )
+ args = parser.parse_args()
+ logger.info(f"args: {args}")
+
+ controller = Controller(args.dispatch_method)
+ uvicorn.run(app, host=args.host, port=args.port, log_level="info")
diff --git a/graphgpt/serve/controller_graph.py b/graphgpt/serve/controller_graph.py
new file mode 100644
index 0000000..f5bfb9c
--- /dev/null
+++ b/graphgpt/serve/controller_graph.py
@@ -0,0 +1,298 @@
+"""
+A controller manages distributed workers.
+It sends worker addresses to clients.
+"""
+import argparse
+import asyncio
+import dataclasses
+from enum import Enum, auto
+import json
+import logging
+import time
+from typing import List, Union
+import threading
+
+from fastapi import FastAPI, Request
+from fastapi.responses import StreamingResponse
+import numpy as np
+import requests
+import uvicorn
+
+from graphgpt.constants import CONTROLLER_HEART_BEAT_EXPIRATION
+from graphgpt.utils import build_logger, server_error_msg
+
+
+logger = build_logger("controller", "controller.log")
+
+
+class DispatchMethod(Enum):
+ LOTTERY = auto()
+ SHORTEST_QUEUE = auto()
+
+ @classmethod
+ def from_str(cls, name):
+ if name == "lottery":
+ return cls.LOTTERY
+ elif name == "shortest_queue":
+ return cls.SHORTEST_QUEUE
+ else:
+ raise ValueError(f"Invalid dispatch method")
+
+
+@dataclasses.dataclass
+class WorkerInfo:
+ model_names: List[str]
+ speed: int
+ queue_length: int
+ check_heart_beat: bool
+ last_heart_beat: str
+
+
+def heart_beat_controller(controller):
+ while True:
+ time.sleep(CONTROLLER_HEART_BEAT_EXPIRATION)
+ controller.remove_stable_workers_by_expiration()
+
+
+class Controller:
+ def __init__(self, dispatch_method: str):
+ # Dict[str -> WorkerInfo]
+ self.worker_info = {}
+ self.dispatch_method = DispatchMethod.from_str(dispatch_method)
+
+ self.heart_beat_thread = threading.Thread(
+ target=heart_beat_controller, args=(self,))
+ self.heart_beat_thread.start()
+
+ logger.info("Init controller")
+
+ def register_worker(self, worker_name: str, check_heart_beat: bool,
+ worker_status: dict):
+ if worker_name not in self.worker_info:
+ logger.info(f"Register a new worker: {worker_name}")
+ else:
+ logger.info(f"Register an existing worker: {worker_name}")
+
+ if not worker_status:
+ worker_status = self.get_worker_status(worker_name)
+ if not worker_status:
+ return False
+
+ self.worker_info[worker_name] = WorkerInfo(
+ worker_status["model_names"], worker_status["speed"], worker_status["queue_length"],
+ check_heart_beat, time.time())
+
+ logger.info(f"Register done: {worker_name}, {worker_status}")
+ return True
+
+ def get_worker_status(self, worker_name: str):
+ try:
+ r = requests.post(worker_name + "/worker_get_status", timeout=5)
+ except requests.exceptions.RequestException as e:
+ logger.error(f"Get status fails: {worker_name}, {e}")
+ return None
+
+ if r.status_code != 200:
+ logger.error(f"Get status fails: {worker_name}, {r}")
+ return None
+
+ return r.json()
+
+ def remove_worker(self, worker_name: str):
+ del self.worker_info[worker_name]
+
+ def refresh_all_workers(self):
+ old_info = dict(self.worker_info)
+ self.worker_info = {}
+
+ for w_name, w_info in old_info.items():
+ if not self.register_worker(w_name, w_info.check_heart_beat, None):
+ logger.info(f"Remove stale worker: {w_name}")
+
+ def list_models(self):
+ model_names = set()
+
+ for w_name, w_info in self.worker_info.items():
+ model_names.update(w_info.model_names)
+
+ return list(model_names)
+
+ def get_worker_address(self, model_name: str):
+ if self.dispatch_method == DispatchMethod.LOTTERY:
+ worker_names = []
+ worker_speeds = []
+ for w_name, w_info in self.worker_info.items():
+ if model_name in w_info.model_names:
+ worker_names.append(w_name)
+ worker_speeds.append(w_info.speed)
+ worker_speeds = np.array(worker_speeds, dtype=np.float32)
+ norm = np.sum(worker_speeds)
+ if norm < 1e-4:
+ return ""
+ worker_speeds = worker_speeds / norm
+ if True: # Directly return address
+ pt = np.random.choice(np.arange(len(worker_names)),
+ p=worker_speeds)
+ worker_name = worker_names[pt]
+ return worker_name
+
+ # Check status before returning
+ while True:
+ pt = np.random.choice(np.arange(len(worker_names)),
+ p=worker_speeds)
+ worker_name = worker_names[pt]
+
+ if self.get_worker_status(worker_name):
+ break
+ else:
+ self.remove_worker(worker_name)
+ worker_speeds[pt] = 0
+ norm = np.sum(worker_speeds)
+ if norm < 1e-4:
+ return ""
+ worker_speeds = worker_speeds / norm
+ continue
+ return worker_name
+ elif self.dispatch_method == DispatchMethod.SHORTEST_QUEUE:
+ worker_names = []
+ worker_qlen = []
+ for w_name, w_info in self.worker_info.items():
+ if model_name in w_info.model_names:
+ worker_names.append(w_name)
+ worker_qlen.append(w_info.queue_length / w_info.speed)
+ if len(worker_names) == 0:
+ return ""
+ min_index = np.argmin(worker_qlen)
+ w_name = worker_names[min_index]
+ self.worker_info[w_name].queue_length += 1
+ logger.info(f"names: {worker_names}, queue_lens: {worker_qlen}, ret: {w_name}")
+ return w_name
+ else:
+ raise ValueError(f"Invalid dispatch method: {self.dispatch_method}")
+
+ def receive_heart_beat(self, worker_name: str, queue_length: int):
+ if worker_name not in self.worker_info:
+ logger.info(f"Receive unknown heart beat. {worker_name}")
+ return False
+
+ self.worker_info[worker_name].queue_length = queue_length
+ self.worker_info[worker_name].last_heart_beat = time.time()
+ logger.info(f"Receive heart beat. {worker_name}")
+ return True
+
+ def remove_stable_workers_by_expiration(self):
+ expire = time.time() - CONTROLLER_HEART_BEAT_EXPIRATION
+ to_delete = []
+ for worker_name, w_info in self.worker_info.items():
+ if w_info.check_heart_beat and w_info.last_heart_beat < expire:
+ to_delete.append(worker_name)
+
+ for worker_name in to_delete:
+ self.remove_worker(worker_name)
+
+ def worker_api_generate_stream(self, params):
+ worker_addr = self.get_worker_address(params["model"])
+ if not worker_addr:
+ logger.info(f"no worker: {params['model']}")
+ ret = {
+ "text": server_error_msg,
+ "error_code": 2,
+ }
+ yield json.dumps(ret).encode() + b"\0"
+
+ try:
+ response = requests.post(worker_addr + "/worker_generate_stream",
+ json=params, stream=True, timeout=5)
+ for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
+ if chunk:
+ yield chunk + b"\0"
+ except requests.exceptions.RequestException as e:
+ logger.info(f"worker timeout: {worker_addr}")
+ ret = {
+ "text": server_error_msg,
+ "error_code": 3,
+ }
+ yield json.dumps(ret).encode() + b"\0"
+
+
+ # Let the controller act as a worker to achieve hierarchical
+ # management. This can be used to connect isolated sub networks.
+ def worker_api_get_status(self):
+ model_names = set()
+ speed = 0
+ queue_length = 0
+
+ for w_name in self.worker_info:
+ worker_status = self.get_worker_status(w_name)
+ if worker_status is not None:
+ model_names.update(worker_status["model_names"])
+ speed += worker_status["speed"]
+ queue_length += worker_status["queue_length"]
+
+ return {
+ "model_names": list(model_names),
+ "speed": speed,
+ "queue_length": queue_length,
+ }
+
+
+app = FastAPI()
+
+
+@app.post("/register_worker")
+async def register_worker(request: Request):
+ data = await request.json()
+ controller.register_worker(
+ data["worker_name"], data["check_heart_beat"],
+ data.get("worker_status", None))
+
+
+@app.post("/refresh_all_workers")
+async def refresh_all_workers():
+ models = controller.refresh_all_workers()
+
+
+@app.post("/list_models")
+async def list_models():
+ models = controller.list_models()
+ return {"models": models}
+
+
+@app.post("/get_worker_address")
+async def get_worker_address(request: Request):
+ data = await request.json()
+ addr = controller.get_worker_address(data["model"])
+ return {"address": addr}
+
+
+@app.post("/receive_heart_beat")
+async def receive_heart_beat(request: Request):
+ data = await request.json()
+ exist = controller.receive_heart_beat(
+ data["worker_name"], data["queue_length"])
+ return {"exist": exist}
+
+
+@app.post("/worker_generate_stream")
+async def worker_api_generate_stream(request: Request):
+ params = await request.json()
+ generator = controller.worker_api_generate_stream(params)
+ return StreamingResponse(generator)
+
+
+@app.post("/worker_get_status")
+async def worker_api_get_status(request: Request):
+ return controller.worker_api_get_status()
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--host", type=str, default="localhost")
+ parser.add_argument("--port", type=int, default=21001)
+ parser.add_argument("--dispatch-method", type=str, choices=[
+ "lottery", "shortest_queue"], default="shortest_queue")
+ args = parser.parse_args()
+ logger.info(f"args: {args}")
+
+ controller = Controller(args.dispatch_method)
+ uvicorn.run(app, host=args.host, port=args.port, log_level="info")
\ No newline at end of file
diff --git a/graphgpt/serve/examples/extreme_ironing.jpg b/graphgpt/serve/examples/extreme_ironing.jpg
new file mode 100644
index 0000000..638b078
Binary files /dev/null and b/graphgpt/serve/examples/extreme_ironing.jpg differ
diff --git a/graphgpt/serve/examples/waterview.jpg b/graphgpt/serve/examples/waterview.jpg
new file mode 100644
index 0000000..6f44eba
Binary files /dev/null and b/graphgpt/serve/examples/waterview.jpg differ
diff --git a/graphgpt/serve/gateway/README.md b/graphgpt/serve/gateway/README.md
new file mode 100644
index 0000000..b3afaf1
--- /dev/null
+++ b/graphgpt/serve/gateway/README.md
@@ -0,0 +1,57 @@
+# fastchat Nginx Gateway
+
+## Purpose of the Gateway
+
+The Nginx gateway serves the following purposes:
+
+1. Protects Gradio servers by acting as a firewall.
+2. Facilitates dynamic mounting and unmounting of Gradio servers.
+3. Provides load balancing for Gradio servers.
+4. Offers additional security features, such as total connection limit.
+5. Reduces attack surface by requiring only a single public port to be exposed for serving.
+
+## Deployment and Updating of the Gateway
+
+### Installing Nginx
+
+On Debian-based distributions (e.g., Ubuntu):
+
+```bash
+sudo apt update
+sudo apt install nginx
+```
+On Red Hat-based distributions (e.g., CentOS, Fedora):
+
+```bash
+sudo yum install epel-release
+sudo yum install nginx
+```
+
+### Deployment
+
+Copy `nginx.conf` to `/etc/nginx/nginx.conf` (need sudo permission).
+
+Replace the port number 7860 in `server localhost:7860` with the port where you deploy the Gradio web server.
+
+Modify `upstream websocket` to configure Gradio servers behind the gateway.
+
+Lastly, update Nginx.
+
+
+### HTTPS Deployment with a Public Domain URL
+
+Make sure you obtain the HTTPS certificate and the private key used to generate the certificate.
+
+Fill the addresses to your certificate and private key in the `[PATH_TO_SSL_CERT]` and `[PATH_TO_PRIVATE_KEY]` fields.
+
+If you have your own domain url to serve the chatbot, replace the chat.lmsys.org url with your own domain url.
+
+### Updating
+
+Every time when `/etc/nginx/nginx.conf` is modified, you need to update the Nginx service:
+
+```bash
+sudo nginx -t # check `/etc/nginx/nginx.conf`
+sudo systemctl reload nginx # restart Nginx service to load the new config
+sudo systemctl status nginx # check the status of the Nginx service. It should be active (running).
+```
diff --git a/graphgpt/serve/gateway/nginx.conf b/graphgpt/serve/gateway/nginx.conf
new file mode 100644
index 0000000..b88ca8c
--- /dev/null
+++ b/graphgpt/serve/gateway/nginx.conf
@@ -0,0 +1,97 @@
+user www-data;
+worker_processes auto;
+pid /run/nginx.pid;
+include /etc/nginx/modules-enabled/*.conf;
+
+events {
+ worker_connections 1024; # maximum number of connections that a worker process can handle concurrently
+ # multi_accept on; # enabling multi_accept can help improve performance under high load, but may increase the number of simultaneous connections that a worker process can handle
+
+}
+
+http {
+ ##
+ # Basic Settings
+ ##
+
+ sendfile on; # enable sendfile for performance optimization
+ tcp_nopush on; # enable TCP no-pushing
+ tcp_nodelay on; # enable TCP no-delay
+ keepalive_timeout 65; # sets the timeout for keep-alive connections
+ types_hash_max_size 2048; # maximum size of the types hash table
+ # server_tokens off; # disable server token (i.e., server signature) in response headers to improve security
+
+ # server_names_hash_bucket_size 64;
+ # server_name_in_redirect off;
+
+ include /etc/nginx/mime.types; # include MIME types file
+ default_type application/octet-stream; # default MIME type for unknown file types
+
+ ##
+ # SSL Settings
+ ##
+
+ ssl_protocols TLSv1.2; # specify SSL/TLS protocols to use
+ ssl_prefer_server_ciphers on; # prefer server ciphers over client ciphers
+
+ ##
+ # Logging Settings
+ ##
+
+ access_log /var/log/nginx/access.log; # path to access log file
+ error_log /var/log/nginx/error.log; # path to error log file
+
+ ##
+ # Gzip Settings
+ ##
+ gzip on; # enable Gzip compression
+
+ ##
+ # Virtual Host Configs
+ ##
+
+ include /etc/nginx/conf.d/*.conf; # include all configuration files in conf.d directory
+ include /etc/nginx/sites-enabled/*; # include all enabled sites configuration files
+
+ # WebSocket Proxy: https://www.nginx.com/blog/websocket-nginx/
+ map $http_upgrade $connection_upgrade {
+ default upgrade;
+ '' close;
+ }
+
+ upstream websocket {
+ ip_hash; # load balancing by IP to guarantee session persistence
+ server localhost:7860; # The port should be the gradio web server port
+ # server localhost:7861; # extra gradio server if more than one
+ }
+
+ limit_conn_status 429;
+ limit_conn_zone $binary_remote_addr zone=perip:10m; # limit number of connections per IP
+ limit_conn_zone $server_name zone=perserver:10m; # limit number of connections per server
+
+ server {
+ listen 443 ssl; # the listening port of our server
+ ssl_certificate [PATH_TO_SSL_CERT];
+ ssl_certificate_key [PATH_TO_PRIVATE_KEY];
+ server_name chat.lmsys.org; # replace the url with your own domain url
+ limit_conn perserver 1024; # connections per server
+ location / {
+ proxy_pass http://websocket; # proxy all requests to the defined upstream server
+ limit_conn perip 5; # connections per IP
+ proxy_set_header Host $host; # set the Host header for the upstream server
+ proxy_set_header X-Real-IP $remote_addr; # set the client IP address as the real IP for the upstream server
+ proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; # set the client IP addresses in the X-Forwarded-For header
+ proxy_http_version 1.1; # use HTTP version 1.1 for upstream communication
+ proxy_set_header Upgrade $http_upgrade;
+ proxy_set_header Connection "Upgrade"; # set the Connection header to Upgrade to enable WebSocket communication
+ }
+ }
+
+ # the following block routes all HTTP traffic to HTTPS via nginx
+ server {
+ listen 80;
+ server_name chat.lmsys.org;
+ return 301 https://chat.lmsys.org$request_uri;
+ }
+
+}
diff --git a/graphgpt/serve/gradio_block_arena_anony.py b/graphgpt/serve/gradio_block_arena_anony.py
new file mode 100644
index 0000000..217d9f1
--- /dev/null
+++ b/graphgpt/serve/gradio_block_arena_anony.py
@@ -0,0 +1,506 @@
+"""
+Chatbot Arena (battle) tab.
+Users chat with two anonymous models.
+"""
+
+import json
+import time
+
+import gradio as gr
+import numpy as np
+
+from fastchat.constants import (
+ MODERATION_MSG,
+ CONVERSATION_LIMIT_MSG,
+ INPUT_CHAR_LEN_LIMIT,
+ CONVERSATION_LEN_LIMIT,
+)
+from fastchat.model.model_adapter import get_conversation_template
+from fastchat.serve.gradio_patch import Chatbot as grChatbot
+from fastchat.serve.gradio_web_server import (
+ State,
+ http_bot,
+ get_conv_log_filename,
+ no_change_btn,
+ enable_btn,
+ disable_btn,
+ learn_more_md,
+)
+from fastchat.utils import (
+ build_logger,
+ violates_moderation,
+)
+
+logger = build_logger("gradio_web_server_multi", "gradio_web_server_multi.log")
+
+num_models = 2
+enable_moderation = False
+anony_names = ["", ""]
+models = []
+
+
+def set_global_vars_anony(enable_moderation_):
+ global enable_moderation
+ enable_moderation = enable_moderation_
+
+
+def load_demo_side_by_side_anony(models_, url_params):
+ global models
+ models = models_
+
+ states = (None,) * num_models
+ selector_updates = (
+ gr.Markdown.update(visible=True),
+ gr.Markdown.update(visible=True),
+ )
+
+ return (
+ states
+ + selector_updates
+ + (gr.Chatbot.update(visible=True),) * num_models
+ + (
+ gr.Textbox.update(visible=True),
+ gr.Box.update(visible=True),
+ gr.Row.update(visible=True),
+ gr.Row.update(visible=True),
+ gr.Accordion.update(visible=True),
+ )
+ )
+
+
+def vote_last_response(states, vote_type, model_selectors, request: gr.Request):
+ with open(get_conv_log_filename(), "a") as fout:
+ data = {
+ "tstamp": round(time.time(), 4),
+ "type": vote_type,
+ "models": [x for x in model_selectors],
+ "states": [x.dict() for x in states],
+ "ip": request.client.host,
+ }
+ fout.write(json.dumps(data) + "\n")
+
+ if ":" not in model_selectors[0]:
+ for i in range(15):
+ names = (
+ "### Model A: " + states[0].model_name,
+ "### Model B: " + states[1].model_name,
+ )
+ yield names + ("",) + (disable_btn,) * 4
+ time.sleep(0.2)
+ else:
+ names = (
+ "### Model A: " + states[0].model_name,
+ "### Model B: " + states[1].model_name,
+ )
+ yield names + ("",) + (disable_btn,) * 4
+
+
+def leftvote_last_response(
+ state0, state1, model_selector0, model_selector1, request: gr.Request
+):
+ logger.info(f"leftvote (anony). ip: {request.client.host}")
+ for x in vote_last_response(
+ [state0, state1], "leftvote", [model_selector0, model_selector1], request
+ ):
+ yield x
+
+
+def rightvote_last_response(
+ state0, state1, model_selector0, model_selector1, request: gr.Request
+):
+ logger.info(f"rightvote (anony). ip: {request.client.host}")
+ for x in vote_last_response(
+ [state0, state1], "rightvote", [model_selector0, model_selector1], request
+ ):
+ yield x
+
+
+def tievote_last_response(
+ state0, state1, model_selector0, model_selector1, request: gr.Request
+):
+ logger.info(f"tievote (anony). ip: {request.client.host}")
+ for x in vote_last_response(
+ [state0, state1], "tievote", [model_selector0, model_selector1], request
+ ):
+ yield x
+
+
+def bothbad_vote_last_response(
+ state0, state1, model_selector0, model_selector1, request: gr.Request
+):
+ logger.info(f"bothbad_vote (anony). ip: {request.client.host}")
+ for x in vote_last_response(
+ [state0, state1], "bothbad_vote", [model_selector0, model_selector1], request
+ ):
+ yield x
+
+
+def regenerate(state0, state1, request: gr.Request):
+ logger.info(f"regenerate (anony). ip: {request.client.host}")
+ states = [state0, state1]
+ for i in range(num_models):
+ states[i].conv.update_last_message(None)
+ return states + [x.to_gradio_chatbot() for x in states] + [""] + [disable_btn] * 6
+
+
+def clear_history(request: gr.Request):
+ logger.info(f"clear_history (anony). ip: {request.client.host}")
+ return (
+ [None] * num_models
+ + [None] * num_models
+ + anony_names
+ + [""]
+ + [disable_btn] * 6
+ )
+
+
+def share_click(state0, state1, model_selector0, model_selector1, request: gr.Request):
+ logger.info(f"share (anony). ip: {request.client.host}")
+ if state0 is not None and state1 is not None:
+ vote_last_response(
+ [state0, state1], "share", [model_selector0, model_selector1], request
+ )
+
+
+DEFAULT_WEIGHTS = {
+ "gpt-4": 1.5,
+ "gpt-3.5-turbo": 1.5,
+ "claude-v1": 1.5,
+ "claude-instant-v1": 1.5,
+ "bard": 1.5,
+ "vicuna-13b": 1.5,
+ "koala-13b": 1.5,
+ "vicuna-7b": 1.2,
+ "mpt-7b-chat": 1.2,
+ "oasst-pythia-12b": 1.2,
+ "RWKV-4-Raven-14B": 1.2,
+ "fastchat-t5-3b": 1,
+ "alpaca-13b": 1,
+ "chatglm-6b": 1,
+ "stablelm-tuned-alpha-7b": 0.5,
+ "dolly-v2-12b": 0.5,
+ "llama-13b": 0.1,
+}
+
+
+def add_text(
+ state0, state1, model_selector0, model_selector1, text, request: gr.Request
+):
+ logger.info(f"add_text (anony). ip: {request.client.host}. len: {len(text)}")
+ states = [state0, state1]
+ model_selectors = [model_selector0, model_selector1]
+
+ if states[0] is None:
+ assert states[1] is None
+ weights = [DEFAULT_WEIGHTS.get(m, 1.0) for m in models]
+ if len(models) > 1:
+ weights = weights / np.sum(weights)
+ model_left, model_right = np.random.choice(
+ models, size=(2,), p=weights, replace=False
+ )
+ else:
+ model_left = model_right = models[0]
+
+ states = [
+ State(model_left),
+ State(model_right),
+ ]
+
+ if len(text) <= 0:
+ for i in range(num_models):
+ states[i].skip_next = True
+ return (
+ states
+ + [x.to_gradio_chatbot() for x in states]
+ + [""]
+ + [
+ no_change_btn,
+ ]
+ * 6
+ )
+
+ if enable_moderation:
+ flagged = violates_moderation(text)
+ if flagged:
+ logger.info(
+ f"violate moderation (anony). ip: {request.client.host}. text: {text}"
+ )
+ for i in range(num_models):
+ states[i].skip_next = True
+ return (
+ states
+ + [x.to_gradio_chatbot() for x in states]
+ + [MODERATION_MSG]
+ + [
+ no_change_btn,
+ ]
+ * 6
+ )
+
+ conv = states[0].conv
+ if (len(conv.messages) - conv.offset) // 2 >= CONVERSATION_LEN_LIMIT:
+ logger.info(
+ f"hit conversation length limit. ip: {request.client.host}. text: {text}"
+ )
+ for i in range(num_models):
+ states[i].skip_next = True
+ return (
+ states
+ + [x.to_gradio_chatbot() for x in states]
+ + [CONVERSATION_LIMIT_MSG]
+ + [
+ no_change_btn,
+ ]
+ * 6
+ )
+
+ text = text[:INPUT_CHAR_LEN_LIMIT] # Hard cut-off
+ for i in range(num_models):
+ states[i].conv.append_message(states[i].conv.roles[0], text)
+ states[i].conv.append_message(states[i].conv.roles[1], None)
+ states[i].skip_next = False
+
+ return (
+ states
+ + [x.to_gradio_chatbot() for x in states]
+ + [""]
+ + [
+ disable_btn,
+ ]
+ * 6
+ )
+
+
+def http_bot_all(
+ state0,
+ state1,
+ temperature,
+ top_p,
+ max_new_tokens,
+ request: gr.Request,
+):
+ logger.info(f"http_bot_all (anony). ip: {request.client.host}")
+
+ if state0.skip_next:
+ # This generate call is skipped due to invalid inputs
+ yield (
+ state0,
+ state1,
+ state0.to_gradio_chatbot(),
+ state1.to_gradio_chatbot(),
+ ) + (no_change_btn,) * 6
+ return
+
+ states = [state0, state1]
+ gen = []
+ for i in range(num_models):
+ gen.append(
+ http_bot(
+ states[i],
+ temperature,
+ top_p,
+ max_new_tokens,
+ request,
+ )
+ )
+
+ chatbots = [None] * num_models
+ while True:
+ stop = True
+ for i in range(num_models):
+ try:
+ ret = next(gen[i])
+ states[i], chatbots[i] = ret[0], ret[1]
+ stop = False
+ except StopIteration:
+ pass
+ yield states + chatbots + [disable_btn] * 6
+ if stop:
+ break
+
+ for i in range(10):
+ if i % 2 == 0:
+ yield states + chatbots + [disable_btn] * 4 + [enable_btn] * 2
+ else:
+ yield states + chatbots + [enable_btn] * 6
+ time.sleep(0.2)
+
+
+def build_side_by_side_ui_anony(models):
+ notice_markdown = """
+# ⚔️ Chatbot Arena ⚔️
+### Rules
+- Chat with two anonymous models side-by-side and vote for which one is better!
+- You can do multiple rounds of conversations before voting.
+- The names of the models will be revealed after your vote. Conversations with identity keywords (e.g., ChatGPT, Bard, Vicuna) or any votes after the names are revealed will not count towards the leaderboard.
+- Click "Clear history" to start a new round.
+- [[Blog](https://lmsys.org/blog/2023-05-03-arena/)] [[GitHub]](https://github.com/lm-sys/FastChat) [[Twitter]](https://twitter.com/lmsysorg) [[Discord]](https://discord.gg/KjdtsE9V)
+
+### Terms of use
+By using this service, users are required to agree to the following terms: The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. **The service collects user dialogue data and reserves the right to distribute it under a Creative Commons Attribution (CC-BY) license.** The demo works better on desktop devices with a wide screen.
+
+### Battle
+Please scroll down and start chatting. You can view a leaderboard of participating models in the fourth tab above labeled 'Leaderboard' or by clicking [here](?leaderboard). The models include both closed-source models (e.g., ChatGPT) and open-source models (e.g., Vicuna).
+"""
+
+ states = [gr.State() for _ in range(num_models)]
+ model_selectors = [None] * num_models
+ chatbots = [None] * num_models
+
+ gr.Markdown(notice_markdown, elem_id="notice_markdown")
+
+ with gr.Box(elem_id="share-region-anony"):
+ with gr.Row():
+ for i in range(num_models):
+ with gr.Column():
+ model_selectors[i] = gr.Markdown(anony_names[i])
+
+ with gr.Row():
+ for i in range(num_models):
+ label = "Model A" if i == 0 else "Model B"
+ with gr.Column():
+ chatbots[i] = grChatbot(
+ label=label, elem_id=f"chatbot", visible=False
+ ).style(height=550)
+
+ with gr.Box() as button_row:
+ with gr.Row():
+ leftvote_btn = gr.Button(value="👈 A is better", interactive=False)
+ rightvote_btn = gr.Button(value="👉 B is better", interactive=False)
+ tie_btn = gr.Button(value="🤝 Tie", interactive=False)
+ bothbad_btn = gr.Button(value="👎 Both are bad", interactive=False)
+
+ with gr.Row():
+ with gr.Column(scale=20):
+ textbox = gr.Textbox(
+ show_label=False,
+ placeholder="Enter text and press ENTER",
+ visible=False,
+ ).style(container=False)
+ with gr.Column(scale=1, min_width=50):
+ send_btn = gr.Button(value="Send", visible=False)
+
+ with gr.Row() as button_row2:
+ regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
+ clear_btn = gr.Button(value="🗑️ Clear history", interactive=False)
+ share_btn = gr.Button(value="📷 Share")
+
+ with gr.Accordion("Parameters", open=False, visible=True) as parameter_row:
+ temperature = gr.Slider(
+ minimum=0.0,
+ maximum=1.0,
+ value=0.7,
+ step=0.1,
+ interactive=True,
+ label="Temperature",
+ )
+ top_p = gr.Slider(
+ minimum=0.0,
+ maximum=1.0,
+ value=1.0,
+ step=0.1,
+ interactive=True,
+ label="Top P",
+ )
+ max_output_tokens = gr.Slider(
+ minimum=16,
+ maximum=1024,
+ value=512,
+ step=64,
+ interactive=True,
+ label="Max output tokens",
+ )
+
+ gr.Markdown(learn_more_md)
+
+ # Register listeners
+ btn_list = [
+ leftvote_btn,
+ rightvote_btn,
+ tie_btn,
+ bothbad_btn,
+ regenerate_btn,
+ clear_btn,
+ ]
+ leftvote_btn.click(
+ leftvote_last_response,
+ states + model_selectors,
+ model_selectors + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
+ )
+ rightvote_btn.click(
+ rightvote_last_response,
+ states + model_selectors,
+ model_selectors + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
+ )
+ tie_btn.click(
+ tievote_last_response,
+ states + model_selectors,
+ model_selectors + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
+ )
+ bothbad_btn.click(
+ bothbad_vote_last_response,
+ states + model_selectors,
+ model_selectors + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
+ )
+ regenerate_btn.click(
+ regenerate, states, states + chatbots + [textbox] + btn_list
+ ).then(
+ http_bot_all,
+ states + [temperature, top_p, max_output_tokens],
+ states + chatbots + btn_list,
+ )
+ clear_btn.click(
+ clear_history, None, states + chatbots + model_selectors + [textbox] + btn_list
+ )
+
+ share_js = """
+function (a, b, c, d) {
+ const captureElement = document.querySelector('#share-region-anony');
+ html2canvas(captureElement)
+ .then(canvas => {
+ canvas.style.display = 'none'
+ document.body.appendChild(canvas)
+ return canvas
+ })
+ .then(canvas => {
+ const image = canvas.toDataURL('image/png')
+ const a = document.createElement('a')
+ a.setAttribute('download', 'chatbot-arena.png')
+ a.setAttribute('href', image)
+ a.click()
+ canvas.remove()
+ });
+ return [a, b, c, d];
+}
+"""
+ share_btn.click(share_click, states + model_selectors, [], _js=share_js)
+
+ textbox.submit(
+ add_text,
+ states + model_selectors + [textbox],
+ states + chatbots + [textbox] + btn_list,
+ ).then(
+ http_bot_all,
+ states + [temperature, top_p, max_output_tokens],
+ states + chatbots + btn_list,
+ )
+ send_btn.click(
+ add_text,
+ states + model_selectors + [textbox],
+ states + chatbots + [textbox] + btn_list,
+ ).then(
+ http_bot_all,
+ states + [temperature, top_p, max_output_tokens],
+ states + chatbots + btn_list,
+ )
+
+ return (
+ states,
+ model_selectors,
+ chatbots,
+ textbox,
+ send_btn,
+ button_row,
+ button_row2,
+ parameter_row,
+ )
diff --git a/graphgpt/serve/gradio_block_arena_named.py b/graphgpt/serve/gradio_block_arena_named.py
new file mode 100644
index 0000000..fc525eb
--- /dev/null
+++ b/graphgpt/serve/gradio_block_arena_named.py
@@ -0,0 +1,468 @@
+"""
+Chatbot Arena (side-by-side) tab.
+Users chat with two chosen models.
+"""
+
+import json
+import time
+
+import gradio as gr
+import numpy as np
+
+from fastchat.constants import (
+ MODERATION_MSG,
+ CONVERSATION_LIMIT_MSG,
+ INPUT_CHAR_LEN_LIMIT,
+ CONVERSATION_LEN_LIMIT,
+)
+from fastchat.model.model_adapter import get_conversation_template
+from fastchat.serve.gradio_patch import Chatbot as grChatbot
+from fastchat.serve.gradio_web_server import (
+ State,
+ http_bot,
+ get_conv_log_filename,
+ get_model_description_md,
+ no_change_btn,
+ enable_btn,
+ disable_btn,
+ learn_more_md,
+)
+from fastchat.utils import (
+ build_logger,
+ violates_moderation,
+)
+
+
+logger = build_logger("gradio_web_server_multi", "gradio_web_server_multi.log")
+
+num_models = 2
+enable_moderation = False
+
+
+def set_global_vars_named(enable_moderation_):
+ global enable_moderation
+ enable_moderation = enable_moderation_
+
+
+def load_demo_side_by_side_named(models, url_params):
+ states = (None,) * num_models
+
+ model_left = models[0] if len(models) > 0 else ""
+ if len(models) > 1:
+ weights = ([8, 4, 2, 1] + [1] * 32)[: len(models) - 1]
+ weights = weights / np.sum(weights)
+ model_right = np.random.choice(models[1:], p=weights)
+ else:
+ model_right = model_left
+
+ selector_updates = (
+ gr.Dropdown.update(model_left, visible=True),
+ gr.Dropdown.update(model_right, visible=True),
+ )
+
+ return (
+ states
+ + selector_updates
+ + (gr.Chatbot.update(visible=True),) * num_models
+ + (
+ gr.Textbox.update(visible=True),
+ gr.Box.update(visible=True),
+ gr.Row.update(visible=True),
+ gr.Row.update(visible=True),
+ gr.Accordion.update(visible=True),
+ )
+ )
+
+
+def vote_last_response(states, vote_type, model_selectors, request: gr.Request):
+ with open(get_conv_log_filename(), "a") as fout:
+ data = {
+ "tstamp": round(time.time(), 4),
+ "type": vote_type,
+ "models": [x for x in model_selectors],
+ "states": [x.dict() for x in states],
+ "ip": request.client.host,
+ }
+ fout.write(json.dumps(data) + "\n")
+
+
+def leftvote_last_response(
+ state0, state1, model_selector0, model_selector1, request: gr.Request
+):
+ logger.info(f"leftvote (named). ip: {request.client.host}")
+ vote_last_response(
+ [state0, state1], "leftvote", [model_selector0, model_selector1], request
+ )
+ return ("",) + (disable_btn,) * 4
+
+
+def rightvote_last_response(
+ state0, state1, model_selector0, model_selector1, request: gr.Request
+):
+ logger.info(f"rightvote (named). ip: {request.client.host}")
+ vote_last_response(
+ [state0, state1], "rightvote", [model_selector0, model_selector1], request
+ )
+ return ("",) + (disable_btn,) * 4
+
+
+def tievote_last_response(
+ state0, state1, model_selector0, model_selector1, request: gr.Request
+):
+ logger.info(f"tievote (named). ip: {request.client.host}")
+ vote_last_response(
+ [state0, state1], "tievote", [model_selector0, model_selector1], request
+ )
+ return ("",) + (disable_btn,) * 4
+
+
+def bothbad_vote_last_response(
+ state0, state1, model_selector0, model_selector1, request: gr.Request
+):
+ logger.info(f"bothbad_vote (named). ip: {request.client.host}")
+ vote_last_response(
+ [state0, state1], "bothbad_vote", [model_selector0, model_selector1], request
+ )
+ return ("",) + (disable_btn,) * 4
+
+
+def regenerate(state0, state1, request: gr.Request):
+ logger.info(f"regenerate (named). ip: {request.client.host}")
+ states = [state0, state1]
+ for i in range(num_models):
+ states[i].conv.update_last_message(None)
+ return states + [x.to_gradio_chatbot() for x in states] + [""] + [disable_btn] * 6
+
+
+def clear_history(request: gr.Request):
+ logger.info(f"clear_history (named). ip: {request.client.host}")
+ return [None] * num_models + [None] * num_models + [""] + [disable_btn] * 6
+
+
+def share_click(state0, state1, model_selector0, model_selector1, request: gr.Request):
+ logger.info(f"share (named). ip: {request.client.host}")
+ if state0 is not None and state1 is not None:
+ vote_last_response(
+ [state0, state1], "share", [model_selector0, model_selector1], request
+ )
+
+
+def add_text(
+ state0, state1, model_selector0, model_selector1, text, request: gr.Request
+):
+ logger.info(f"add_text (named). ip: {request.client.host}. len: {len(text)}")
+ states = [state0, state1]
+ model_selectors = [model_selector0, model_selector1]
+
+ for i in range(num_models):
+ if states[i] is None:
+ states[i] = State(model_selectors[i])
+
+ if len(text) <= 0:
+ for i in range(num_models):
+ states[i].skip_next = True
+ return (
+ states
+ + [x.to_gradio_chatbot() for x in states]
+ + [""]
+ + [
+ no_change_btn,
+ ]
+ * 6
+ )
+
+ if enable_moderation:
+ flagged = violates_moderation(text)
+ if flagged:
+ logger.info(
+ f"violate moderation (named). ip: {request.client.host}. text: {text}"
+ )
+ for i in range(num_models):
+ states[i].skip_next = True
+ return (
+ states
+ + [x.to_gradio_chatbot() for x in states]
+ + [MODERATION_MSG]
+ + [
+ no_change_btn,
+ ]
+ * 6
+ )
+
+ conv = states[0].conv
+ if (len(conv.messages) - conv.offset) // 2 >= CONVERSATION_LEN_LIMIT:
+ logger.info(
+ f"hit conversation length limit. ip: {request.client.host}. text: {text}"
+ )
+ for i in range(num_models):
+ states[i].skip_next = True
+ return (
+ states
+ + [x.to_gradio_chatbot() for x in states]
+ + [CONVERSATION_LIMIT_MSG]
+ + [
+ no_change_btn,
+ ]
+ * 6
+ )
+
+ text = text[:INPUT_CHAR_LEN_LIMIT] # Hard cut-off
+ for i in range(num_models):
+ states[i].conv.append_message(states[i].conv.roles[0], text)
+ states[i].conv.append_message(states[i].conv.roles[1], None)
+ states[i].skip_next = False
+
+ return (
+ states
+ + [x.to_gradio_chatbot() for x in states]
+ + [""]
+ + [
+ disable_btn,
+ ]
+ * 6
+ )
+
+
+def http_bot_all(
+ state0,
+ state1,
+ temperature,
+ top_p,
+ max_new_tokens,
+ request: gr.Request,
+):
+ logger.info(f"http_bot_all (named). ip: {request.client.host}")
+
+ if state0.skip_next:
+ # This generate call is skipped due to invalid inputs
+ yield (
+ state0,
+ state1,
+ state0.to_gradio_chatbot(),
+ state1.to_gradio_chatbot(),
+ ) + (no_change_btn,) * 6
+ return
+
+ states = [state0, state1]
+ gen = []
+ for i in range(num_models):
+ gen.append(
+ http_bot(
+ states[i],
+ temperature,
+ top_p,
+ max_new_tokens,
+ request,
+ )
+ )
+
+ chatbots = [None] * num_models
+ while True:
+ stop = True
+ for i in range(num_models):
+ try:
+ ret = next(gen[i])
+ states[i], chatbots[i] = ret[0], ret[1]
+ stop = False
+ except StopIteration:
+ pass
+ yield states + chatbots + [disable_btn] * 6
+ if stop:
+ break
+
+ for i in range(10):
+ if i % 2 == 0:
+ yield states + chatbots + [disable_btn] * 4 + [enable_btn] * 2
+ else:
+ yield states + chatbots + [enable_btn] * 6
+ time.sleep(0.2)
+
+
+def build_side_by_side_ui_named(models):
+ notice_markdown = """
+# ⚔️ Chatbot Arena ⚔️
+### Rules
+- Chat with two models side-by-side and vote for which one is better!
+- You pick the models you want to chat with.
+- You can do multiple rounds of conversations before voting.
+- Click "Clear history" to start a new round.
+- [[Blog](https://lmsys.org/blog/2023-05-03-arena/)] [[GitHub]](https://github.com/lm-sys/FastChat) [[Twitter]](https://twitter.com/lmsysorg) [[Discord]](https://discord.gg/KjdtsE9V)
+
+### Terms of use
+By using this service, users are required to agree to the following terms: The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. **The service collects user dialogue data and reserves the right to distribute it under a Creative Commons Attribution (CC-BY) license.** The demo works better on desktop devices with a wide screen.
+
+### Choose two models to chat with (view [leaderboard](?leaderboard))
+"""
+
+ states = [gr.State() for _ in range(num_models)]
+ model_selectors = [None] * num_models
+ chatbots = [None] * num_models
+
+ model_description_md = get_model_description_md(models)
+ notice = gr.Markdown(
+ notice_markdown + model_description_md, elem_id="notice_markdown"
+ )
+
+ with gr.Box(elem_id="share-region-named"):
+ with gr.Row():
+ for i in range(num_models):
+ with gr.Column():
+ model_selectors[i] = gr.Dropdown(
+ choices=models,
+ value=models[i] if len(models) > i else "",
+ interactive=True,
+ show_label=False,
+ ).style(container=False)
+
+ with gr.Row():
+ for i in range(num_models):
+ label = "Model A" if i == 0 else "Model B"
+ with gr.Column():
+ chatbots[i] = grChatbot(
+ label=label, elem_id=f"chatbot", visible=False
+ ).style(height=550)
+
+ with gr.Box() as button_row:
+ with gr.Row():
+ leftvote_btn = gr.Button(value="👈 A is better", interactive=False)
+ rightvote_btn = gr.Button(value="👉 B is better", interactive=False)
+ tie_btn = gr.Button(value="🤝 Tie", interactive=False)
+ bothbad_btn = gr.Button(value="👎 Both are bad", interactive=False)
+
+ with gr.Row():
+ with gr.Column(scale=20):
+ textbox = gr.Textbox(
+ show_label=False,
+ placeholder="Enter text and press ENTER",
+ visible=False,
+ ).style(container=False)
+ with gr.Column(scale=1, min_width=50):
+ send_btn = gr.Button(value="Send", visible=False)
+
+ with gr.Row() as button_row2:
+ regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
+ clear_btn = gr.Button(value="🗑️ Clear history", interactive=False)
+ share_btn = gr.Button(value="📷 Share")
+
+ with gr.Accordion("Parameters", open=False, visible=True) as parameter_row:
+ temperature = gr.Slider(
+ minimum=0.0,
+ maximum=1.0,
+ value=0.7,
+ step=0.1,
+ interactive=True,
+ label="Temperature",
+ )
+ top_p = gr.Slider(
+ minimum=0.0,
+ maximum=1.0,
+ value=1.0,
+ step=0.1,
+ interactive=True,
+ label="Top P",
+ )
+ max_output_tokens = gr.Slider(
+ minimum=16,
+ maximum=1024,
+ value=512,
+ step=64,
+ interactive=True,
+ label="Max output tokens",
+ )
+
+ gr.Markdown(learn_more_md)
+
+ # Register listeners
+ btn_list = [
+ leftvote_btn,
+ rightvote_btn,
+ tie_btn,
+ bothbad_btn,
+ regenerate_btn,
+ clear_btn,
+ ]
+ leftvote_btn.click(
+ leftvote_last_response,
+ states + model_selectors,
+ [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
+ )
+ rightvote_btn.click(
+ rightvote_last_response,
+ states + model_selectors,
+ [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
+ )
+ tie_btn.click(
+ tievote_last_response,
+ states + model_selectors,
+ [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
+ )
+ bothbad_btn.click(
+ bothbad_vote_last_response,
+ states + model_selectors,
+ [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
+ )
+ regenerate_btn.click(
+ regenerate, states, states + chatbots + [textbox] + btn_list
+ ).then(
+ http_bot_all,
+ states + [temperature, top_p, max_output_tokens],
+ states + chatbots + btn_list,
+ )
+ clear_btn.click(clear_history, None, states + chatbots + [textbox] + btn_list)
+
+ share_js = """
+function (a, b, c, d) {
+ const captureElement = document.querySelector('#share-region-named');
+ html2canvas(captureElement)
+ .then(canvas => {
+ canvas.style.display = 'none'
+ document.body.appendChild(canvas)
+ return canvas
+ })
+ .then(canvas => {
+ const image = canvas.toDataURL('image/png')
+ const a = document.createElement('a')
+ a.setAttribute('download', 'chatbot-arena.png')
+ a.setAttribute('href', image)
+ a.click()
+ canvas.remove()
+ });
+ return [a, b, c, d];
+}
+"""
+ share_btn.click(share_click, states + model_selectors, [], _js=share_js)
+
+ for i in range(num_models):
+ model_selectors[i].change(
+ clear_history, None, states + chatbots + [textbox] + btn_list
+ )
+
+ textbox.submit(
+ add_text,
+ states + model_selectors + [textbox],
+ states + chatbots + [textbox] + btn_list,
+ ).then(
+ http_bot_all,
+ states + [temperature, top_p, max_output_tokens],
+ states + chatbots + btn_list,
+ )
+ send_btn.click(
+ add_text,
+ states + model_selectors + [textbox],
+ states + chatbots + [textbox] + btn_list,
+ ).then(
+ http_bot_all,
+ states + [temperature, top_p, max_output_tokens],
+ states + chatbots + btn_list,
+ )
+
+ return (
+ states,
+ model_selectors,
+ chatbots,
+ textbox,
+ send_btn,
+ button_row,
+ button_row2,
+ parameter_row,
+ )
diff --git a/graphgpt/serve/gradio_css.py b/graphgpt/serve/gradio_css.py
new file mode 100644
index 0000000..d64a8cf
--- /dev/null
+++ b/graphgpt/serve/gradio_css.py
@@ -0,0 +1,77 @@
+code_highlight_css = """
+#chatbot .hll { background-color: #ffffcc }
+#chatbot .c { color: #408080; font-style: italic }
+#chatbot .err { border: 1px solid #FF0000 }
+#chatbot .k { color: #008000; font-weight: bold }
+#chatbot .o { color: #666666 }
+#chatbot .ch { color: #408080; font-style: italic }
+#chatbot .cm { color: #408080; font-style: italic }
+#chatbot .cp { color: #BC7A00 }
+#chatbot .cpf { color: #408080; font-style: italic }
+#chatbot .c1 { color: #408080; font-style: italic }
+#chatbot .cs { color: #408080; font-style: italic }
+#chatbot .gd { color: #A00000 }
+#chatbot .ge { font-style: italic }
+#chatbot .gr { color: #FF0000 }
+#chatbot .gh { color: #000080; font-weight: bold }
+#chatbot .gi { color: #00A000 }
+#chatbot .go { color: #888888 }
+#chatbot .gp { color: #000080; font-weight: bold }
+#chatbot .gs { font-weight: bold }
+#chatbot .gu { color: #800080; font-weight: bold }
+#chatbot .gt { color: #0044DD }
+#chatbot .kc { color: #008000; font-weight: bold }
+#chatbot .kd { color: #008000; font-weight: bold }
+#chatbot .kn { color: #008000; font-weight: bold }
+#chatbot .kp { color: #008000 }
+#chatbot .kr { color: #008000; font-weight: bold }
+#chatbot .kt { color: #B00040 }
+#chatbot .m { color: #666666 }
+#chatbot .s { color: #BA2121 }
+#chatbot .na { color: #7D9029 }
+#chatbot .nb { color: #008000 }
+#chatbot .nc { color: #0000FF; font-weight: bold }
+#chatbot .no { color: #880000 }
+#chatbot .nd { color: #AA22FF }
+#chatbot .ni { color: #999999; font-weight: bold }
+#chatbot .ne { color: #D2413A; font-weight: bold }
+#chatbot .nf { color: #0000FF }
+#chatbot .nl { color: #A0A000 }
+#chatbot .nn { color: #0000FF; font-weight: bold }
+#chatbot .nt { color: #008000; font-weight: bold }
+#chatbot .nv { color: #19177C }
+#chatbot .ow { color: #AA22FF; font-weight: bold }
+#chatbot .w { color: #bbbbbb }
+#chatbot .mb { color: #666666 }
+#chatbot .mf { color: #666666 }
+#chatbot .mh { color: #666666 }
+#chatbot .mi { color: #666666 }
+#chatbot .mo { color: #666666 }
+#chatbot .sa { color: #BA2121 }
+#chatbot .sb { color: #BA2121 }
+#chatbot .sc { color: #BA2121 }
+#chatbot .dl { color: #BA2121 }
+#chatbot .sd { color: #BA2121; font-style: italic }
+#chatbot .s2 { color: #BA2121 }
+#chatbot .se { color: #BB6622; font-weight: bold }
+#chatbot .sh { color: #BA2121 }
+#chatbot .si { color: #BB6688; font-weight: bold }
+#chatbot .sx { color: #008000 }
+#chatbot .sr { color: #BB6688 }
+#chatbot .s1 { color: #BA2121 }
+#chatbot .ss { color: #19177C }
+#chatbot .bp { color: #008000 }
+#chatbot .fm { color: #0000FF }
+#chatbot .vc { color: #19177C }
+#chatbot .vg { color: #19177C }
+#chatbot .vi { color: #19177C }
+#chatbot .vm { color: #19177C }
+#chatbot .il { color: #666666 }
+"""
+# .highlight { background: #f8f8f8; }
+
+table_css = """
+table {
+ line-height: 0em
+}
+"""
diff --git a/graphgpt/serve/gradio_patch.py b/graphgpt/serve/gradio_patch.py
new file mode 100644
index 0000000..af8731d
--- /dev/null
+++ b/graphgpt/serve/gradio_patch.py
@@ -0,0 +1,168 @@
+"""
+Adopted from https://github.com/gradio-app/gradio/blob/main/gradio/components.py
+Fix a markdown render problem.
+"""
+from __future__ import annotations
+
+from gradio.components import *
+from markdown2 import Markdown
+import nh3
+
+
+class _Keywords(Enum):
+ NO_VALUE = "NO_VALUE" # Used as a sentinel to determine if nothing is provided as a argument for `value` in `Component.update()`
+ FINISHED_ITERATING = "FINISHED_ITERATING" # Used to skip processing of a component's value (needed for generators + state)
+
+
+@document("style")
+class Chatbot(Changeable, Selectable, IOComponent, JSONSerializable):
+ """
+ Displays a chatbot output showing both user submitted messages and responses. Supports a subset of Markdown including bold, italics, code, and images.
+ Preprocessing: this component does *not* accept input.
+ Postprocessing: expects function to return a {List[Tuple[str | None | Tuple, str | None | Tuple]]}, a list of tuples with user message and response messages. Messages should be strings, tuples, or Nones. If the message is a string, it can include Markdown. If it is a tuple, it should consist of (string filepath to image/video/audio, [optional string alt text]). Messages that are `None` are not displayed.
+
+ Demos: chatbot_simple, chatbot_multimodal
+ """
+
+ def __init__(
+ self,
+ value: List[Tuple[str | None, str | None]] | Callable | None = None,
+ color_map: Dict[str, str] | None = None, # Parameter moved to Chatbot.style()
+ *,
+ label: str | None = None,
+ every: float | None = None,
+ show_label: bool = True,
+ visible: bool = True,
+ elem_id: str | None = None,
+ elem_classes: List[str] | str | None = None,
+ **kwargs,
+ ):
+ """
+ Parameters:
+ value: Default value to show in chatbot. If callable, the function will be called whenever the app loads to set the initial value of the component.
+ label: component name in interface.
+ every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. Queue must be enabled. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute.
+ show_label: if True, will display label.
+ visible: If False, component will be hidden.
+ elem_id: An optional string that is assigned as the id of this component in the HTML DOM. Can be used for targeting CSS styles.
+ elem_classes: An optional list of strings that are assigned as the classes of this component in the HTML DOM. Can be used for targeting CSS styles.
+ """
+ if color_map is not None:
+ warnings.warn(
+ "The 'color_map' parameter has been deprecated.",
+ )
+ # self.md = utils.get_markdown_parser()
+ self.md = Markdown(extras=["fenced-code-blocks", "tables", "break-on-newline"])
+ self.select: EventListenerMethod
+ """
+ Event listener for when the user selects message from Chatbot.
+ Uses event data gradio.SelectData to carry `value` referring to text of selected message, and `index` tuple to refer to [message, participant] index.
+ See EventData documentation on how to use this event data.
+ """
+
+ IOComponent.__init__(
+ self,
+ label=label,
+ every=every,
+ show_label=show_label,
+ visible=visible,
+ elem_id=elem_id,
+ elem_classes=elem_classes,
+ value=value,
+ **kwargs,
+ )
+
+ def get_config(self):
+ return {
+ "value": self.value,
+ "selectable": self.selectable,
+ **IOComponent.get_config(self),
+ }
+
+ @staticmethod
+ def update(
+ value: Any | Literal[_Keywords.NO_VALUE] | None = _Keywords.NO_VALUE,
+ label: str | None = None,
+ show_label: bool | None = None,
+ visible: bool | None = None,
+ ):
+ updated_config = {
+ "label": label,
+ "show_label": show_label,
+ "visible": visible,
+ "value": value,
+ "__type__": "update",
+ }
+ return updated_config
+
+ def _process_chat_messages(
+ self, chat_message: str | Tuple | List | Dict | None
+ ) -> str | Dict | None:
+ if chat_message is None:
+ return None
+ elif isinstance(chat_message, (tuple, list)):
+ mime_type = processing_utils.get_mimetype(chat_message[0])
+ return {
+ "name": chat_message[0],
+ "mime_type": mime_type,
+ "alt_text": chat_message[1] if len(chat_message) > 1 else None,
+ "data": None, # These last two fields are filled in by the frontend
+ "is_file": True,
+ }
+ elif isinstance(
+ chat_message, dict
+ ): # This happens for previously processed messages
+ return chat_message
+ elif isinstance(chat_message, str):
+ # return self.md.render(chat_message)
+ return str(self.md.convert(chat_message))
+ else:
+ raise ValueError(f"Invalid message for Chatbot component: {chat_message}")
+
+ def postprocess(
+ self,
+ y: List[
+ Tuple[str | Tuple | List | Dict | None, str | Tuple | List | Dict | None]
+ ],
+ ) -> List[Tuple[str | Dict | None, str | Dict | None]]:
+ """
+ Parameters:
+ y: List of tuples representing the message and response pairs. Each message and response should be a string, which may be in Markdown format. It can also be a tuple whose first element is a string filepath or URL to an image/video/audio, and second (optional) element is the alt text, in which case the media file is displayed. It can also be None, in which case that message is not displayed.
+ Returns:
+ List of tuples representing the message and response. Each message and response will be a string of HTML, or a dictionary with media information.
+ """
+ if y is None:
+ return []
+ processed_messages = []
+ for message_pair in y:
+ assert isinstance(
+ message_pair, (tuple, list)
+ ), f"Expected a list of lists or list of tuples. Received: {message_pair}"
+ assert (
+ len(message_pair) == 2
+ ), f"Expected a list of lists of length 2 or list of tuples of length 2. Received: {message_pair}"
+ processed_messages.append(
+ (
+ # self._process_chat_messages(message_pair[0]),
+ ''
+ + nh3.clean(message_pair[0])
+ + "
",
+ self._process_chat_messages(message_pair[1]),
+ )
+ )
+ return processed_messages
+
+ def style(self, height: int | None = None, **kwargs):
+ """
+ This method can be used to change the appearance of the Chatbot component.
+ """
+ if height is not None:
+ self._style["height"] = height
+ if kwargs.get("color_map") is not None:
+ warnings.warn("The 'color_map' parameter has been deprecated.")
+
+ Component.style(
+ self,
+ **kwargs,
+ )
+ return self
diff --git a/graphgpt/serve/gradio_web_server.py b/graphgpt/serve/gradio_web_server.py
new file mode 100644
index 0000000..4e0c0d3
--- /dev/null
+++ b/graphgpt/serve/gradio_web_server.py
@@ -0,0 +1,714 @@
+"""
+The gradio demo server for chatting with a single model.
+"""
+
+import argparse
+from collections import defaultdict
+import datetime
+import json
+import os
+import random
+import time
+import uuid
+
+import gradio as gr
+import requests
+
+from fastchat.conversation import SeparatorStyle
+from fastchat.constants import (
+ LOGDIR,
+ WORKER_API_TIMEOUT,
+ ErrorCode,
+ MODERATION_MSG,
+ CONVERSATION_LIMIT_MSG,
+ SERVER_ERROR_MSG,
+ INPUT_CHAR_LEN_LIMIT,
+ CONVERSATION_LEN_LIMIT,
+)
+from fastchat.model.model_adapter import get_conversation_template
+from fastchat.model.model_registry import model_info
+from fastchat.serve.api_provider import (
+ anthropic_api_stream_iter,
+ bard_api_stream_iter,
+ openai_api_stream_iter,
+ palm_api_stream_iter,
+ init_palm_chat,
+)
+from fastchat.serve.gradio_patch import Chatbot as grChatbot
+from fastchat.serve.gradio_css import code_highlight_css
+from fastchat.utils import (
+ build_logger,
+ violates_moderation,
+ get_window_url_params_js,
+)
+
+
+logger = build_logger("gradio_web_server", "gradio_web_server.log")
+
+headers = {"User-Agent": "fastchat Client"}
+
+no_change_btn = gr.Button.update()
+enable_btn = gr.Button.update(interactive=True)
+disable_btn = gr.Button.update(interactive=False)
+
+controller_url = None
+enable_moderation = False
+
+learn_more_md = """
+### License
+The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation.
+"""
+
+
+class State:
+ def __init__(self, model_name):
+ self.conv = get_conversation_template(model_name)
+ self.conv_id = uuid.uuid4().hex
+ self.skip_next = False
+ self.model_name = model_name
+
+ if model_name == "bard":
+ self.bard_session_state = {
+ "conversation_id": "",
+ "response_id": "",
+ "choice_id": "",
+ "req_id": 0,
+ }
+ # According to release note, "chat-bison@001" is PaLM 2 for chat.
+ # https://cloud.google.com/vertex-ai/docs/release-notes#May_10_2023
+ self.palm_chat = init_palm_chat("chat-bison@001")
+
+ def to_gradio_chatbot(self):
+ return self.conv.to_gradio_chatbot()
+
+ def dict(self):
+ base = self.conv.dict()
+ base.update(
+ {
+ "conv_id": self.conv_id,
+ "model_name": self.model_name,
+ }
+ )
+ return base
+
+
+def set_global_vars(controller_url_, enable_moderation_):
+ global controller_url, enable_moderation
+ controller_url = controller_url_
+ enable_moderation = enable_moderation_
+
+
+def get_conv_log_filename():
+ t = datetime.datetime.now()
+ name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
+ return name
+
+
+def get_model_list(controller_url):
+ ret = requests.post(controller_url + "/refresh_all_workers")
+ assert ret.status_code == 200
+ ret = requests.post(controller_url + "/list_models")
+ models = ret.json()["models"]
+ priority = {k: f"___{i:02d}" for i, k in enumerate(model_info)}
+ models.sort(key=lambda x: priority.get(x, x))
+ logger.info(f"Models: {models}")
+ return models
+
+
+def load_demo_refresh_model_list(url_params):
+ models = get_model_list(controller_url)
+ selected_model = models[0] if len(models) > 0 else ""
+ if "model" in url_params:
+ model = url_params["model"]
+ if model in models:
+ selected_model = model
+
+ dropdown_update = gr.Dropdown.update(
+ choices=models, value=selected_model, visible=True
+ )
+
+ state = None
+ return (
+ state,
+ dropdown_update,
+ gr.Chatbot.update(visible=True),
+ gr.Textbox.update(visible=True),
+ gr.Button.update(visible=True),
+ gr.Row.update(visible=True),
+ gr.Accordion.update(visible=True),
+ )
+
+
+def load_demo_reload_model(url_params, request: gr.Request):
+ logger.info(
+ f"load_demo_reload_model. ip: {request.client.host}. params: {url_params}"
+ )
+ return load_demo_refresh_model_list(url_params)
+
+
+def load_demo_single(models, url_params):
+ dropdown_update = gr.Dropdown.update(visible=True)
+ if "model" in url_params:
+ model = url_params["model"]
+ if model in models:
+ dropdown_update = gr.Dropdown.update(value=model, visible=True)
+
+ state = None
+ return (
+ state,
+ dropdown_update,
+ gr.Chatbot.update(visible=True),
+ gr.Textbox.update(visible=True),
+ gr.Button.update(visible=True),
+ gr.Row.update(visible=True),
+ gr.Accordion.update(visible=True),
+ )
+
+
+def load_demo(url_params, request: gr.Request):
+ logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
+ return load_demo_single(models, url_params)
+
+
+def vote_last_response(state, vote_type, model_selector, request: gr.Request):
+ with open(get_conv_log_filename(), "a") as fout:
+ data = {
+ "tstamp": round(time.time(), 4),
+ "type": vote_type,
+ "model": model_selector,
+ "state": state.dict(),
+ "ip": request.client.host,
+ }
+ fout.write(json.dumps(data) + "\n")
+
+
+def upvote_last_response(state, model_selector, request: gr.Request):
+ logger.info(f"upvote. ip: {request.client.host}")
+ vote_last_response(state, "upvote", model_selector, request)
+ return ("",) + (disable_btn,) * 3
+
+
+def downvote_last_response(state, model_selector, request: gr.Request):
+ logger.info(f"downvote. ip: {request.client.host}")
+ vote_last_response(state, "downvote", model_selector, request)
+ return ("",) + (disable_btn,) * 3
+
+
+def flag_last_response(state, model_selector, request: gr.Request):
+ logger.info(f"flag. ip: {request.client.host}")
+ vote_last_response(state, "flag", model_selector, request)
+ return ("",) + (disable_btn,) * 3
+
+
+def regenerate(state, request: gr.Request):
+ logger.info(f"regenerate. ip: {request.client.host}")
+ state.conv.update_last_message(None)
+ return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5
+
+
+def clear_history(request: gr.Request):
+ logger.info(f"clear_history. ip: {request.client.host}")
+ state = None
+ return (state, [], "") + (disable_btn,) * 5
+
+
+def add_text(state, model_selector, text, request: gr.Request):
+ logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
+
+ if state is None:
+ state = State(model_selector)
+
+ if len(text) <= 0:
+ state.skip_next = True
+ return (state, state.to_gradio_chatbot(), "") + (no_change_btn,) * 5
+
+ if enable_moderation:
+ flagged = violates_moderation(text)
+ if flagged:
+ logger.info(f"violate moderation. ip: {request.client.host}. text: {text}")
+ state.skip_next = True
+ return (state, state.to_gradio_chatbot(), MODERATION_MSG) + (
+ no_change_btn,
+ ) * 5
+
+ conv = state.conv
+ if (len(conv.messages) - conv.offset) // 2 >= CONVERSATION_LEN_LIMIT:
+ logger.info(
+ f"hit conversation length limit. ip: {request.client.host}. text: {text}"
+ )
+ state.skip_next = True
+ return (state, state.to_gradio_chatbot(), CONVERSATION_LIMIT_MSG) + (
+ no_change_btn,
+ ) * 5
+
+ text = text[:INPUT_CHAR_LEN_LIMIT] # Hard cut-off
+ conv.append_message(conv.roles[0], text)
+ conv.append_message(conv.roles[1], None)
+ return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5
+
+
+def post_process_code(code):
+ sep = "\n```"
+ if sep in code:
+ blocks = code.split(sep)
+ if len(blocks) % 2 == 1:
+ for i in range(1, len(blocks), 2):
+ blocks[i] = blocks[i].replace("\\_", "_")
+ code = sep.join(blocks)
+ return code
+
+
+def model_worker_stream_iter(
+ conv,
+ model_name,
+ worker_addr,
+ prompt,
+ temperature,
+ repetition_penalty,
+ top_p,
+ max_new_tokens,
+):
+ # Make requests
+ gen_params = {
+ "model": model_name,
+ "prompt": prompt,
+ "temperature": temperature,
+ "repetition_penalty": repetition_penalty,
+ "top_p": top_p,
+ "max_new_tokens": max_new_tokens,
+ "stop": conv.stop_str,
+ "stop_token_ids": conv.stop_token_ids,
+ "echo": False,
+ }
+ logger.info(f"==== request ====\n{gen_params}")
+
+ # Stream output
+ response = requests.post(
+ worker_addr + "/worker_generate_stream",
+ headers=headers,
+ json=gen_params,
+ stream=True,
+ timeout=WORKER_API_TIMEOUT,
+ )
+ for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
+ if chunk:
+ data = json.loads(chunk.decode())
+ yield data
+
+
+def http_bot(state, temperature, top_p, max_new_tokens, request: gr.Request):
+ logger.info(f"http_bot. ip: {request.client.host}")
+ start_tstamp = time.time()
+ temperature = float(temperature)
+ top_p = float(top_p)
+ max_new_tokens = int(max_new_tokens)
+
+ if state.skip_next:
+ # This generate call is skipped due to invalid inputs
+ state.skip_next = False
+ yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
+ return
+
+ conv, model_name = state.conv, state.model_name
+ if model_name == "gpt-3.5-turbo" or model_name == "gpt-4":
+ prompt = conv.to_openai_api_messages()
+ stream_iter = openai_api_stream_iter(
+ model_name, prompt, temperature, top_p, max_new_tokens
+ )
+ elif model_name in ["claude-v1", "claude-instant-v1"]:
+ prompt = conv.get_prompt()
+ stream_iter = anthropic_api_stream_iter(
+ model_name, prompt, temperature, top_p, max_new_tokens
+ )
+ elif model_name == "bard":
+ # stream_iter = bard_api_stream_iter(state)
+ stream_iter = palm_api_stream_iter(
+ state.palm_chat, conv.messages[-2][1], temperature, top_p, max_new_tokens
+ )
+ else:
+ # Query worker address
+ ret = requests.post(
+ controller_url + "/get_worker_address", json={"model": model_name}
+ )
+ worker_addr = ret.json()["address"]
+ logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")
+
+ # No available worker
+ if worker_addr == "":
+ conv.update_last_message(SERVER_ERROR_MSG)
+ yield (
+ state,
+ state.to_gradio_chatbot(),
+ disable_btn,
+ disable_btn,
+ disable_btn,
+ enable_btn,
+ enable_btn,
+ )
+ return
+
+ # Construct prompt
+ if "chatglm" in model_name:
+ prompt = list(list(x) for x in conv.messages[conv.offset :])
+ else:
+ prompt = conv.get_prompt()
+
+ # Construct repetition_penalty
+ if "t5" in model_name:
+ repetition_penalty = 1.2
+ else:
+ repetition_penalty = 1.0
+ stream_iter = model_worker_stream_iter(
+ conv,
+ model_name,
+ worker_addr,
+ prompt,
+ temperature,
+ repetition_penalty,
+ top_p,
+ max_new_tokens,
+ )
+
+ conv.update_last_message("▌")
+ yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
+
+ try:
+ for data in stream_iter:
+ if data["error_code"] == 0:
+ output = data["text"].strip()
+ if "vicuna" in model_name:
+ output = post_process_code(output)
+ conv.update_last_message(output + "▌")
+ yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
+ else:
+ output = data["text"] + f"\n\n(error_code: {data['error_code']})"
+ conv.update_last_message(output)
+ yield (state, state.to_gradio_chatbot()) + (
+ disable_btn,
+ disable_btn,
+ disable_btn,
+ enable_btn,
+ enable_btn,
+ )
+ return
+ time.sleep(0.02)
+ except requests.exceptions.RequestException as e:
+ conv.update_last_message(
+ f"{SERVER_ERROR_MSG}\n\n"
+ f"(error_code: {ErrorCode.GRADIO_REQUEST_ERROR}, {e})"
+ )
+ yield (state, state.to_gradio_chatbot()) + (
+ disable_btn,
+ disable_btn,
+ disable_btn,
+ enable_btn,
+ enable_btn,
+ )
+ return
+ except Exception as e:
+ conv.update_last_message(
+ f"{SERVER_ERROR_MSG}\n\n"
+ f"(error_code: {ErrorCode.GRADIO_STREAM_UNKNOWN_ERROR}, {e})"
+ )
+ yield (state, state.to_gradio_chatbot()) + (
+ disable_btn,
+ disable_btn,
+ disable_btn,
+ enable_btn,
+ enable_btn,
+ )
+ return
+
+ # Delete "▌"
+ conv.update_last_message(conv.messages[-1][-1][:-1])
+ yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
+
+ finish_tstamp = time.time()
+ logger.info(f"{output}")
+
+ with open(get_conv_log_filename(), "a") as fout:
+ data = {
+ "tstamp": round(finish_tstamp, 4),
+ "type": "chat",
+ "model": model_name,
+ "gen_params": {
+ "temperature": temperature,
+ "top_p": top_p,
+ "max_new_tokens": max_new_tokens,
+ },
+ "start": round(start_tstamp, 4),
+ "finish": round(finish_tstamp, 4),
+ "state": state.dict(),
+ "ip": request.client.host,
+ }
+ fout.write(json.dumps(data) + "\n")
+
+
+block_css = (
+ code_highlight_css
+ + """
+pre {
+ white-space: pre-wrap; /* Since CSS 2.1 */
+ white-space: -moz-pre-wrap; /* Mozilla, since 1999 */
+ white-space: -pre-wrap; /* Opera 4-6 */
+ white-space: -o-pre-wrap; /* Opera 7 */
+ word-wrap: break-word; /* Internet Explorer 5.5+ */
+}
+#notice_markdown th {
+ display: none;
+}
+"""
+)
+
+
+def get_model_description_md(models):
+ model_description_md = """
+| | | |
+| ---- | ---- | ---- |
+"""
+ ct = 0
+ visited = set()
+ for i, name in enumerate(models):
+ if name in model_info:
+ minfo = model_info[name]
+ if minfo.simple_name in visited:
+ continue
+ visited.add(minfo.simple_name)
+ one_model_md = f"[{minfo.simple_name}]({minfo.link}): {minfo.description}"
+ else:
+ visited.add(name)
+ one_model_md = (
+ f"[{name}](): Add the description at fastchat/model/model_registry.py"
+ )
+
+ if ct % 3 == 0:
+ model_description_md += "|"
+ model_description_md += f" {one_model_md} |"
+ if ct % 3 == 2:
+ model_description_md += "\n"
+ ct += 1
+ return model_description_md
+
+
+def build_single_model_ui(models):
+ notice_markdown = """
+# 🏔️ Chat with Open Large Language Models
+- Vicuna: An Open-Source Chatbot Impressing GPT-4 with 90% ChatGPT Quality. [[Blog post]](https://lmsys.org/blog/2023-03-30-vicuna/)
+- Koala: A Dialogue Model for Academic Research. [[Blog post]](https://bair.berkeley.edu/blog/2023/04/03/koala/)
+- [[GitHub]](https://github.com/lm-sys/FastChat) [[Twitter]](https://twitter.com/lmsysorg) [[Discord]](https://discord.gg/KjdtsE9V)
+
+### Terms of use
+By using this service, users are required to agree to the following terms: The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. **The service collects user dialogue data and reserves the right to distribute it under a Creative Commons Attribution (CC-BY) license.**
+
+### Choose a model to chat with
+"""
+
+ state = gr.State()
+ model_description_md = get_model_description_md(models)
+ gr.Markdown(notice_markdown + model_description_md, elem_id="notice_markdown")
+
+ with gr.Row(elem_id="model_selector_row"):
+ model_selector = gr.Dropdown(
+ choices=models,
+ value=models[0] if len(models) > 0 else "",
+ interactive=True,
+ show_label=False,
+ ).style(container=False)
+
+ chatbot = grChatbot(
+ elem_id="chatbot", label="Scroll down and start chatting", visible=False
+ ).style(height=550)
+ with gr.Row():
+ with gr.Column(scale=20):
+ textbox = gr.Textbox(
+ show_label=False,
+ placeholder="Enter text and press ENTER",
+ visible=False,
+ ).style(container=False)
+ with gr.Column(scale=1, min_width=50):
+ send_btn = gr.Button(value="Send", visible=False)
+
+ with gr.Row(visible=False) as button_row:
+ upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
+ downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
+ flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
+ regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
+ clear_btn = gr.Button(value="🗑️ Clear history", interactive=False)
+
+ with gr.Accordion("Parameters", open=False, visible=False) as parameter_row:
+ temperature = gr.Slider(
+ minimum=0.0,
+ maximum=1.0,
+ value=0.7,
+ step=0.1,
+ interactive=True,
+ label="Temperature",
+ )
+ top_p = gr.Slider(
+ minimum=0.0,
+ maximum=1.0,
+ value=1.0,
+ step=0.1,
+ interactive=True,
+ label="Top P",
+ )
+ max_output_tokens = gr.Slider(
+ minimum=16,
+ maximum=1024,
+ value=512,
+ step=64,
+ interactive=True,
+ label="Max output tokens",
+ )
+
+ gr.Markdown(learn_more_md)
+
+ # Register listeners
+ btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
+ upvote_btn.click(
+ upvote_last_response,
+ [state, model_selector],
+ [textbox, upvote_btn, downvote_btn, flag_btn],
+ )
+ downvote_btn.click(
+ downvote_last_response,
+ [state, model_selector],
+ [textbox, upvote_btn, downvote_btn, flag_btn],
+ )
+ flag_btn.click(
+ flag_last_response,
+ [state, model_selector],
+ [textbox, upvote_btn, downvote_btn, flag_btn],
+ )
+ regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then(
+ http_bot,
+ [state, temperature, top_p, max_output_tokens],
+ [state, chatbot] + btn_list,
+ )
+ clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list)
+
+ model_selector.change(clear_history, None, [state, chatbot, textbox] + btn_list)
+
+ textbox.submit(
+ add_text, [state, model_selector, textbox], [state, chatbot, textbox] + btn_list
+ ).then(
+ http_bot,
+ [state, temperature, top_p, max_output_tokens],
+ [state, chatbot] + btn_list,
+ )
+ send_btn.click(
+ add_text, [state, model_selector, textbox], [state, chatbot, textbox] + btn_list
+ ).then(
+ http_bot,
+ [state, temperature, top_p, max_output_tokens],
+ [state, chatbot] + btn_list,
+ )
+
+ return state, model_selector, chatbot, textbox, send_btn, button_row, parameter_row
+
+
+def build_demo(models):
+ with gr.Blocks(
+ title="Chat with Open Large Language Models",
+ theme=gr.themes.Base(),
+ css=block_css,
+ ) as demo:
+ url_params = gr.JSON(visible=False)
+
+ (
+ state,
+ model_selector,
+ chatbot,
+ textbox,
+ send_btn,
+ button_row,
+ parameter_row,
+ ) = build_single_model_ui(models)
+
+ if args.model_list_mode == "once":
+ demo.load(
+ load_demo,
+ [url_params],
+ [
+ state,
+ model_selector,
+ chatbot,
+ textbox,
+ send_btn,
+ button_row,
+ parameter_row,
+ ],
+ _js=get_window_url_params_js,
+ )
+ elif args.model_list_mode == "reload":
+ demo.load(
+ load_demo_reload_model,
+ [url_params],
+ [
+ state,
+ model_selector,
+ chatbot,
+ textbox,
+ send_btn,
+ button_row,
+ parameter_row,
+ ],
+ _js=get_window_url_params_js,
+ )
+ else:
+ raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
+
+ return demo
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--host", type=str, default="0.0.0.0")
+ parser.add_argument("--port", type=int)
+ parser.add_argument("--controller-url", type=str, default="http://localhost:21001")
+ parser.add_argument("--concurrency-count", type=int, default=10)
+ parser.add_argument(
+ "--model-list-mode",
+ type=str,
+ default="once",
+ choices=["once", "reload"],
+ help="Whether to load the model list once or reload the model list every time.",
+ )
+ parser.add_argument("--share", action="store_true")
+ parser.add_argument(
+ "--moderate", action="store_true", help="Enable content moderation"
+ )
+ parser.add_argument(
+ "--add-chatgpt",
+ action="store_true",
+ help="Add OpenAI's ChatGPT models (gpt-3.5-turbo, gpt-4)",
+ )
+ parser.add_argument(
+ "--add-claude",
+ action="store_true",
+ help="Add Anthropic's Claude models (claude-v1, claude-instant-v1)",
+ )
+ parser.add_argument(
+ "--add-bard",
+ action="store_true",
+ help="Add Google's Bard model (PaLM 2 for Chat: chat-bison@001)",
+ )
+ args = parser.parse_args()
+ logger.info(f"args: {args}")
+
+ set_global_vars(args.controller_url, args.moderate)
+ models = get_model_list(args.controller_url)
+
+ if args.add_chatgpt:
+ models = ["gpt-3.5-turbo", "gpt-4"] + models
+ if args.add_claude:
+ models = ["claude-v1", "claude-instant-v1"] + models
+ if args.add_bard:
+ models = ["bard"] + models
+
+ demo = build_demo(models)
+ demo.queue(
+ concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False
+ ).launch(
+ server_name=args.host, server_port=args.port, share=args.share, max_threads=200
+ )
diff --git a/graphgpt/serve/gradio_web_server_graph.py b/graphgpt/serve/gradio_web_server_graph.py
new file mode 100644
index 0000000..e3e1004
--- /dev/null
+++ b/graphgpt/serve/gradio_web_server_graph.py
@@ -0,0 +1,420 @@
+import argparse
+import datetime
+import json
+import os
+import time
+
+import gradio as gr
+import requests
+
+from graphgpt.conversation import (default_conversation, conv_templates,
+ SeparatorStyle)
+from graphgpt.constants import LOGDIR
+from graphgpt.utils import (build_logger, server_error_msg,
+ violates_moderation, moderation_msg)
+import hashlib
+
+
+logger = build_logger("gradio_web_server", "gradio_web_server.log")
+
+headers = {"User-Agent": "GraphGPT Client"}
+
+no_change_btn = gr.Button.update()
+enable_btn = gr.Button.update(interactive=True)
+disable_btn = gr.Button.update(interactive=False)
+
+priority = {
+ "vicuna-13b": "aaaaaaa",
+ "koala-13b": "aaaaaab",
+}
+
+
+def get_conv_log_filename():
+ t = datetime.datetime.now()
+ name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
+ return name
+
+
+def get_model_list():
+ ret = requests.post(args.controller_url + "/refresh_all_workers")
+ assert ret.status_code == 200
+ ret = requests.post(args.controller_url + "/list_models")
+ models = ret.json()["models"]
+ models.sort(key=lambda x: priority.get(x, x))
+ logger.info(f"Models: {models}")
+ return models
+
+
+get_window_url_params = """
+function() {
+ const params = new URLSearchParams(window.location.search);
+ url_params = Object.fromEntries(params);
+ console.log(url_params);
+ return url_params;
+ }
+"""
+
+
+def load_demo(url_params, request: gr.Request):
+ logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
+
+ dropdown_update = gr.Dropdown.update(visible=True)
+ if "model" in url_params:
+ model = url_params["model"]
+ if model in models:
+ dropdown_update = gr.Dropdown.update(
+ value=model, visible=True)
+
+ state = default_conversation.copy()
+ return state, dropdown_update
+
+
+def load_demo_refresh_model_list(request: gr.Request):
+ logger.info(f"load_demo. ip: {request.client.host}")
+ models = get_model_list()
+ state = default_conversation.copy()
+ dropdown_update = gr.Dropdown.update(
+ choices=models,
+ value=models[0] if len(models) > 0 else ""
+ )
+ return state, dropdown_update
+
+
+def vote_last_response(state, vote_type, model_selector, request: gr.Request):
+ with open(get_conv_log_filename(), "a") as fout:
+ data = {
+ "tstamp": round(time.time(), 4),
+ "type": vote_type,
+ "model": model_selector,
+ "state": state.dict(),
+ "ip": request.client.host,
+ }
+ fout.write(json.dumps(data) + "\n")
+
+
+def upvote_last_response(state, model_selector, request: gr.Request):
+ logger.info(f"upvote. ip: {request.client.host}")
+ vote_last_response(state, "upvote", model_selector, request)
+ return ("",) + (disable_btn,) * 3
+
+
+def downvote_last_response(state, model_selector, request: gr.Request):
+ logger.info(f"downvote. ip: {request.client.host}")
+ vote_last_response(state, "downvote", model_selector, request)
+ return ("",) + (disable_btn,) * 3
+
+
+def flag_last_response(state, model_selector, request: gr.Request):
+ logger.info(f"flag. ip: {request.client.host}")
+ vote_last_response(state, "flag", model_selector, request)
+ return ("",) + (disable_btn,) * 3
+
+
+def regenerate(state, image_process_mode, request: gr.Request):
+ logger.info(f"regenerate. ip: {request.client.host}")
+ state.messages[-1][-1] = None
+ prev_human_msg = state.messages[-2]
+ if type(prev_human_msg[1]) in (tuple, list):
+ prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)
+ state.skip_next = False
+ return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
+
+
+def clear_history(request: gr.Request):
+ logger.info(f"clear_history. ip: {request.client.host}")
+ state = default_conversation.copy()
+ return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
+
+
+def add_text(state, text, image, image_process_mode, request: gr.Request):
+ logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
+ if len(text) <= 0 and image is None:
+ state.skip_next = True
+ return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5
+ if args.moderate:
+ flagged = violates_moderation(text)
+ if flagged:
+ state.skip_next = True
+ return (state, state.to_gradio_chatbot(), moderation_msg, None) + (
+ no_change_btn,) * 5
+
+ text = text[:1536] # Hard cut-off
+ if image is not None:
+ text = text[:1200] # Hard cut-off for images
+ if '' not in text:
+ # text = '' + text
+ text = text + '\n'
+ text = (text, image, image_process_mode)
+ if len(state.get_images(return_pil=True)) > 0:
+ state = default_conversation.copy()
+ state.append_message(state.roles[0], text)
+ state.append_message(state.roles[1], None)
+ state.skip_next = False
+ return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
+
+
+def http_bot(state, model_selector, temperature, top_p, max_new_tokens, request: gr.Request):
+ logger.info(f"http_bot. ip: {request.client.host}")
+ start_tstamp = time.time()
+ model_name = model_selector
+
+ if state.skip_next:
+ # This generate call is skipped due to invalid inputs
+ yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
+ return
+
+ if len(state.messages) == state.offset + 2:
+ # First round of conversation
+ if "llava" in model_name.lower():
+ if 'llama-2' in model_name.lower():
+ template_name = "llava_llama_2"
+ elif "v1" in model_name.lower():
+ if 'mmtag' in model_name.lower():
+ template_name = "v1_mmtag"
+ elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower():
+ template_name = "v1_mmtag"
+ else:
+ template_name = "llava_v1"
+ elif "mpt" in model_name.lower():
+ template_name = "mpt"
+ else:
+ if 'mmtag' in model_name.lower():
+ template_name = "v0_mmtag"
+ elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower():
+ template_name = "v0_mmtag"
+ else:
+ template_name = "llava_v0"
+ elif "mpt" in model_name:
+ template_name = "mpt_text"
+ elif "llama-2" in model_name:
+ template_name = "llama_2"
+ else:
+ template_name = "vicuna_v1"
+ new_state = conv_templates[template_name].copy()
+ new_state.append_message(new_state.roles[0], state.messages[-2][1])
+ new_state.append_message(new_state.roles[1], None)
+ state = new_state
+
+ # Query worker address
+ controller_url = args.controller_url
+ ret = requests.post(controller_url + "/get_worker_address",
+ json={"model": model_name})
+ worker_addr = ret.json()["address"]
+ logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")
+
+ # No available worker
+ if worker_addr == "":
+ state.messages[-1][-1] = server_error_msg
+ yield (state, state.to_gradio_chatbot(), disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
+ return
+
+ # Construct prompt
+ prompt = state.get_prompt()
+
+ all_images = state.get_images(return_pil=True)
+ all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images]
+ for image, hash in zip(all_images, all_image_hash):
+ t = datetime.datetime.now()
+ filename = os.path.join(LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash}.jpg")
+ if not os.path.isfile(filename):
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
+ image.save(filename)
+
+ # Make requests
+ pload = {
+ "model": model_name,
+ "prompt": prompt,
+ "temperature": float(temperature),
+ "top_p": float(top_p),
+ "max_new_tokens": min(int(max_new_tokens), 1536),
+ "stop": state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2,
+ "images": f'List of {len(state.get_images())} images: {all_image_hash}',
+ }
+ logger.info(f"==== request ====\n{pload}")
+
+ pload['images'] = state.get_images()
+
+ state.messages[-1][-1] = "▌"
+ yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
+
+ try:
+ # Stream output
+ response = requests.post(worker_addr + "/worker_generate_stream",
+ headers=headers, json=pload, stream=True, timeout=10)
+ for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
+ if chunk:
+ data = json.loads(chunk.decode())
+ if data["error_code"] == 0:
+ output = data["text"][len(prompt):].strip()
+ state.messages[-1][-1] = output + "▌"
+ yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
+ else:
+ output = data["text"] + f" (error_code: {data['error_code']})"
+ state.messages[-1][-1] = output
+ yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
+ return
+ time.sleep(0.03)
+ except requests.exceptions.RequestException as e:
+ state.messages[-1][-1] = server_error_msg
+ yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
+ return
+
+ state.messages[-1][-1] = state.messages[-1][-1][:-1]
+ yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
+
+ finish_tstamp = time.time()
+ logger.info(f"{output}")
+
+ with open(get_conv_log_filename(), "a") as fout:
+ data = {
+ "tstamp": round(finish_tstamp, 4),
+ "type": "chat",
+ "model": model_name,
+ "start": round(start_tstamp, 4),
+ "finish": round(start_tstamp, 4),
+ "state": state.dict(),
+ "images": all_image_hash,
+ "ip": request.client.host,
+ }
+ fout.write(json.dumps(data) + "\n")
+
+title_markdown = ("""
+# 🌋 LLaVA: Large Language and Vision Assistant
+[[Project Page](https://llava-vl.github.io)] [[Code](https://github.com/haotian-liu/LLaVA)] [[Model](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md)] | 📚 [[LLaVA](https://arxiv.org/abs/2304.08485)] [[LLaVA-v1.5](https://arxiv.org/abs/2310.03744)]
+""")
+
+tos_markdown = ("""
+### Terms of use
+By using this service, users are required to agree to the following terms:
+The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
+Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
+For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
+""")
+
+
+learn_more_markdown = ("""
+### License
+The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation.
+""")
+
+block_css = """
+
+#buttons button {
+ min-width: min(120px,100%);
+}
+
+"""
+
+def build_demo(embed_mode):
+ textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False)
+ with gr.Blocks(title="LLaVA", theme=gr.themes.Default(), css=block_css) as demo:
+ state = gr.State()
+
+ if not embed_mode:
+ gr.Markdown(title_markdown)
+
+ with gr.Row():
+ with gr.Column(scale=3):
+ with gr.Row(elem_id="model_selector_row"):
+ model_selector = gr.Dropdown(
+ choices=models,
+ value=models[0] if len(models) > 0 else "",
+ interactive=True,
+ show_label=False,
+ container=False)
+
+ imagebox = gr.Image(type="pil")
+ image_process_mode = gr.Radio(
+ ["Crop", "Resize", "Pad", "Default"],
+ value="Default",
+ label="Preprocess for non-square image", visible=False)
+
+ cur_dir = os.path.dirname(os.path.abspath(__file__))
+ gr.Examples(examples=[
+ [f"{cur_dir}/examples/extreme_ironing.jpg", "What is unusual about this image?"],
+ [f"{cur_dir}/examples/waterview.jpg", "What are the things I should be cautious about when I visit here?"],
+ ], inputs=[imagebox, textbox])
+
+ with gr.Accordion("Parameters", open=False) as parameter_row:
+ temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True, label="Temperature",)
+ top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",)
+ max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",)
+
+ with gr.Column(scale=8):
+ chatbot = gr.Chatbot(elem_id="chatbot", label="LLaVA Chatbot", height=550)
+ with gr.Row():
+ with gr.Column(scale=8):
+ textbox.render()
+ with gr.Column(scale=1, min_width=50):
+ submit_btn = gr.Button(value="Send", variant="primary")
+ with gr.Row(elem_id="buttons") as button_row:
+ upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
+ downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
+ flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
+ #stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
+ regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
+ clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
+
+ if not embed_mode:
+ gr.Markdown(tos_markdown)
+ gr.Markdown(learn_more_markdown)
+ url_params = gr.JSON(visible=False)
+
+ # Register listeners
+ btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
+ upvote_btn.click(upvote_last_response,
+ [state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn])
+ downvote_btn.click(downvote_last_response,
+ [state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn])
+ flag_btn.click(flag_last_response,
+ [state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn])
+ regenerate_btn.click(regenerate, [state, image_process_mode],
+ [state, chatbot, textbox, imagebox] + btn_list).then(
+ http_bot, [state, model_selector, temperature, top_p, max_output_tokens],
+ [state, chatbot] + btn_list)
+ clear_btn.click(clear_history, None, [state, chatbot, textbox, imagebox] + btn_list)
+
+ textbox.submit(add_text, [state, textbox, imagebox, image_process_mode], [state, chatbot, textbox, imagebox] + btn_list
+ ).then(http_bot, [state, model_selector, temperature, top_p, max_output_tokens],
+ [state, chatbot] + btn_list)
+ submit_btn.click(add_text, [state, textbox, imagebox, image_process_mode], [state, chatbot, textbox, imagebox] + btn_list
+ ).then(http_bot, [state, model_selector, temperature, top_p, max_output_tokens],
+ [state, chatbot] + btn_list)
+
+ if args.model_list_mode == "once":
+ demo.load(load_demo, [url_params], [state, model_selector],
+ _js=get_window_url_params)
+ elif args.model_list_mode == "reload":
+ demo.load(load_demo_refresh_model_list, None, [state, model_selector])
+ else:
+ raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
+
+ return demo
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--host", type=str, default="0.0.0.0")
+ parser.add_argument("--port", type=int)
+ parser.add_argument("--controller-url", type=str, default="http://localhost:21001")
+ parser.add_argument("--concurrency-count", type=int, default=10)
+ parser.add_argument("--model-list-mode", type=str, default="once",
+ choices=["once", "reload"])
+ parser.add_argument("--share", action="store_true")
+ parser.add_argument("--moderate", action="store_true")
+ parser.add_argument("--embed", action="store_true")
+ args = parser.parse_args()
+ logger.info(f"args: {args}")
+
+ models = get_model_list()
+
+ logger.info(args)
+ demo = build_demo(args.embed)
+ demo.queue(
+ concurrency_count=args.concurrency_count,
+ api_open=False
+ ).launch(
+ server_name=args.host,
+ server_port=args.port,
+ share=args.share
+ )
\ No newline at end of file
diff --git a/graphgpt/serve/gradio_web_server_multi.py b/graphgpt/serve/gradio_web_server_multi.py
new file mode 100644
index 0000000..3be9243
--- /dev/null
+++ b/graphgpt/serve/gradio_web_server_multi.py
@@ -0,0 +1,219 @@
+"""
+The gradio demo server with multiple tabs.
+It supports chatting with a single model or chatting with two models side-by-side.
+"""
+
+import argparse
+import pickle
+
+import gradio as gr
+
+from fastchat.serve.gradio_block_arena_anony import (
+ build_side_by_side_ui_anony,
+ load_demo_side_by_side_anony,
+ set_global_vars_anony,
+)
+from fastchat.serve.gradio_block_arena_named import (
+ build_side_by_side_ui_named,
+ load_demo_side_by_side_named,
+ set_global_vars_named,
+)
+from fastchat.serve.gradio_patch import Chatbot as grChatbot
+from fastchat.serve.gradio_web_server import (
+ set_global_vars,
+ block_css,
+ build_single_model_ui,
+ get_model_list,
+ load_demo_single,
+)
+from fastchat.serve.monitor.monitor import build_leaderboard_tab
+from fastchat.utils import build_logger, get_window_url_params_js
+
+
+logger = build_logger("gradio_web_server_multi", "gradio_web_server_multi.log")
+
+
+def load_demo(url_params, request: gr.Request):
+ logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
+ selected = 0
+ if "arena" in url_params:
+ selected = 1
+ elif "compare" in url_params:
+ selected = 2
+ elif "leaderboard" in url_params:
+ selected = 3
+ single_updates = load_demo_single(models, url_params)
+
+ models_anony = models
+ if args.anony_only_for_proprietary_model:
+ # Only enable these models in anony battles.
+ if args.add_chatgpt:
+ models_anony = ["gpt-4", "gpt-3.5-turbo"] + models_anony
+ if args.add_claude:
+ models_anony = ["claude-v1", "claude-instant-v1"] + models_anony
+ if args.add_bard:
+ models_anony = ["bard"] + models_anony
+
+ side_by_side_anony_updates = load_demo_side_by_side_anony(models_anony, url_params)
+ side_by_side_named_updates = load_demo_side_by_side_named(models, url_params)
+ return (
+ (gr.Tabs.update(selected=selected),)
+ + single_updates
+ + side_by_side_anony_updates
+ + side_by_side_named_updates
+ )
+
+
+def build_demo(models, elo_results_file):
+ with gr.Blocks(
+ title="Chat with Open Large Language Models",
+ theme=gr.themes.Base(),
+ css=block_css,
+ ) as demo:
+ with gr.Tabs() as tabs:
+ with gr.Tab("Single Model", id=0):
+ (
+ a_state,
+ a_model_selector,
+ a_chatbot,
+ a_textbox,
+ a_send_btn,
+ a_button_row,
+ a_parameter_row,
+ ) = build_single_model_ui(models)
+ a_list = [
+ a_state,
+ a_model_selector,
+ a_chatbot,
+ a_textbox,
+ a_send_btn,
+ a_button_row,
+ a_parameter_row,
+ ]
+
+ with gr.Tab("Chatbot Arena (battle)", id=1):
+ (
+ b_states,
+ b_model_selectors,
+ b_chatbots,
+ b_textbox,
+ b_send_btn,
+ b_button_row,
+ b_button_row2,
+ b_parameter_row,
+ ) = build_side_by_side_ui_anony(models)
+ b_list = (
+ b_states
+ + b_model_selectors
+ + b_chatbots
+ + [
+ b_textbox,
+ b_send_btn,
+ b_button_row,
+ b_button_row2,
+ b_parameter_row,
+ ]
+ )
+
+ with gr.Tab("Chatbot Arena (side-by-side)", id=2):
+ (
+ c_states,
+ c_model_selectors,
+ c_chatbots,
+ c_textbox,
+ c_send_btn,
+ c_button_row,
+ c_button_row2,
+ c_parameter_row,
+ ) = build_side_by_side_ui_named(models)
+ c_list = (
+ c_states
+ + c_model_selectors
+ + c_chatbots
+ + [
+ c_textbox,
+ c_send_btn,
+ c_button_row,
+ c_button_row2,
+ c_parameter_row,
+ ]
+ )
+
+ if elo_results_file:
+ with gr.Tab("Leaderboard", id=3):
+ build_leaderboard_tab(elo_results_file)
+
+ url_params = gr.JSON(visible=False)
+
+ if args.model_list_mode == "once":
+ demo.load(
+ load_demo,
+ [url_params],
+ [tabs] + a_list + b_list + c_list,
+ _js=get_window_url_params_js,
+ )
+ else:
+ raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
+
+ return demo
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--host", type=str, default="0.0.0.0")
+ parser.add_argument("--port", type=int)
+ parser.add_argument("--controller-url", type=str, default="http://localhost:21001")
+ parser.add_argument("--concurrency-count", type=int, default=10)
+ parser.add_argument(
+ "--model-list-mode",
+ type=str,
+ default="once",
+ choices=["once"],
+ )
+ parser.add_argument("--share", action="store_true")
+ parser.add_argument(
+ "--moderate", action="store_true", help="Enable content moderation"
+ )
+ parser.add_argument(
+ "--add-chatgpt",
+ action="store_true",
+ help="Add OpenAI ChatGPT models (gpt-3.5-turbo, gpt-4)",
+ )
+ parser.add_argument(
+ "--add-claude",
+ action="store_true",
+ help="Add Anthropic's Claude models (claude-v1, claude-instant-v1)",
+ )
+ parser.add_argument(
+ "--add-bard",
+ action="store_true",
+ help="Add Google's Bard model (PaLM 2 for Chat: chat-bison@001)",
+ )
+ parser.add_argument(
+ "--anony-only-for-proprietary-model",
+ action="store_true",
+ help="Only add ChatGPT, Claude, Bard under anony battle tab",
+ )
+ parser.add_argument("--elo-results-file", type=str)
+ args = parser.parse_args()
+ logger.info(f"args: {args}")
+
+ set_global_vars(args.controller_url, args.moderate)
+ set_global_vars_named(args.moderate)
+ set_global_vars_anony(args.moderate)
+ models = get_model_list(args.controller_url)
+
+ if not args.anony_only_for_proprietary_model:
+ if args.add_chatgpt:
+ models = ["gpt-3.5-turbo", "gpt-4"] + models
+ if args.add_claude:
+ models = ["claude-v1", "claude-instant-v1"] + models
+ if args.add_bard:
+ models = ["bard"] + models
+
+ demo = build_demo(models, args.elo_results_file)
+ demo.queue(
+ concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False
+ ).launch(
+ server_name=args.host, server_port=args.port, share=args.share, max_threads=200
+ )
diff --git a/graphgpt/serve/huggingface_api.py b/graphgpt/serve/huggingface_api.py
new file mode 100644
index 0000000..9223f2e
--- /dev/null
+++ b/graphgpt/serve/huggingface_api.py
@@ -0,0 +1,69 @@
+"""
+Use FastChat with Hugging Face generation APIs.
+
+Usage:
+python3 -m fastchat.serve.huggingface_api --model lmsys/fastchat-t5-3b-v1.0
+python3 -m fastchat.serve.huggingface_api --model ~/model_weights/vicuna-7b/
+"""
+import argparse
+import json
+
+import torch
+from transformers import AutoTokenizer, AutoModelForCausalLM
+
+from fastchat.model import load_model, get_conversation_template, add_model_args
+
+
+@torch.inference_mode()
+def main(args):
+ model, tokenizer = load_model(
+ args.model_path,
+ args.device,
+ args.num_gpus,
+ args.max_gpu_memory,
+ args.load_8bit,
+ args.cpu_offloading,
+ debug=args.debug,
+ )
+
+ msg = args.message
+
+ conv = get_conversation_template(args.model_path)
+ conv.append_message(conv.roles[0], msg)
+ conv.append_message(conv.roles[1], None)
+ prompt = conv.get_prompt()
+
+ input_ids = tokenizer([prompt]).input_ids
+
+ if "t5" in args.model_path and args.repetition_penalty == 1.0:
+ args.repetition_penalty = 1.2
+ output_ids = model.generate(
+ torch.as_tensor(input_ids).cuda(),
+ do_sample=True,
+ temperature=args.temperature,
+ repetition_penalty=args.repetition_penalty,
+ max_new_tokens=args.max_new_tokens,
+ )
+ if model.config.is_encoder_decoder:
+ output_ids = output_ids[0]
+ else:
+ output_ids = output_ids[0][len(input_ids[0]) :]
+ outputs = tokenizer.decode(
+ output_ids, skip_special_tokens=True, spaces_between_special_tokens=False
+ )
+
+ print(f"{conv.roles[0]}: {msg}")
+ print(f"{conv.roles[1]}: {outputs}")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ add_model_args(parser)
+ parser.add_argument("--temperature", type=float, default=0.7)
+ parser.add_argument("--repetition_penalty", type=float, default=1.0)
+ parser.add_argument("--max-new-tokens", type=int, default=512)
+ parser.add_argument("--debug", action="store_true")
+ parser.add_argument("--message", type=str, default="Hello! Who are you?")
+ args = parser.parse_args()
+
+ main(args)
diff --git a/graphgpt/serve/inference.py b/graphgpt/serve/inference.py
new file mode 100644
index 0000000..489b379
--- /dev/null
+++ b/graphgpt/serve/inference.py
@@ -0,0 +1,314 @@
+"""Inference for FastChat models."""
+import abc
+import gc
+import math
+from typing import Iterable, Optional
+import sys
+import warnings
+
+import psutil
+import torch
+from transformers import (
+ AutoTokenizer,
+ AutoModelForCausalLM,
+ LlamaTokenizer,
+ LlamaForCausalLM,
+ AutoModel,
+ AutoModelForSeq2SeqLM,
+ T5Tokenizer,
+ AutoConfig,
+)
+from transformers.generation.logits_process import (
+ LogitsProcessorList,
+ RepetitionPenaltyLogitsProcessor,
+ TemperatureLogitsWarper,
+ TopKLogitsWarper,
+ TopPLogitsWarper,
+)
+
+from fastchat.conversation import get_conv_template, SeparatorStyle
+from fastchat.model.model_adapter import load_model, get_conversation_template
+from fastchat.model.chatglm_model import chatglm_generate_stream
+
+
+def prepare_logits_processor(
+ temperature: float, repetition_penalty: float, top_p: float, top_k: int
+) -> LogitsProcessorList:
+ processor_list = LogitsProcessorList()
+ # TemperatureLogitsWarper doesn't accept 0.0, 1.0 makes it a no-op so we skip two cases.
+ if temperature >= 1e-5 and temperature != 1.0:
+ processor_list.append(TemperatureLogitsWarper(temperature))
+ if repetition_penalty > 1.0:
+ processor_list.append(RepetitionPenaltyLogitsProcessor(repetition_penalty))
+ if 1e-8 <= top_p < 1.0:
+ processor_list.append(TopPLogitsWarper(top_p))
+ if top_k > 0:
+ processor_list.append(TopKLogitsWarper(top_k))
+ return processor_list
+
+
+def partial_stop(output, stop_str):
+ for i in range(0, min(len(output), len(stop_str))):
+ if stop_str.startswith(output[-i:]):
+ return True
+ return False
+
+
+@torch.inference_mode()
+def generate_stream(
+ model, tokenizer, params, device, context_len=2048, stream_interval=2
+):
+ prompt = params["prompt"]
+ len_prompt = len(prompt)
+ temperature = float(params.get("temperature", 1.0))
+ repetition_penalty = float(params.get("repetition_penalty", 1.0))
+ top_p = float(params.get("top_p", 1.0))
+ top_k = int(params.get("top_k", -1)) # -1 means disable
+ max_new_tokens = int(params.get("max_new_tokens", 256))
+ stop_str = params.get("stop", None)
+ echo = bool(params.get("echo", True))
+ stop_token_ids = params.get("stop_token_ids", None) or []
+ stop_token_ids.append(tokenizer.eos_token_id)
+
+ logits_processor = prepare_logits_processor(
+ temperature, repetition_penalty, top_p, top_k
+ )
+
+ input_ids = tokenizer(prompt).input_ids
+ input_echo_len = len(input_ids)
+ output_ids = list(input_ids)
+
+ if model.config.is_encoder_decoder:
+ max_src_len = context_len
+ else:
+ max_src_len = context_len - max_new_tokens - 8
+
+ input_ids = input_ids[-max_src_len:]
+
+ if model.config.is_encoder_decoder:
+ encoder_output = model.encoder(
+ input_ids=torch.as_tensor([input_ids], device=device)
+ )[0]
+ start_ids = torch.as_tensor(
+ [[model.generation_config.decoder_start_token_id]],
+ dtype=torch.int64,
+ device=device,
+ )
+
+ past_key_values = out = None
+ for i in range(max_new_tokens):
+ if i == 0:
+ if model.config.is_encoder_decoder:
+ out = model.decoder(
+ input_ids=start_ids,
+ encoder_hidden_states=encoder_output,
+ use_cache=True,
+ )
+ logits = model.lm_head(out[0])
+ else:
+ out = model(torch.as_tensor([input_ids], device=device), use_cache=True)
+ logits = out.logits
+ past_key_values = out.past_key_values
+ else:
+ if model.config.is_encoder_decoder:
+ out = model.decoder(
+ input_ids=torch.as_tensor([[token]], device=device),
+ encoder_hidden_states=encoder_output,
+ use_cache=True,
+ past_key_values=past_key_values,
+ )
+
+ logits = model.lm_head(out[0])
+ else:
+ out = model(
+ input_ids=torch.as_tensor([[token]], device=device),
+ use_cache=True,
+ past_key_values=past_key_values,
+ )
+ logits = out.logits
+ past_key_values = out.past_key_values
+
+ if logits_processor:
+ if repetition_penalty > 1.0:
+ tmp_output_ids = torch.as_tensor([output_ids], device=logits.device)
+ else:
+ tmp_output_ids = None
+ last_token_logits = logits_processor(tmp_output_ids, logits[:, -1, :])[0]
+ else:
+ last_token_logits = logits[0, -1, :]
+
+ if device == "mps":
+ # Switch to CPU by avoiding some bugs in mps backend.
+ last_token_logits = last_token_logits.float().to("cpu")
+
+ if temperature < 1e-5 or top_p < 1e-8: # greedy
+ token = int(torch.argmax(last_token_logits))
+ else:
+ probs = torch.softmax(last_token_logits, dim=-1)
+ token = int(torch.multinomial(probs, num_samples=1))
+
+ output_ids.append(token)
+
+ if token in stop_token_ids:
+ stopped = True
+ else:
+ stopped = False
+
+ if i % stream_interval == 0 or i == max_new_tokens - 1 or stopped:
+ if echo:
+ tmp_output_ids = output_ids
+ rfind_start = len_prompt
+ else:
+ tmp_output_ids = output_ids[input_echo_len:]
+ rfind_start = 0
+
+ output = tokenizer.decode(
+ tmp_output_ids,
+ skip_special_tokens=True,
+ spaces_between_special_tokens=False,
+ )
+
+ partially_stopped = False
+ if stop_str:
+ if isinstance(stop_str, str):
+ pos = output.rfind(stop_str, rfind_start)
+ if pos != -1:
+ output = output[:pos]
+ stopped = True
+ else:
+ partially_stopped = partial_stop(output, stop_str)
+ elif isinstance(stop_str, Iterable):
+ for each_stop in stop_str:
+ pos = output.rfind(each_stop, rfind_start)
+ if pos != -1:
+ output = output[:pos]
+ stopped = True
+ break
+ else:
+ partially_stopped = partial_stop(output, each_stop)
+ if partially_stopped:
+ break
+ else:
+ raise ValueError("Invalid stop field type.")
+
+ # prevent yielding partial stop sequence
+ if not partially_stopped:
+ yield {
+ "text": output,
+ "usage": {
+ "prompt_tokens": input_echo_len,
+ "completion_tokens": i,
+ "total_tokens": input_echo_len + i,
+ },
+ "finish_reason": None,
+ }
+
+ if stopped:
+ break
+
+ # finish stream event, which contains finish reason
+ if i == max_new_tokens - 1:
+ finish_reason = "length"
+ elif stopped:
+ finish_reason = "stop"
+ else:
+ finish_reason = None
+
+ yield {
+ "text": output,
+ "usage": {
+ "prompt_tokens": input_echo_len,
+ "completion_tokens": i,
+ "total_tokens": input_echo_len + i,
+ },
+ "finish_reason": finish_reason,
+ }
+
+ # clean
+ del past_key_values, out
+ gc.collect()
+ torch.cuda.empty_cache()
+
+
+class ChatIO(abc.ABC):
+ @abc.abstractmethod
+ def prompt_for_input(self, role: str) -> str:
+ """Prompt for input from a role."""
+
+ @abc.abstractmethod
+ def prompt_for_output(self, role: str):
+ """Prompt for output from a role."""
+
+ @abc.abstractmethod
+ def stream_output(self, output_stream):
+ """Stream output."""
+
+
+def chat_loop(
+ model_path: str,
+ device: str,
+ num_gpus: int,
+ max_gpu_memory: str,
+ load_8bit: bool,
+ cpu_offloading: bool,
+ conv_template: Optional[str],
+ temperature: float,
+ repetition_penalty: float,
+ max_new_tokens: int,
+ chatio: ChatIO,
+ debug: bool,
+):
+ # Model
+ model, tokenizer = load_model(
+ model_path, device, num_gpus, max_gpu_memory, load_8bit, cpu_offloading, debug
+ )
+ is_chatglm = "chatglm" in str(type(model)).lower()
+ is_fastchat_t5 = "t5" in str(type(model)).lower()
+
+ # Hardcode T5 repetition penalty to be 1.2
+ if is_fastchat_t5 and repetition_penalty == 1.0:
+ repetition_penalty = 1.2
+
+ # Chat
+ if conv_template:
+ conv = get_conv_template(conv_template)
+ else:
+ conv = get_conversation_template(model_path)
+
+ while True:
+ try:
+ inp = chatio.prompt_for_input(conv.roles[0])
+ except EOFError:
+ inp = ""
+ if not inp:
+ print("exit...")
+ break
+
+ conv.append_message(conv.roles[0], inp)
+ conv.append_message(conv.roles[1], None)
+
+ if is_chatglm:
+ generate_stream_func = chatglm_generate_stream
+ prompt = conv.messages[conv.offset :]
+ else:
+ generate_stream_func = generate_stream
+ prompt = conv.get_prompt()
+
+ gen_params = {
+ "model": model_path,
+ "prompt": prompt,
+ "temperature": temperature,
+ "repetition_penalty": repetition_penalty,
+ "max_new_tokens": max_new_tokens,
+ "stop": conv.stop_str,
+ "stop_token_ids": conv.stop_token_ids,
+ "echo": False,
+ }
+
+ chatio.prompt_for_output(conv.roles[1])
+ output_stream = generate_stream_func(model, tokenizer, gen_params, device)
+ outputs = chatio.stream_output(output_stream)
+ conv.update_last_message(outputs.strip())
+
+ if debug:
+ print("\n", {"prompt": prompt, "outputs": outputs}, "\n")
diff --git a/graphgpt/serve/model_worker.py b/graphgpt/serve/model_worker.py
new file mode 100644
index 0000000..8f9c2fa
--- /dev/null
+++ b/graphgpt/serve/model_worker.py
@@ -0,0 +1,427 @@
+"""
+A model worker executes the model.
+"""
+import argparse
+import asyncio
+import dataclasses
+import logging
+import json
+import os
+import time
+from typing import List, Union
+import threading
+import uuid
+
+from fastapi import FastAPI, Request, BackgroundTasks
+from fastapi.responses import StreamingResponse, JSONResponse
+import requests
+
+try:
+ from transformers import (
+ AutoTokenizer,
+ AutoModelForCausalLM,
+ LlamaTokenizer,
+ AutoModel,
+ )
+except ImportError:
+ from transformers import (
+ AutoTokenizer,
+ AutoModelForCausalLM,
+ LLaMATokenizer,
+ AutoModel,
+ )
+import torch
+import torch.nn.functional as F
+import uvicorn
+
+from fastchat.constants import WORKER_HEART_BEAT_INTERVAL, ErrorCode, SERVER_ERROR_MSG
+from fastchat.model.model_adapter import load_model, add_model_args
+from fastchat.model.chatglm_model import chatglm_generate_stream
+from fastchat.serve.inference import generate_stream
+from fastchat.utils import build_logger, pretty_print_semaphore
+
+GB = 1 << 30
+
+worker_id = str(uuid.uuid4())[:6]
+logger = build_logger("model_worker", f"model_worker_{worker_id}.log")
+global_counter = 0
+
+model_semaphore = None
+
+
+def heart_beat_worker(controller):
+ while True:
+ time.sleep(WORKER_HEART_BEAT_INTERVAL)
+ controller.send_heart_beat()
+
+
+class ModelWorker:
+ def __init__(
+ self,
+ controller_addr,
+ worker_addr,
+ worker_id,
+ no_register,
+ model_path,
+ model_name,
+ device,
+ num_gpus,
+ max_gpu_memory,
+ load_8bit=False,
+ cpu_offloading=False,
+ ):
+ self.controller_addr = controller_addr
+ self.worker_addr = worker_addr
+ self.worker_id = worker_id
+ if model_path.endswith("/"):
+ model_path = model_path[:-1]
+ self.model_name = model_name or model_path.split("/")[-1]
+ self.device = device
+
+ logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...")
+ self.model, self.tokenizer = load_model(
+ model_path, device, num_gpus, max_gpu_memory, load_8bit, cpu_offloading
+ )
+ if self.tokenizer.pad_token == None:
+ self.tokenizer.pad_token = self.tokenizer.eos_token
+
+ if hasattr(self.model.config, "max_sequence_length"):
+ self.context_len = self.model.config.max_sequence_length
+ elif hasattr(self.model.config, "max_position_embeddings"):
+ self.context_len = self.model.config.max_position_embeddings
+ else:
+ self.context_len = 2048
+
+ # generate_stream
+ is_chatglm = "chatglm" in str(type(self.model)).lower()
+ if is_chatglm:
+ self.generate_stream_func = chatglm_generate_stream
+ else:
+ self.generate_stream_func = generate_stream
+
+ if not no_register:
+ self.register_to_controller()
+ self.heart_beat_thread = threading.Thread(
+ target=heart_beat_worker, args=(self,)
+ )
+ self.heart_beat_thread.start()
+
+ def register_to_controller(self):
+ logger.info("Register to controller")
+
+ url = self.controller_addr + "/register_worker"
+ data = {
+ "worker_name": self.worker_addr,
+ "check_heart_beat": True,
+ "worker_status": self.get_status(),
+ }
+ r = requests.post(url, json=data)
+ assert r.status_code == 200
+
+ def send_heart_beat(self):
+ logger.info(
+ f"Send heart beat. Models: {[self.model_name]}. "
+ f"Semaphore: {pretty_print_semaphore(model_semaphore)}. "
+ f"global_counter: {global_counter}"
+ )
+
+ url = self.controller_addr + "/receive_heart_beat"
+
+ while True:
+ try:
+ ret = requests.post(
+ url,
+ json={
+ "worker_name": self.worker_addr,
+ "queue_length": self.get_queue_length(),
+ },
+ timeout=5,
+ )
+ exist = ret.json()["exist"]
+ break
+ except requests.exceptions.RequestException as e:
+ logger.error(f"heart beat error: {e}")
+ time.sleep(5)
+
+ if not exist:
+ self.register_to_controller()
+
+ def get_queue_length(self):
+ if (
+ model_semaphore is None
+ or model_semaphore._value is None
+ or model_semaphore._waiters is None
+ ):
+ return 0
+ else:
+ return (
+ args.limit_model_concurrency
+ - model_semaphore._value
+ + len(model_semaphore._waiters)
+ )
+
+ def get_status(self):
+ return {
+ "model_names": [self.model_name],
+ "speed": 1,
+ "queue_length": self.get_queue_length(),
+ }
+
+ def count_token(self, params):
+ prompt = params["prompt"]
+ input_ids = self.tokenizer(prompt).input_ids
+ input_echo_len = len(input_ids)
+
+ ret = {
+ "count": input_echo_len,
+ "error_code": 0,
+ }
+ return ret
+
+ def generate_stream_gate(self, params):
+ try:
+ for output in self.generate_stream_func(
+ self.model,
+ self.tokenizer,
+ params,
+ self.device,
+ self.context_len,
+ args.stream_interval,
+ ):
+ ret = {
+ "text": output["text"],
+ "error_code": 0,
+ }
+ if "usage" in output:
+ ret["usage"] = output["usage"]
+ if "finish_reason" in output:
+ ret["finish_reason"] = output["finish_reason"]
+ if "logprobs" in output:
+ ret["logprobs"] = output["logprobs"]
+ yield json.dumps(ret).encode() + b"\0"
+ except torch.cuda.OutOfMemoryError as e:
+ ret = {
+ "text": f"{SERVER_ERROR_MSG}\n\n({e})",
+ "error_code": ErrorCode.CUDA_OUT_OF_MEMORY,
+ }
+ yield json.dumps(ret).encode() + b"\0"
+ except (ValueError, RuntimeError) as e:
+ ret = {
+ "text": f"{SERVER_ERROR_MSG}\n\n({e})",
+ "error_code": ErrorCode.INTERNAL_ERROR,
+ }
+ yield json.dumps(ret).encode() + b"\0"
+
+ def generate_gate(self, params):
+ try:
+ ret = {"text": "", "error_code": 0}
+ for output in self.generate_stream_func(
+ self.model,
+ self.tokenizer,
+ params,
+ self.device,
+ self.context_len,
+ args.stream_interval,
+ ):
+ ret["text"] = output["text"]
+ if "usage" in output:
+ ret["usage"] = output["usage"]
+ if "finish_reason" in output:
+ ret["finish_reason"] = output["finish_reason"]
+ if "logprobs" in output:
+ ret["logprobs"] = output["logprobs"]
+ except torch.cuda.OutOfMemoryError as e:
+ ret = {
+ "text": f"{SERVER_ERROR_MSG}\n\n({e})",
+ "error_code": ErrorCode.CUDA_OUT_OF_MEMORY,
+ }
+ except (ValueError, RuntimeError) as e:
+ ret = {
+ "text": f"{SERVER_ERROR_MSG}\n\n({e})",
+ "error_code": ErrorCode.INTERNAL_ERROR,
+ }
+ return ret
+
+ @torch.inference_mode()
+ def get_embeddings(self, params):
+ try:
+ tokenizer = self.tokenizer
+ is_llama = "llama" in str(type(self.model)) # vicuna support batch inference
+ is_chatglm = "chatglm" in str(type(self.model))
+ is_t5 = "t5" in str(type(self.model))
+ if is_llama:
+ encoding = tokenizer.batch_encode_plus(
+ params["input"], padding=True, return_tensors="pt"
+ )
+ input_ids = encoding["input_ids"].to(self.device)
+ attention_mask = encoding["attention_mask"].to(self.device)
+ model_output = self.model(
+ input_ids, attention_mask, output_hidden_states=True
+ )
+ data = model_output.hidden_states[-1]
+ mask = attention_mask.unsqueeze(-1).expand(data.size()).float()
+ masked_embeddings = data * mask
+ sum_embeddings = torch.sum(masked_embeddings, dim=1)
+ seq_length = torch.sum(mask, dim=1)
+ embedding = sum_embeddings / seq_length
+ normalized_embeddings = F.normalize(embedding, p=2, dim=1)
+ ret = {
+ "embedding": normalized_embeddings.tolist(),
+ "token_num": torch.sum(attention_mask).item(),
+ }
+ else:
+ embedding = []
+ token_num = 0
+ for text in params["input"]:
+ input_ids = tokenizer.encode(text, return_tensors="pt").to(
+ self.device
+ )
+ if is_t5:
+ model_output = self.model(input_ids, decoder_input_ids=input_ids)
+ else:
+ model_output = self.model(input_ids, output_hidden_states=True)
+ if is_chatglm:
+ data = (model_output.hidden_states[-1].transpose(0, 1))[0]
+ elif is_t5:
+ data = model_output.encoder_last_hidden_state[0]
+ else:
+ data = model_output.hidden_states[-1][0]
+ data = F.normalize(torch.mean(data, dim=0), p=2, dim=0)
+ embedding.append(data.tolist())
+ token_num += len(input_ids[0])
+ ret = {
+ "embedding": embedding,
+ "token_num": token_num,
+ }
+ except torch.cuda.OutOfMemoryError as e:
+ ret = {
+ "text": f"{SERVER_ERROR_MSG}\n\n({e})",
+ "error_code": ErrorCode.CUDA_OUT_OF_MEMORY,
+ }
+ except (ValueError, RuntimeError) as e:
+ ret = {
+ "text": f"{SERVER_ERROR_MSG}\n\n({e})",
+ "error_code": ErrorCode.INTERNAL_ERROR,
+ }
+ return ret
+
+
+app = FastAPI()
+
+
+def release_model_semaphore():
+ model_semaphore.release()
+
+
+def acquire_model_semaphore():
+ global model_semaphore, global_counter
+ global_counter += 1
+ if model_semaphore is None:
+ model_semaphore = asyncio.Semaphore(args.limit_model_concurrency)
+ return model_semaphore.acquire()
+
+
+def create_background_tasks():
+ background_tasks = BackgroundTasks()
+ background_tasks.add_task(release_model_semaphore)
+ return background_tasks
+
+
+@app.post("/worker_generate_stream")
+async def api_generate_stream(request: Request):
+ params = await request.json()
+ await acquire_model_semaphore()
+ generator = worker.generate_stream_gate(params)
+ background_tasks = create_background_tasks()
+ return StreamingResponse(generator, background=background_tasks)
+
+
+@app.post("/worker_generate")
+async def api_generate(request: Request):
+ params = await request.json()
+ await acquire_model_semaphore()
+ output = worker.generate_gate(params)
+ release_model_semaphore()
+ return JSONResponse(output)
+
+
+@app.post("/worker_generate_completion_stream")
+async def api_generate_completion_stream(request: Request):
+ params = await request.json()
+ await acquire_model_semaphore()
+ generator = worker.generate_stream_gate(params)
+ background_tasks = create_background_tasks()
+ return StreamingResponse(generator, background=background_tasks)
+
+
+@app.post("/worker_generate_completion")
+async def api_generate_completion(request: Request):
+ params = await request.json()
+ await acquire_model_semaphore()
+ completion = worker.generate_gate(params)
+ background_tasks = create_background_tasks()
+ return JSONResponse(content=completion, background=background_tasks)
+
+
+@app.post("/worker_get_embeddings")
+async def api_get_embeddings(request: Request):
+ params = await request.json()
+ await acquire_model_semaphore()
+ embedding = worker.get_embeddings(params)
+ background_tasks = create_background_tasks()
+ return JSONResponse(content=embedding, background=background_tasks)
+
+
+@app.post("/worker_get_status")
+async def api_get_status(request: Request):
+ return worker.get_status()
+
+
+@app.post("/count_token")
+async def count_token(request: Request):
+ params = await request.json()
+ return worker.count_token(params)
+
+
+@app.post("/model_details")
+async def model_details(request: Request):
+ return {"context_length": worker.context_len}
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--host", type=str, default="localhost")
+ parser.add_argument("--port", type=int, default=21002)
+ parser.add_argument("--worker-address", type=str, default="http://localhost:21002")
+ parser.add_argument(
+ "--controller-address", type=str, default="http://localhost:21001"
+ )
+ add_model_args(parser)
+ parser.add_argument("--model-name", type=str, help="Optional display name")
+ parser.add_argument("--limit-model-concurrency", type=int, default=5)
+ parser.add_argument("--stream-interval", type=int, default=2)
+ parser.add_argument("--no-register", action="store_true")
+ args = parser.parse_args()
+ logger.info(f"args: {args}")
+
+ if args.gpus:
+ if len(args.gpus.split(",")) < args.num_gpus:
+ raise ValueError(
+ f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!"
+ )
+ os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
+
+ worker = ModelWorker(
+ args.controller_address,
+ args.worker_address,
+ worker_id,
+ args.no_register,
+ args.model_path,
+ args.model_name,
+ args.device,
+ args.num_gpus,
+ args.max_gpu_memory,
+ args.load_8bit,
+ args.cpu_offloading,
+ )
+ uvicorn.run(app, host=args.host, port=args.port, log_level="info")
diff --git a/graphgpt/serve/model_worker_graph.py b/graphgpt/serve/model_worker_graph.py
new file mode 100644
index 0000000..5e4b237
--- /dev/null
+++ b/graphgpt/serve/model_worker_graph.py
@@ -0,0 +1,285 @@
+"""
+A model worker executes the model.
+"""
+import argparse
+import asyncio
+import json
+import time
+import threading
+import uuid
+
+from fastapi import FastAPI, Request, BackgroundTasks
+from fastapi.responses import StreamingResponse
+import requests
+import torch
+import uvicorn
+from functools import partial
+
+from graphgpt.constants import WORKER_HEART_BEAT_INTERVAL
+from graphgpt.utils import (build_logger, server_error_msg,
+ pretty_print_semaphore)
+from graphgpt.model.builder import load_pretrained_model
+from graphgpt.mm_utils import process_images, load_image_from_base64, tokenizer_image_token, KeywordsStoppingCriteria
+from graphgpt.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
+from transformers import TextIteratorStreamer
+from threading import Thread
+
+
+GB = 1 << 30
+
+worker_id = str(uuid.uuid4())[:6]
+logger = build_logger("model_worker", f"model_worker_{worker_id}.log")
+global_counter = 0
+
+model_semaphore = None
+
+
+def heart_beat_worker(controller):
+
+ while True:
+ time.sleep(WORKER_HEART_BEAT_INTERVAL)
+ controller.send_heart_beat()
+
+
+class ModelWorker:
+ def __init__(self, controller_addr, worker_addr,
+ worker_id, no_register,
+ model_path, model_base, model_name,
+ load_8bit, load_4bit, device):
+ self.controller_addr = controller_addr
+ self.worker_addr = worker_addr
+ self.worker_id = worker_id
+ if model_path.endswith("/"):
+ model_path = model_path[:-1]
+ if model_name is None:
+ model_paths = model_path.split("/")
+ if model_paths[-1].startswith('checkpoint-'):
+ self.model_name = model_paths[-2] + "_" + model_paths[-1]
+ else:
+ self.model_name = model_paths[-1]
+ else:
+ self.model_name = model_name
+
+ self.device = device
+ logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...")
+ self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
+ model_path, model_base, self.model_name, load_8bit, load_4bit, device=self.device)
+ self.is_multimodal = 'llava' in self.model_name.lower()
+
+ if not no_register:
+ self.register_to_controller()
+ self.heart_beat_thread = threading.Thread(
+ target=heart_beat_worker, args=(self,))
+ self.heart_beat_thread.start()
+
+ def register_to_controller(self):
+ logger.info("Register to controller")
+
+ url = self.controller_addr + "/register_worker"
+ data = {
+ "worker_name": self.worker_addr,
+ "check_heart_beat": True,
+ "worker_status": self.get_status()
+ }
+ r = requests.post(url, json=data)
+ assert r.status_code == 200
+
+ def send_heart_beat(self):
+ logger.info(f"Send heart beat. Models: {[self.model_name]}. "
+ f"Semaphore: {pretty_print_semaphore(model_semaphore)}. "
+ f"global_counter: {global_counter}")
+
+ url = self.controller_addr + "/receive_heart_beat"
+
+ while True:
+ try:
+ ret = requests.post(url, json={
+ "worker_name": self.worker_addr,
+ "queue_length": self.get_queue_length()}, timeout=5)
+ exist = ret.json()["exist"]
+ break
+ except requests.exceptions.RequestException as e:
+ logger.error(f"heart beat error: {e}")
+ time.sleep(5)
+
+ if not exist:
+ self.register_to_controller()
+
+ def get_queue_length(self):
+ if model_semaphore is None:
+ return 0
+ else:
+ return args.limit_model_concurrency - model_semaphore._value + (len(
+ model_semaphore._waiters) if model_semaphore._waiters is not None else 0)
+
+ def get_status(self):
+ return {
+ "model_names": [self.model_name],
+ "speed": 1,
+ "queue_length": self.get_queue_length(),
+ }
+
+ @torch.inference_mode()
+ def generate_stream(self, params):
+ tokenizer, model, image_processor = self.tokenizer, self.model, self.image_processor
+
+ prompt = params["prompt"]
+ ori_prompt = prompt
+ images = params.get("images", None)
+ num_image_tokens = 0
+ if images is not None and len(images) > 0 and self.is_multimodal:
+ if len(images) > 0:
+ if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
+ raise ValueError("Number of images does not match number of tokens in prompt")
+
+ images = [load_image_from_base64(image) for image in images]
+ images = process_images(images, image_processor, model.config)
+
+ if type(images) is list:
+ images = [image.to(self.model.device, dtype=torch.float16) for image in images]
+ else:
+ images = images.to(self.model.device, dtype=torch.float16)
+
+ replace_token = DEFAULT_IMAGE_TOKEN
+ if getattr(self.model.config, 'mm_use_im_start_end', False):
+ replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
+ prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
+
+ num_image_tokens = prompt.count(replace_token) * model.get_vision_tower().num_patches
+ else:
+ images = None
+ image_args = {"images": images}
+ else:
+ images = None
+ image_args = {}
+
+ temperature = float(params.get("temperature", 1.0))
+ top_p = float(params.get("top_p", 1.0))
+ max_context_length = getattr(model.config, 'max_position_embeddings', 2048)
+ max_new_tokens = min(int(params.get("max_new_tokens", 256)), 1024)
+ stop_str = params.get("stop", None)
+ do_sample = True if temperature > 0.001 else False
+
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device)
+ keywords = [stop_str]
+ stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
+
+ max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens)
+
+ if max_new_tokens < 1:
+ yield json.dumps({"text": ori_prompt + "Exceeds max token length. Please start a new conversation, thanks.", "error_code": 0}).encode() + b"\0"
+ return
+
+ thread = Thread(target=model.generate, kwargs=dict(
+ inputs=input_ids,
+ do_sample=do_sample,
+ temperature=temperature,
+ top_p=top_p,
+ max_new_tokens=max_new_tokens,
+ streamer=streamer,
+ stopping_criteria=[stopping_criteria],
+ use_cache=True,
+ **image_args
+ ))
+ thread.start()
+
+ generated_text = ori_prompt
+ for new_text in streamer:
+ generated_text += new_text
+ if generated_text.endswith(stop_str):
+ generated_text = generated_text[:-len(stop_str)]
+ yield json.dumps({"text": generated_text, "error_code": 0}).encode() + b"\0"
+
+ def generate_stream_gate(self, params):
+ try:
+ for x in self.generate_stream(params):
+ yield x
+ except ValueError as e:
+ print("Caught ValueError:", e)
+ ret = {
+ "text": server_error_msg,
+ "error_code": 1,
+ }
+ yield json.dumps(ret).encode() + b"\0"
+ except torch.cuda.CudaError as e:
+ print("Caught torch.cuda.CudaError:", e)
+ ret = {
+ "text": server_error_msg,
+ "error_code": 1,
+ }
+ yield json.dumps(ret).encode() + b"\0"
+ except Exception as e:
+ print("Caught Unknown Error", e)
+ ret = {
+ "text": server_error_msg,
+ "error_code": 1,
+ }
+ yield json.dumps(ret).encode() + b"\0"
+
+
+app = FastAPI()
+
+
+def release_model_semaphore(fn=None):
+ model_semaphore.release()
+ if fn is not None:
+ fn()
+
+
+@app.post("/worker_generate_stream")
+async def generate_stream(request: Request):
+ global model_semaphore, global_counter
+ global_counter += 1
+ params = await request.json()
+
+ if model_semaphore is None:
+ model_semaphore = asyncio.Semaphore(args.limit_model_concurrency)
+ await model_semaphore.acquire()
+ worker.send_heart_beat()
+ generator = worker.generate_stream_gate(params)
+ background_tasks = BackgroundTasks()
+ background_tasks.add_task(partial(release_model_semaphore, fn=worker.send_heart_beat))
+ return StreamingResponse(generator, background=background_tasks)
+
+
+@app.post("/worker_get_status")
+async def get_status(request: Request):
+ return worker.get_status()
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--host", type=str, default="localhost")
+ parser.add_argument("--port", type=int, default=21002)
+ parser.add_argument("--worker-address", type=str,
+ default="http://localhost:21002")
+ parser.add_argument("--controller-address", type=str,
+ default="http://localhost:21001")
+ parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
+ parser.add_argument("--model-base", type=str, default=None)
+ parser.add_argument("--model-name", type=str)
+ parser.add_argument("--device", type=str, default="cuda")
+ parser.add_argument("--multi-modal", action="store_true", help="Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path.")
+ parser.add_argument("--limit-model-concurrency", type=int, default=5)
+ parser.add_argument("--stream-interval", type=int, default=1)
+ parser.add_argument("--no-register", action="store_true")
+ parser.add_argument("--load-8bit", action="store_true")
+ parser.add_argument("--load-4bit", action="store_true")
+ args = parser.parse_args()
+ logger.info(f"args: {args}")
+
+ if args.multi_modal:
+ logger.warning("Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path.")
+
+ worker = ModelWorker(args.controller_address,
+ args.worker_address,
+ worker_id,
+ args.no_register,
+ args.model_path,
+ args.model_base,
+ args.model_name,
+ args.load_8bit,
+ args.load_4bit,
+ args.device)
+ uvicorn.run(app, host=args.host, port=args.port, log_level="info")
\ No newline at end of file
diff --git a/graphgpt/serve/monitor/basic_stats.py b/graphgpt/serve/monitor/basic_stats.py
new file mode 100644
index 0000000..c910e12
--- /dev/null
+++ b/graphgpt/serve/monitor/basic_stats.py
@@ -0,0 +1,198 @@
+import argparse
+import code
+import datetime
+import json
+import os
+from pytz import timezone
+import time
+
+import pandas as pd
+import plotly.express as px
+import plotly.graph_objects as go
+from tqdm import tqdm
+
+
+def get_log_files(max_num_files=None):
+ dates = []
+ for month in [4, 5]:
+ for day in range(1, 32):
+ dates.append(f"2023-{month:02d}-{day:02d}")
+
+ num_servers = 12
+ filenames = []
+ for d in dates:
+ for i in range(num_servers):
+ name = os.path.expanduser(f"~/fastchat_logs/server{i}/{d}-conv.json")
+ if os.path.exists(name):
+ filenames.append(name)
+ max_num_files = max_num_files or len(filenames)
+ filenames = filenames[-max_num_files:]
+ return filenames
+
+
+def load_log_files(log_files):
+ data = []
+ for filename in tqdm(log_files, desc="read files"):
+ for retry in range(5):
+ try:
+ lines = open(filename).readlines()
+ break
+ except FileNotFoundError:
+ time.sleep(2)
+
+ for l in lines:
+ row = json.loads(l)
+
+ data.append(
+ dict(
+ type=row["type"],
+ tstamp=row["tstamp"],
+ model=row.get("model", ""),
+ models=row.get("models", ["", ""]),
+ )
+ )
+
+ return data
+
+
+def get_anony_vote_df(df):
+ anony_vote_df = df[
+ df["type"].isin(["leftvote", "rightvote", "tievote", "bothbad_vote"])
+ ]
+ anony_vote_df = anony_vote_df[
+ anony_vote_df["models"].apply(lambda x: x[0] == "")
+ ]
+ return anony_vote_df
+
+
+def merge_counts(series, on, names):
+ ret = pd.merge(series[0], series[1], on=on)
+ for i in range(2, len(series)):
+ ret = pd.merge(ret, series[i], on=on)
+ ret = ret.reset_index()
+ old_names = list(ret.columns)[-len(series) :]
+ rename = {old_name: new_name for old_name, new_name in zip(old_names, names)}
+ ret = ret.rename(columns=rename)
+ return ret
+
+
+def report_basic_stats(log_files):
+ df_all = load_log_files(log_files)
+ df_all = pd.DataFrame(df_all)
+ now_t = df_all["tstamp"].max()
+ df_1_hour = df_all[df_all["tstamp"] > (now_t - 3600)]
+ df_1_day = df_all[df_all["tstamp"] > (now_t - 3600 * 24)]
+ anony_vote_df_all = get_anony_vote_df(df_all)
+
+ # Chat trends
+ chat_dates = [
+ datetime.datetime.fromtimestamp(
+ x, tz=timezone("US/Pacific")
+ ).strftime("%Y-%m-%d")
+ for x in df_all[df_all["type"] == "chat"]["tstamp"]
+ ]
+ chat_dates_counts = pd.value_counts(chat_dates)
+ vote_dates = [
+ datetime.datetime.fromtimestamp(
+ x, tz=timezone("US/Pacific")
+ ).strftime("%Y-%m-%d")
+ for x in anony_vote_df_all["tstamp"]
+ ]
+ vote_dates_counts = pd.value_counts(vote_dates)
+ chat_dates_bar = go.Figure(data=[
+ go.Bar(name="Anony. Vote", x=vote_dates_counts.index, y=vote_dates_counts,
+ text=[f"{val:.0f}" for val in vote_dates_counts], textposition="auto"),
+ go.Bar(name="Chat", x=chat_dates_counts.index, y=chat_dates_counts,
+ text=[f"{val:.0f}" for val in chat_dates_counts], textposition="auto"),
+ ])
+ chat_dates_bar.update_layout(
+ barmode="stack",
+ xaxis_title="Dates",
+ yaxis_title="Count",
+ height=300,
+ width=1200,
+ )
+
+ # Model call counts
+ model_hist_all = df_all[df_all["type"] == "chat"]["model"].value_counts()
+ model_hist_1_day = df_1_day[df_1_day["type"] == "chat"]["model"].value_counts()
+ model_hist_1_hour = df_1_hour[df_1_hour["type"] == "chat"]["model"].value_counts()
+ model_hist = merge_counts(
+ [model_hist_all, model_hist_1_day, model_hist_1_hour],
+ on="model",
+ names=["All", "Last Day", "Last Hour"],
+ )
+ model_hist_md = model_hist.to_markdown(index=False, tablefmt="github")
+
+ # Action counts
+ action_hist_all = df_all["type"].value_counts()
+ action_hist_1_day = df_1_day["type"].value_counts()
+ action_hist_1_hour = df_1_hour["type"].value_counts()
+ action_hist = merge_counts(
+ [action_hist_all, action_hist_1_day, action_hist_1_hour],
+ on="type",
+ names=["All", "Last Day", "Last Hour"],
+ )
+ action_hist_md = action_hist.to_markdown(index=False, tablefmt="github")
+
+ # Anony vote counts
+ anony_vote_hist_all = anony_vote_df_all["type"].value_counts()
+ anony_vote_df_1_day = get_anony_vote_df(df_1_day)
+ anony_vote_hist_1_day = anony_vote_df_1_day["type"].value_counts()
+ anony_vote_df_1_hour = get_anony_vote_df(df_1_hour)
+ anony_vote_hist_1_hour = anony_vote_df_1_hour["type"].value_counts()
+ anony_vote_hist = merge_counts(
+ [anony_vote_hist_all, anony_vote_hist_1_day, anony_vote_hist_1_hour],
+ on="type",
+ names=["All", "Last Day", "Last Hour"],
+ )
+ anony_vote_hist_md = anony_vote_hist.to_markdown(index=False, tablefmt="github")
+
+ # Last 24 hours
+ chat_1_day = df_1_day[df_1_day["type"] == "chat"]
+ num_chats_last_24_hours = []
+ base = df_1_day["tstamp"].min()
+ for i in range(24, 0, -1):
+ left = base + (i - 1) * 3600
+ right = base + i * 3600
+ num = ((chat_1_day["tstamp"] >= left) & (chat_1_day["tstamp"] < right)).sum()
+ num_chats_last_24_hours.append(num)
+ times = [
+ datetime.datetime.fromtimestamp(
+ base + i * 3600, tz=timezone("US/Pacific")
+ ).strftime("%Y-%m-%d %H:%M:%S %Z")
+ for i in range(24, 0, -1)
+ ]
+ last_24_hours_df = pd.DataFrame({"time": times, "value": num_chats_last_24_hours})
+ last_24_hours_md = last_24_hours_df.to_markdown(index=False, tablefmt="github")
+
+ # Last update datetime
+ last_updated_tstamp = now_t
+ last_updated_datetime = datetime.datetime.fromtimestamp(
+ last_updated_tstamp, tz=timezone("US/Pacific")
+ ).strftime("%Y-%m-%d %H:%M:%S %Z")
+
+ # code.interact(local=locals())
+
+ return {
+ "chat_dates_bar": chat_dates_bar,
+ "model_hist_md": model_hist_md,
+ "action_hist_md": action_hist_md,
+ "anony_vote_hist_md": anony_vote_hist_md,
+ "num_chats_last_24_hours": last_24_hours_md,
+ "last_updated_datetime": last_updated_datetime,
+ }
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--max-num-files", type=int)
+ args = parser.parse_args()
+
+ log_files = get_log_files(args.max_num_files)
+ basic_stats = report_basic_stats(log_files)
+
+ print(basic_stats["action_hist_md"] + "\n")
+ print(basic_stats["model_hist_md"] + "\n")
+ print(basic_stats["anony_vote_hist_md"] + "\n")
+ print(basic_stats["num_chats_last_24_hours"] + "\n")
diff --git a/graphgpt/serve/monitor/clean_battle_data.py b/graphgpt/serve/monitor/clean_battle_data.py
new file mode 100644
index 0000000..73c1a48
--- /dev/null
+++ b/graphgpt/serve/monitor/clean_battle_data.py
@@ -0,0 +1,195 @@
+import argparse
+import datetime
+import json
+from pytz import timezone
+import os
+import time
+
+from tqdm import tqdm
+
+from fastchat.serve.monitor.basic_stats import get_log_files
+from fastchat.utils import detect_language
+
+
+VOTES = ["tievote", "leftvote", "rightvote", "bothbad_vote"]
+IDENTITY_WORDS = [
+ "vicuna",
+ "lmsys",
+ "koala",
+ "uc berkeley",
+ "open assistant",
+ "laion",
+ "chatglm",
+ "chatgpt",
+ "openai",
+ "anthropic",
+ "claude",
+ "bard",
+ "palm",
+ "Lamda",
+ "google",
+ "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**",
+]
+
+
+def get_log_files(max_num_files=None):
+ dates = []
+ for month in [4]:
+ for day in range(24, 32):
+ dates.append(f"2023-{month:02d}-{day:02d}")
+ for month in [5]:
+ for day in range(1, 24):
+ dates.append(f"2023-{month:02d}-{day:02d}")
+ cutoff_date = dates[-1].replace("-", "")
+
+ num_servers = 12
+ filenames = []
+ for d in dates:
+ for i in range(num_servers):
+ name = os.path.expanduser(f"~/fastchat_logs/server{i}/{d}-conv.json")
+ if os.path.exists(name):
+ filenames.append(name)
+ max_num_files = max_num_files or len(filenames)
+ filenames = filenames[-max_num_files:]
+ return filenames, cutoff_date
+
+
+def remove_html(raw):
+ if raw.startswith(""):
+ return raw[raw.find(": ") + 2 : -len("
\n")]
+ return raw
+
+
+def clean_battle_data(log_files):
+ data = []
+ for filename in tqdm(log_files, desc="read files"):
+ for retry in range(5):
+ try:
+ lines = open(filename).readlines()
+ break
+ except FileNotFoundError:
+ time.sleep(2)
+
+ for l in lines:
+ row = json.loads(l)
+ if row["type"] in VOTES:
+ data.append(row)
+
+ convert_type = {
+ "leftvote": "model_a",
+ "rightvote": "model_b",
+ "tievote": "tie",
+ "bothbad_vote": "tie (bothbad)",
+ }
+
+ all_models = set()
+ ct_annoy = 0
+ ct_invalid = 0
+ ct_leaked_identity = 0
+ battles = []
+ for row in data:
+ # Resolve model names
+ models_public = [remove_html(row["models"][0]), remove_html(row["models"][1])]
+ if "model_name" in row["states"][0]:
+ models_hidden = [
+ row["states"][0]["model_name"],
+ row["states"][1]["model_name"],
+ ]
+ if models_hidden[0] is None:
+ models_hidden = models_public
+ else:
+ models_hidden = models_public
+
+ if (models_public[0] == "" and models_public[1] != "") or (
+ models_public[1] == "" and models_public[0] != ""
+ ):
+ ct_invalid += 1
+ continue
+
+ if models_public[0] == "" or models_public[0] == "Model A":
+ anony = True
+ models = models_hidden
+ ct_annoy += 1
+ else:
+ anony = False
+ models = models_public
+ if not models_public == models_hidden:
+ ct_invalid += 1
+ continue
+
+ # Detect langauge
+ state = row["states"][0]
+ if state["offset"] >= len(state["messages"]):
+ ct_invalid += 1
+ continue
+ lang_code = detect_language(state["messages"][state["offset"]][1])
+ rounds = (len(state["messages"]) - state["offset"]) // 2
+
+ # Drop conversations if the model names are leaked
+ leaked_identity = False
+ messages = ""
+ for i in range(2):
+ state = row["states"][i]
+ for role, msg in state["messages"][state["offset"] :]:
+ if msg:
+ messages += msg.lower()
+ for word in IDENTITY_WORDS:
+ if word in messages:
+ leaked_identity = True
+ break
+
+ if leaked_identity:
+ ct_leaked_identity += 1
+ continue
+
+ # Replace bard with palm
+ models = [m.replace("bard", "palm-2") for m in models]
+
+ # Keep the result
+ battles.append(
+ dict(
+ model_a=models[0],
+ model_b=models[1],
+ win=convert_type[row["type"]],
+ anony=anony,
+ rounds=rounds,
+ language=lang_code,
+ tstamp=row["tstamp"],
+ )
+ )
+
+ all_models.update(models_hidden)
+ battles.sort(key=lambda x: x["tstamp"])
+ last_updated_tstamp = battles[-1]["tstamp"]
+
+ last_updated_datetime = datetime.datetime.fromtimestamp(
+ last_updated_tstamp, tz=timezone("US/Pacific")
+ ).strftime("%Y-%m-%d %H:%M:%S %Z")
+
+ print(
+ f"#votes: {len(data)}, #invalid votes: {ct_invalid}, "
+ f"#leaked_identity: {ct_leaked_identity}"
+ )
+ print(f"#battles: {len(battles)}, #annoy: {ct_annoy}")
+ print(f"#models: {len(all_models)}, {all_models}")
+ print(f"last-updated: {last_updated_datetime}")
+
+ return battles
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--max-num-files", type=int)
+ args = parser.parse_args()
+
+ log_files, cutoff_date = get_log_files(args.max_num_files)
+ battles = clean_battle_data(log_files)
+
+ print("Samples:")
+ for i in range(4):
+ print(battles[i])
+
+ output = f"clean_battle_{cutoff_date}.json"
+ with open(output, "w") as fout:
+ json.dump(battles, fout, indent=2)
+ print(f"Write cleaned data to {output}")
diff --git a/graphgpt/serve/monitor/elo_analysis.py b/graphgpt/serve/monitor/elo_analysis.py
new file mode 100644
index 0000000..ea0307c
--- /dev/null
+++ b/graphgpt/serve/monitor/elo_analysis.py
@@ -0,0 +1,283 @@
+import argparse
+from collections import defaultdict
+import datetime
+import json
+import math
+import pickle
+from pytz import timezone
+
+import gdown
+import numpy as np
+import pandas as pd
+import plotly.express as px
+from tqdm import tqdm
+
+from fastchat.model.model_registry import get_model_info
+from fastchat.serve.monitor.basic_stats import get_log_files
+from fastchat.serve.monitor.clean_battle_data import clean_battle_data
+
+
+pd.options.display.float_format = "{:.2f}".format
+
+
+def compute_elo(battles, K=4, SCALE=400, BASE=10, INIT_RATING=1000):
+ rating = defaultdict(lambda: INIT_RATING)
+
+ for rd, model_a, model_b, win in battles[
+ ["model_a", "model_b", "win"]
+ ].itertuples():
+ ra = rating[model_a]
+ rb = rating[model_b]
+ ea = 1 / (1 + BASE ** ((rb - ra) / SCALE))
+ eb = 1 / (1 + BASE ** ((ra - rb) / SCALE))
+ if win == "model_a":
+ sa = 1
+ elif win == "model_b":
+ sa = 0
+ elif win == "tie" or win == "tie (bothbad)":
+ sa = 0.5
+ else:
+ raise Exception(f"unexpected vote {win}")
+ rating[model_a] += K * (sa - ea)
+ rating[model_b] += K * (1 - sa - eb)
+
+ return dict(rating)
+
+
+def get_bootstrap_result(battles, func_compute_elo, num_round=1000):
+ rows = []
+ for i in tqdm(range(num_round), desc="bootstrap"):
+ tmp_battles = battles.sample(frac=1.0, replace=True)
+ # tmp_battles = tmp_battles.sort_values(ascending=True, by=["tstamp"])
+ rows.append(func_compute_elo(tmp_battles))
+ df = pd.DataFrame(rows)
+ return df[df.median().sort_values(ascending=False).index]
+
+
+def get_elo_from_bootstrap(bootstrap_df):
+ return dict(bootstrap_df.quantile(0.5))
+
+
+def compute_pairwise_win_fraction(battles, model_order):
+ # Times each model wins as Model A
+ a_win_ptbl = pd.pivot_table(
+ battles[battles["win"] == "model_a"],
+ index="model_a",
+ columns="model_b",
+ aggfunc="size",
+ fill_value=0,
+ )
+
+ # Table counting times each model wins as Model B
+ b_win_ptbl = pd.pivot_table(
+ battles[battles["win"] == "model_b"],
+ index="model_a",
+ columns="model_b",
+ aggfunc="size",
+ fill_value=0,
+ )
+
+ # Table counting number of A-B pairs
+ num_battles_ptbl = pd.pivot_table(
+ battles, index="model_a", columns="model_b", aggfunc="size", fill_value=0
+ )
+
+ # Computing the proportion of wins for each model as A and as B
+ # against all other models
+ row_beats_col_freq = (a_win_ptbl + b_win_ptbl.T) / (
+ num_battles_ptbl + num_battles_ptbl.T
+ )
+
+ if model_order is None:
+ prop_wins = row_beats_col_freq.mean(axis=1).sort_values(ascending=False)
+ model_order = list(prop_wins.keys())
+
+ # Arrange ordering according to proprition of wins
+ row_beats_col = row_beats_col_freq.loc[model_order, model_order]
+ return row_beats_col
+
+
+def visualize_leaderboard_table(rating):
+ models = list(rating.keys())
+ models.sort(key=lambda k: -rating[k])
+
+ emoji_dict = {
+ 1: "🥇",
+ 2: "🥈",
+ 3: "🥉",
+ }
+
+ md = ""
+ md += "| Rank | Model | Elo Rating | Description |\n"
+ md += "| --- | --- | --- | --- |\n"
+ for i, model in enumerate(models):
+ rank = i + 1
+ minfo = get_model_info(model)
+ emoji = emoji_dict.get(rank, "")
+ md += f"| {rank} | {emoji} [{model}]({minfo.link}) | {rating[model]:.0f} | {minfo.description} |\n"
+
+ return md
+
+
+def visualize_pairwise_win_fraction(battles, model_order):
+ row_beats_col = compute_pairwise_win_fraction(battles, model_order)
+ fig = px.imshow(
+ row_beats_col,
+ color_continuous_scale="RdBu",
+ text_auto=".2f",
+ height=600,
+ width=600,
+ )
+ fig.update_layout(
+ xaxis_title="Model B",
+ yaxis_title="Model A",
+ xaxis_side="top",
+ title_y=0.07,
+ title_x=0.5,
+ )
+ fig.update_traces(
+ hovertemplate="Model A: %{y}
Model B: %{x}
Fraction of A Wins: %{z}"
+ )
+
+ return fig
+
+
+def visualize_battle_count(battles, model_order):
+ ptbl = pd.pivot_table(
+ battles, index="model_a", columns="model_b", aggfunc="size", fill_value=0
+ )
+ battle_counts = ptbl + ptbl.T
+ fig = px.imshow(
+ battle_counts.loc[model_order, model_order],
+ text_auto=True,
+ height=600,
+ width=600,
+ )
+ fig.update_layout(
+ xaxis_title="Model B",
+ yaxis_title="Model A",
+ xaxis_side="top",
+ title_y=0.07,
+ title_x=0.5,
+ )
+ fig.update_traces(
+ hovertemplate="Model A: %{y}
Model B: %{x}
Count: %{z}"
+ )
+ return fig
+
+
+def visualize_average_win_rate(battles):
+ row_beats_col_freq = compute_pairwise_win_fraction(battles, None)
+ fig = px.bar(
+ row_beats_col_freq.mean(axis=1).sort_values(ascending=False),
+ text_auto=".2f",
+ height=400,
+ width=600,
+ )
+ fig.update_layout(
+ yaxis_title="Average Win Rate", xaxis_title="Model", showlegend=False
+ )
+ return fig
+
+
+def visualize_bootstrap_elo_rating(df):
+ bars = (
+ pd.DataFrame(
+ dict(
+ lower=df.quantile(0.025),
+ rating=df.quantile(0.5),
+ upper=df.quantile(0.975),
+ )
+ )
+ .reset_index(names="model")
+ .sort_values("rating", ascending=False)
+ )
+ bars["error_y"] = bars["upper"] - bars["rating"]
+ bars["error_y_minus"] = bars["rating"] - bars["lower"]
+ bars["rating_rounded"] = np.round(bars["rating"], 2)
+ fig = px.scatter(
+ bars,
+ x="model",
+ y="rating",
+ error_y="error_y",
+ error_y_minus="error_y_minus",
+ text="rating_rounded",
+ height=400,
+ width=600,
+ )
+ fig.update_layout(xaxis_title="Model", yaxis_title="Rating")
+ return fig
+
+
+def report_elo_analysis_results(battles_json):
+ battles = pd.DataFrame(battles_json)
+ battles = battles.sort_values(ascending=True, by=["tstamp"])
+ # Only use anonymous votes
+ battles = battles[battles["anony"]].reset_index(drop=True)
+ battles_no_ties = battles[~battles["win"].str.contains("tie")]
+
+ # Online update
+ elo_rating_online = compute_elo(battles)
+
+ # Bootstrap
+ bootstrap_df = get_bootstrap_result(battles, compute_elo)
+ elo_rating_median = get_elo_from_bootstrap(bootstrap_df)
+ elo_rating_median = {k: int(v + 0.5) for k, v in elo_rating_median.items()}
+ model_order = list(elo_rating_online.keys())
+ model_order.sort(key=lambda k: -elo_rating_online[k])
+
+ # Plots
+ leaderboard_table = visualize_leaderboard_table(elo_rating_online)
+ win_fraction_heatmap = visualize_pairwise_win_fraction(battles_no_ties, model_order)
+ battle_count_heatmap = visualize_battle_count(battles_no_ties, model_order)
+ average_win_rate_bar = visualize_average_win_rate(battles_no_ties)
+ bootstrap_elo_rating = visualize_bootstrap_elo_rating(bootstrap_df)
+
+ last_updated_tstamp = battles["tstamp"].max()
+ last_updated_datetime = datetime.datetime.fromtimestamp(
+ last_updated_tstamp, tz=timezone("US/Pacific")
+ ).strftime("%Y-%m-%d %H:%M:%S %Z")
+
+ return {
+ "elo_rating_online": elo_rating_online,
+ "elo_rating_median": elo_rating_median,
+ "leaderboard_table": leaderboard_table,
+ "win_fraction_heatmap": win_fraction_heatmap,
+ "battle_count_heatmap": battle_count_heatmap,
+ "average_win_rate_bar": average_win_rate_bar,
+ "bootstrap_elo_rating": bootstrap_elo_rating,
+ "last_updated_datetime": last_updated_datetime,
+ }
+
+
+def pretty_print_elo_rating(rating):
+ model_order = list(rating.keys())
+ model_order.sort(key=lambda k: -rating[k])
+ for i, model in enumerate(model_order):
+ print(f"{i+1:2d}, {model:25s}, {rating[model]:.0f}")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--clean-battle-file", type=str)
+ parser.add_argument("--max-num-files", type=int)
+ args = parser.parse_args()
+
+ if args.clean_battle_file:
+ # Read data from a cleaned battle files
+ battles = pd.read_json(args.clean_battle_file)
+ else:
+ # Read data from all log files
+ log_files = get_log_files(args.max_num_files)
+ battles = clean_battle_data(log_files)
+
+ results = report_elo_analysis_results(battles)
+
+ print("# Online")
+ pretty_print_elo_rating(results["elo_rating_online"])
+ print("# Median")
+ pretty_print_elo_rating(results["elo_rating_median"])
+ print(f"last update : {results['last_updated_datetime']}")
+
+ with open("elo_results.pkl", "wb") as fout:
+ pickle.dump(results, fout)
diff --git a/graphgpt/serve/monitor/hf_space_leaderboard_app.py b/graphgpt/serve/monitor/hf_space_leaderboard_app.py
new file mode 100644
index 0000000..44ae4a8
--- /dev/null
+++ b/graphgpt/serve/monitor/hf_space_leaderboard_app.py
@@ -0,0 +1,86 @@
+"""A gradio app that renders a static leaderboard. This is used for Hugging Face Space."""
+import argparse
+import pickle
+
+import gradio as gr
+
+
+notebook_url = "https://colab.research.google.com/drive/17L9uCiAivzWfzOxo2Tb9RMauT7vS6nVU?usp=sharing"
+
+
+def make_leaderboard_md(elo_results):
+ leaderboard_md = f"""
+# Leaderboard
+[[Blog](https://lmsys.org/blog/2023-05-03-arena/)] [[Vote](https://arena.lmsys.org/)] [[Github]](https://github.com/lm-sys/FastChat) [[Twitter]](https://twitter.com/lmsysorg) [[Discord]](https://discord.gg/KjdtsE9V)
+
+We use the Elo rating system to calculate the relative performance of the models. You can view the voting data, basic analyses, and calculation procedure in this [notebook]({notebook_url}). We will periodically release new leaderboards. If you want to see more models, please help us [add them](https://github.com/lm-sys/FastChat/blob/main/docs/arena.md#how-to-add-a-new-model).
+Last updated: {elo_results["last_updated_datetime"]}
+{elo_results["leaderboard_table"]}
+"""
+ return leaderboard_md
+
+
+def build_leaderboard_tab(elo_results_file):
+ if elo_results_file is not None:
+ with open(elo_results_file, "rb") as fin:
+ elo_results = pickle.load(fin)
+
+ md = make_leaderboard_md(elo_results)
+ p1 = elo_results["win_fraction_heatmap"]
+ p2 = elo_results["battle_count_heatmap"]
+ p3 = elo_results["average_win_rate_bar"]
+ p4 = elo_results["bootstrap_elo_rating"]
+ else:
+ md = "Loading ..."
+ p1 = p2 = p3 = p4 = None
+
+ md_1 = gr.Markdown(md)
+ gr.Markdown(
+ f"""## More Statistics\n
+We added some additional figures to show more statistics. The code for generating them is also included in this [notebook]({notebook_url}).
+Please note that you may see different orders from different ranking methods. This is expected for models that perform similarly, as demonstrated by the confidence interval in the bootstrap figure. Going forward, we prefer the classical Elo calculation because of its scalability and interpretability. You can find more discussions in this blog [post](https://lmsys.org/blog/2023-05-03-arena/).
+"""
+ )
+
+ with gr.Row():
+ with gr.Column():
+ gr.Markdown(
+ "#### Figure 1: Fraction of Model A Wins for All Non-tied A vs. B Battles"
+ )
+ plot_1 = gr.Plot(p1, show_label=False)
+ with gr.Column():
+ gr.Markdown(
+ "#### Figure 2: Battle Count for Each Combination of Models (without Ties)"
+ )
+ plot_2 = gr.Plot(p2, show_label=False)
+ with gr.Row():
+ with gr.Column():
+ gr.Markdown(
+ "#### Figure 3: Average Win Rate Against All Other Models (Assuming Uniform Sampling and No Ties)"
+ )
+ plot_3 = gr.Plot(p3, show_label=False)
+ with gr.Column():
+ gr.Markdown(
+ "#### Figure 4: Bootstrap of Elo Estimates (1000 Rounds of Random Sampling)"
+ )
+ plot_4 = gr.Plot(p4, show_label=False)
+ return [md_1, plot_1, plot_2, plot_3, plot_4]
+
+
+def build_demo(elo_results_file):
+ with gr.Blocks(
+ title="Chatbot Arena Leaderboard",
+ theme=gr.themes.Base(),
+ ) as demo:
+ leader_components = build_leaderboard_tab(elo_results_file)
+
+ return demo
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--share", action="store_true")
+ args = parser.parse_args()
+
+ demo = build_demo("elo_results_20230508.pkl")
+ demo.launch(share=args.share)
diff --git a/graphgpt/serve/monitor/monitor.py b/graphgpt/serve/monitor/monitor.py
new file mode 100644
index 0000000..ac9d26e
--- /dev/null
+++ b/graphgpt/serve/monitor/monitor.py
@@ -0,0 +1,209 @@
+# sudo apt install pkg-config libicu-dev
+# pip install pytz gradio gdown plotly polyglot pyicu pycld2 tabulate
+
+import argparse
+import pickle
+import os
+import threading
+import time
+
+import gradio as gr
+
+from fastchat.serve.monitor.basic_stats import report_basic_stats, get_log_files
+from fastchat.serve.monitor.clean_battle_data import clean_battle_data
+from fastchat.serve.monitor.elo_analysis import report_elo_analysis_results
+from fastchat.utils import build_logger, get_window_url_params_js
+
+
+notebook_url = "https://colab.research.google.com/drive/17L9uCiAivzWfzOxo2Tb9RMauT7vS6nVU?usp=sharing"
+
+
+logger = build_logger("monitor", "monitor.log")
+
+
+basic_component_values = [None] * 6
+leader_component_values = [None] * 5
+
+
+def make_leaderboard_md(elo_results):
+ leaderboard_md = f"""
+# Leaderboard
+[[Blog](https://lmsys.org/blog/2023-05-03-arena/)] [[GitHub]](https://github.com/lm-sys/FastChat) [[Twitter]](https://twitter.com/lmsysorg) [[Discord]](https://discord.gg/KjdtsE9V)
+
+We use the Elo rating system to calculate the relative performance of the models. You can view the voting data, basic analyses, and calculation procedure in this [notebook]({notebook_url}). We will periodically release new leaderboards. If you want to see more models, please help us [add them](https://github.com/lm-sys/FastChat/blob/main/docs/arena.md#how-to-add-a-new-model).
+Last updated: {elo_results["last_updated_datetime"]}
+{elo_results["leaderboard_table"]}
+"""
+ return leaderboard_md
+
+
+def update_elo_components(max_num_files, elo_results_file):
+ log_files = get_log_files(max_num_files)
+
+ # Leaderboard
+ if elo_results_file is None:
+ battles = clean_battle_data(log_files)
+ elo_results = report_elo_analysis_results(battles)
+
+ leader_component_values[0] = make_leaderboard_md(elo_results)
+ leader_component_values[1] = elo_results["win_fraction_heatmap"]
+ leader_component_values[2] = elo_results["battle_count_heatmap"]
+ leader_component_values[3] = elo_results["average_win_rate_bar"]
+ leader_component_values[4] = elo_results["bootstrap_elo_rating"]
+
+ # Basic stats
+ basic_stats = report_basic_stats(log_files)
+ md0 = f"Last updated: {basic_stats['last_updated_datetime']}"
+
+ md1 = "### Action Histogram\n"
+ md1 += basic_stats["action_hist_md"] + "\n"
+
+ md2 = "### Anony. Vote Histogram\n"
+ md2 += basic_stats["anony_vote_hist_md"] + "\n"
+
+ md3 = "### Model Call Histogram\n"
+ md3 += basic_stats["model_hist_md"] + "\n"
+
+ md4 = "### Model Call (Last 24 Hours)\n"
+ md4 += basic_stats["num_chats_last_24_hours"] + "\n"
+
+ basic_component_values[0] = md0
+ basic_component_values[1] = basic_stats["chat_dates_bar"]
+ basic_component_values[2] = md1
+ basic_component_values[3] = md2
+ basic_component_values[4] = md3
+ basic_component_values[5] = md4
+
+
+def update_worker(max_num_files, interval, elo_results_file):
+ while True:
+ tic = time.time()
+ update_elo_components(max_num_files, elo_results_file)
+ durtaion = time.time() - tic
+ print(f"update duration: {durtaion:.2f} s")
+ time.sleep(max(interval - durtaion, 0))
+
+
+def load_demo(url_params, request: gr.Request):
+ logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
+ return basic_component_values + leader_component_values
+
+
+def build_basic_stats_tab():
+ empty = "Loading ..."
+ basic_component_values[:] = [empty, None, empty, empty, empty, empty]
+
+ md0 = gr.Markdown(empty)
+ gr.Markdown(
+ "#### Figure 1: Number of model calls and votes"
+ )
+ plot_1 = gr.Plot(show_label=False)
+ with gr.Row():
+ with gr.Column():
+ md1 = gr.Markdown(empty)
+ with gr.Column():
+ md2 = gr.Markdown(empty)
+ with gr.Row():
+ with gr.Column():
+ md3 = gr.Markdown(empty)
+ with gr.Column():
+ md4 = gr.Markdown(empty)
+ return [md0, plot_1, md1, md2, md3, md4]
+
+
+def build_leaderboard_tab(elo_results_file):
+ if elo_results_file is not None:
+ with open(elo_results_file, "rb") as fin:
+ elo_results = pickle.load(fin)
+
+ md = make_leaderboard_md(elo_results)
+ p1 = elo_results["win_fraction_heatmap"]
+ p2 = elo_results["battle_count_heatmap"]
+ p3 = elo_results["average_win_rate_bar"]
+ p4 = elo_results["bootstrap_elo_rating"]
+ else:
+ md = "Loading ..."
+ p1 = p2 = p3 = p4 = None
+
+ leader_component_values[:] = [md, p1, p2, p3, p4]
+
+ md_1 = gr.Markdown(md)
+ gr.Markdown(
+ f"""## More Statistics\n
+We added some additional figures to show more statistics. The code for generating them is also included in this [notebook]({notebook_url}).
+Please note that you may see different orders from different ranking methods. This is expected for models that perform similarly, as demonstrated by the confidence interval in the bootstrap figure. Going forward, we prefer the classical Elo calculation because of its scalability and interpretability. You can find more discussions in this blog [post](https://lmsys.org/blog/2023-05-03-arena/).
+"""
+ )
+
+ with gr.Row():
+ with gr.Column():
+ gr.Markdown(
+ "#### Figure 1: Fraction of Model A Wins for All Non-tied A vs. B Battles"
+ )
+ plot_1 = gr.Plot(p1, show_label=False)
+ with gr.Column():
+ gr.Markdown(
+ "#### Figure 2: Battle Count for Each Combination of Models (without Ties)"
+ )
+ plot_2 = gr.Plot(p2, show_label=False)
+ with gr.Row():
+ with gr.Column():
+ gr.Markdown(
+ "#### Figure 3: Average Win Rate Against All Other Models (Assuming Uniform Sampling and No Ties)"
+ )
+ plot_3 = gr.Plot(p3, show_label=False)
+ with gr.Column():
+ gr.Markdown(
+ "#### Figure 4: Bootstrap of Elo Estimates (1000 Rounds of Random Sampling)"
+ )
+ plot_4 = gr.Plot(p4, show_label=False)
+ return [md_1, plot_1, plot_2, plot_3, plot_4]
+
+
+def build_demo(elo_results_file):
+ with gr.Blocks(
+ title="Monitor",
+ theme=gr.themes.Base(),
+ ) as demo:
+ with gr.Tabs() as tabs:
+ with gr.Tab("Leaderboard", id=0):
+ leader_components = build_leaderboard_tab(elo_results_file)
+
+ with gr.Tab("Basic Stats", id=1):
+ basic_components = build_basic_stats_tab()
+
+ url_params = gr.JSON(visible=False)
+ demo.load(
+ load_demo,
+ [url_params],
+ basic_components + leader_components,
+ _js=get_window_url_params_js,
+ )
+
+ return demo
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--host", type=str, default="0.0.0.0")
+ parser.add_argument("--port", type=int)
+ parser.add_argument("--share", action="store_true")
+ parser.add_argument("--concurrency-count", type=int, default=10)
+ parser.add_argument("--update-interval", type=int, default=300)
+ parser.add_argument("--max-num-files", type=int)
+ parser.add_argument("--elo-results-file", type=str)
+ args = parser.parse_args()
+ logger.info(f"args: {args}")
+
+ update_thread = threading.Thread(
+ target=update_worker,
+ args=(args.max_num_files, args.update_interval, args.elo_results_file),
+ )
+ update_thread.start()
+
+ demo = build_demo(args.elo_results_file)
+ demo.queue(
+ concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False
+ ).launch(
+ server_name=args.host, server_port=args.port, share=args.share, max_threads=200
+ )
diff --git a/graphgpt/serve/openai_api_server.py b/graphgpt/serve/openai_api_server.py
new file mode 100644
index 0000000..6e3099a
--- /dev/null
+++ b/graphgpt/serve/openai_api_server.py
@@ -0,0 +1,722 @@
+"""A server that provides OpenAI-compatible RESTful APIs. It supports:
+
+- Chat Completions. (Reference: https://platform.openai.com/docs/api-reference/chat)
+- Completions. (Reference: https://platform.openai.com/docs/api-reference/completions)
+- Embeddings. (Reference: https://platform.openai.com/docs/api-reference/embeddings)
+
+Usage:
+python3 -m fastchat.serve.openai_api_server
+"""
+import asyncio
+
+import argparse
+import asyncio
+import json
+import logging
+
+import os
+from typing import Generator, Optional, Union, Dict, List, Any
+
+import fastapi
+from fastapi.middleware.cors import CORSMiddleware
+from fastapi.responses import StreamingResponse, JSONResponse
+import httpx
+from pydantic import BaseSettings
+import shortuuid
+import tiktoken
+import uvicorn
+
+from fastchat.constants import WORKER_API_TIMEOUT, WORKER_API_EMBEDDING_BATCH_SIZE, ErrorCode
+from fastchat.model.model_adapter import get_conversation_template
+from fastapi.exceptions import RequestValidationError
+from fastchat.protocol.openai_api_protocol import (
+ ChatCompletionRequest,
+ ChatCompletionResponse,
+ ChatCompletionResponseStreamChoice,
+ ChatCompletionStreamResponse,
+ ChatMessage,
+ ChatCompletionResponseChoice,
+ CompletionRequest,
+ CompletionResponse,
+ CompletionResponseChoice,
+ DeltaMessage,
+ CompletionResponseStreamChoice,
+ CompletionStreamResponse,
+ EmbeddingsRequest,
+ EmbeddingsResponse,
+ ErrorResponse,
+ ModelCard,
+ ModelList,
+ ModelPermission,
+ TokenCheckRequest,
+ TokenCheckResponse,
+ UsageInfo,
+)
+
+logger = logging.getLogger(__name__)
+
+
+class AppSettings(BaseSettings):
+ # The address of the model controller.
+ controller_address: str = "http://localhost:21001"
+
+
+app_settings = AppSettings()
+
+app = fastapi.FastAPI()
+headers = {"User-Agent": "FastChat API Server"}
+
+
+def create_error_response(code: int, message: str) -> JSONResponse:
+ return JSONResponse(
+ ErrorResponse(message=message, code=code).dict(), status_code=400
+ )
+
+
+@app.exception_handler(RequestValidationError)
+async def validation_exception_handler(request, exc):
+ return create_error_response(ErrorCode.VALIDATION_TYPE_ERROR, str(exc))
+
+
+async def check_model(request) -> Optional[JSONResponse]:
+ controller_address = app_settings.controller_address
+ ret = None
+ async with httpx.AsyncClient() as client:
+ try:
+ _worker_addr = await _get_worker_address(request.model, client)
+ except:
+ models_ret = await client.post(controller_address + "/list_models")
+ models = models_ret.json()["models"]
+ ret = create_error_response(
+ ErrorCode.INVALID_MODEL,
+ f"Only {'&&'.join(models)} allowed now, your model {request.model}",
+ )
+ return ret
+
+
+async def check_length(request, prompt, max_tokens):
+ async with httpx.AsyncClient() as client:
+ worker_addr = await _get_worker_address(request.model, client)
+
+ response = await client.post(
+ worker_addr + "/model_details",
+ headers=headers,
+ json={},
+ timeout=WORKER_API_TIMEOUT,
+ )
+ context_len = response.json()["context_length"]
+
+ response = await client.post(
+ worker_addr + "/count_token",
+ headers=headers,
+ json={"prompt": prompt},
+ timeout=WORKER_API_TIMEOUT,
+ )
+ token_num = response.json()["count"]
+
+ if token_num + max_tokens > context_len:
+ return create_error_response(
+ ErrorCode.CONTEXT_OVERFLOW,
+ f"This model's maximum context length is {context_len} tokens. "
+ f"However, you requested {max_tokens + token_num} tokens "
+ f"({token_num} in the messages, "
+ f"{max_tokens} in the completion). "
+ f"Please reduce the length of the messages or completion.",
+ )
+ else:
+ return None
+
+
+def check_requests(request) -> Optional[JSONResponse]:
+ # Check all params
+ if request.max_tokens is not None and request.max_tokens <= 0:
+ return create_error_response(
+ ErrorCode.PARAM_OUT_OF_RANGE,
+ f"{request.max_tokens} is less than the minimum of 1 - 'max_tokens'",
+ )
+ if request.n is not None and request.n <= 0:
+ return create_error_response(
+ ErrorCode.PARAM_OUT_OF_RANGE,
+ f"{request.n} is less than the minimum of 1 - 'n'",
+ )
+ if request.temperature is not None and request.temperature < 0:
+ return create_error_response(
+ ErrorCode.PARAM_OUT_OF_RANGE,
+ f"{request.temperature} is less than the minimum of 0 - 'temperature'",
+ )
+ if request.temperature is not None and request.temperature > 2:
+ return create_error_response(
+ ErrorCode.PARAM_OUT_OF_RANGE,
+ f"{request.temperature} is greater than the maximum of 2 - 'temperature'",
+ )
+ if request.top_p is not None and request.top_p < 0:
+ return create_error_response(
+ ErrorCode.PARAM_OUT_OF_RANGE,
+ f"{request.top_p} is less than the minimum of 0 - 'top_p'",
+ )
+ if request.top_p is not None and request.top_p > 1:
+ return create_error_response(
+ ErrorCode.PARAM_OUT_OF_RANGE,
+ f"{request.top_p} is greater than the maximum of 1 - 'temperature'",
+ )
+ if request.stop is not None and (
+ not isinstance(request.stop, str) and not isinstance(request.stop, list)
+ ):
+ return create_error_response(
+ ErrorCode.PARAM_OUT_OF_RANGE,
+ f"{request.stop} is not valid under any of the given schemas - 'stop'",
+ )
+
+ return None
+
+
+def process_input(model_name, input):
+ if isinstance(input, str):
+ input = [input]
+ elif isinstance(input, list):
+ if isinstance(input[0], int):
+ decoding = tiktoken.model.encoding_for_model(model_name)
+ input = [decoding.decode(input)]
+ elif isinstance(input[0], list):
+ decoding = tiktoken.model.encoding_for_model(model_name)
+ input = [decoding.decode(text) for text in input]
+
+ return input
+
+
+def get_gen_params(
+ model_name: str,
+ messages: Union[str, List[Dict[str, str]]],
+ *,
+ temperature: float,
+ top_p: float,
+ max_tokens: Optional[int],
+ echo: Optional[bool],
+ stream: Optional[bool],
+ stop: Optional[Union[str, List[str]]],
+) -> Dict[str, Any]:
+ conv = get_conversation_template(model_name)
+
+ if isinstance(messages, str):
+ prompt = messages
+ else:
+ for message in messages:
+ msg_role = message["role"]
+ if msg_role == "system":
+ conv.system = message["content"]
+ elif msg_role == "user":
+ conv.append_message(conv.roles[0], message["content"])
+ elif msg_role == "assistant":
+ conv.append_message(conv.roles[1], message["content"])
+ else:
+ raise ValueError(f"Unknown role: {msg_role}")
+
+ # Add a blank message for the assistant.
+ conv.append_message(conv.roles[1], None)
+
+ is_chatglm = "chatglm" in model_name.lower()
+ if is_chatglm:
+ prompt = conv.messages[conv.offset :]
+ else:
+ prompt = conv.get_prompt()
+
+ if max_tokens is None:
+ max_tokens = 512
+
+ gen_params = {
+ "model": model_name,
+ "prompt": prompt,
+ "temperature": temperature,
+ "top_p": top_p,
+ "max_new_tokens": max_tokens,
+ "echo": echo,
+ "stream": stream,
+ }
+
+ if stop is None:
+ gen_params.update(
+ {"stop": conv.stop_str, "stop_token_ids": conv.stop_token_ids}
+ )
+ else:
+ gen_params.update({"stop": stop})
+
+ logger.debug(f"==== request ====\n{gen_params}")
+ return gen_params
+
+
+async def _get_worker_address(model_name: str, client: httpx.AsyncClient) -> str:
+ """
+ Get worker address based on the requested model
+
+ :param model_name: The worker's model name
+ :param client: The httpx client to use
+ :return: Worker address from the controller
+ :raises: :class:`ValueError`: No available worker for requested model
+ """
+ controller_address = app_settings.controller_address
+
+ ret = await client.post(
+ controller_address + "/get_worker_address", json={"model": model_name}
+ )
+ worker_addr = ret.json()["address"]
+ # No available worker
+ if worker_addr == "":
+ raise ValueError(f"No available worker for {model_name}")
+
+ logger.debug(f"model_name: {model_name}, worker_addr: {worker_addr}")
+ return worker_addr
+
+
+@app.get("/v1/models")
+async def show_available_models():
+ controller_address = app_settings.controller_address
+ async with httpx.AsyncClient() as client:
+ ret = await client.post(controller_address + "/refresh_all_workers")
+ ret = await client.post(controller_address + "/list_models")
+ models = ret.json()["models"]
+ models.sort()
+ # TODO: return real model permission details
+ model_cards = []
+ for m in models:
+ model_cards.append(ModelCard(id=m, root=m, permission=[ModelPermission()]))
+ return ModelList(data=model_cards)
+
+
+# TODO: Have check_length and count_tokens share code.
+@app.post("/v1/token_check")
+async def count_tokens(request: TokenCheckRequest):
+ """
+ Checks the token count against your message
+ This is not part of the OpenAI API spec.
+ """
+ async with httpx.AsyncClient() as client:
+ worker_addr = await _get_worker_address(request.model, client)
+
+ response = await client.post(
+ worker_addr + "/model_details",
+ headers=headers,
+ json={},
+ timeout=WORKER_API_TIMEOUT,
+ )
+ context_len = response.json()["context_length"]
+
+ response = await client.post(
+ worker_addr + "/count_token",
+ headers=headers,
+ json={"prompt": request.prompt},
+ timeout=WORKER_API_TIMEOUT,
+ )
+ token_num = response.json()["count"]
+
+ can_fit = True
+ if token_num + request.max_tokens > context_len:
+ can_fit = False
+
+ return TokenCheckResponse(fits=can_fit, contextLength=context_len, tokenCount=token_num)
+
+
+@app.post("/v1/chat/completions")
+async def create_chat_completion(request: ChatCompletionRequest):
+ """Creates a completion for the chat message"""
+ error_check_ret = await check_model(request)
+ if error_check_ret is not None:
+ return error_check_ret
+ error_check_ret = check_requests(request)
+ if error_check_ret is not None:
+ return error_check_ret
+
+ gen_params = get_gen_params(
+ request.model,
+ request.messages,
+ temperature=request.temperature,
+ top_p=request.top_p,
+ max_tokens=request.max_tokens,
+ echo=False,
+ stream=request.stream,
+ stop=request.stop,
+ )
+ error_check_ret = await check_length(
+ request, gen_params["prompt"], gen_params["max_new_tokens"]
+ )
+ if error_check_ret is not None:
+ return error_check_ret
+
+ if request.stream:
+ generator = chat_completion_stream_generator(
+ request.model, gen_params, request.n
+ )
+ return StreamingResponse(generator, media_type="text/event-stream")
+
+ choices = []
+ # TODO: batch the requests. maybe not necessary if using CacheFlow worker
+ chat_completions = []
+ for i in range(request.n):
+ content = asyncio.create_task(chat_completion(request.model, gen_params))
+ chat_completions.append(content)
+ try:
+ all_tasks = await asyncio.gather(*chat_completions)
+ except Exception as e:
+ return create_error_response(ErrorCode.INTERNAL_ERROR, str(e))
+ usage = UsageInfo()
+ for i, content in enumerate(all_tasks):
+ if content["error_code"] != 0:
+ return create_error_response(content["error_code"], content["text"])
+ choices.append(
+ ChatCompletionResponseChoice(
+ index=i,
+ message=ChatMessage(role="assistant", content=content["text"]),
+ finish_reason=content.get("finish_reason", "stop"),
+ )
+ )
+ task_usage = UsageInfo.parse_obj(content["usage"])
+ for usage_key, usage_value in task_usage.dict().items():
+ setattr(usage, usage_key, getattr(usage, usage_key) + usage_value)
+
+ return ChatCompletionResponse(model=request.model, choices=choices, usage=usage)
+
+
+async def chat_completion_stream_generator(
+ model_name: str, gen_params: Dict[str, Any], n: int
+) -> Generator[str, Any, None]:
+ """
+ Event stream format:
+ https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format
+ """
+ id = f"chatcmpl-{shortuuid.random()}"
+ finish_stream_events = []
+ for i in range(n):
+ # First chunk with role
+ choice_data = ChatCompletionResponseStreamChoice(
+ index=i,
+ delta=DeltaMessage(role="assistant"),
+ finish_reason=None,
+ )
+ chunk = ChatCompletionStreamResponse(
+ id=id, choices=[choice_data], model=model_name
+ )
+ yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
+
+ previous_text = ""
+ async for content in chat_completion_stream(model_name, gen_params):
+ if content["error_code"] != 0:
+ yield f"data: {json.dumps(content, ensure_ascii=False)}\n\n"
+ yield "data: [DONE]\n\n"
+ return
+ decoded_unicode = content["text"].replace("\ufffd", "")
+ delta_text = decoded_unicode[len(previous_text) :]
+ previous_text = decoded_unicode
+
+ if len(delta_text) == 0:
+ delta_text = None
+ choice_data = ChatCompletionResponseStreamChoice(
+ index=i,
+ delta=DeltaMessage(content=delta_text),
+ finish_reason=content.get("finish_reason", None),
+ )
+ chunk = ChatCompletionStreamResponse(
+ id=id, choices=[choice_data], model=model_name
+ )
+ if delta_text is None:
+ if content.get("finish_reason", None) is not None:
+ finish_stream_events.append(chunk)
+ continue
+ yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
+ # There is not "content" field in the last delta message, so exclude_none to exclude field "content".
+ for finish_chunk in finish_stream_events:
+ yield f"data: {finish_chunk.json(exclude_none=True, ensure_ascii=False)}\n\n"
+ yield "data: [DONE]\n\n"
+
+
+async def chat_completion_stream(model_name: str, gen_params: Dict[str, Any]):
+ controller_url = app_settings.controller_address
+ async with httpx.AsyncClient() as client:
+ worker_addr = await _get_worker_address(model_name, client)
+ delimiter = b"\0"
+ async with client.stream(
+ "POST",
+ worker_addr + "/worker_generate_stream",
+ headers=headers,
+ json=gen_params,
+ timeout=WORKER_API_TIMEOUT,
+ ) as response:
+ # content = await response.aread()
+ async for raw_chunk in response.aiter_raw():
+ for chunk in raw_chunk.split(delimiter):
+ if not chunk:
+ continue
+ data = json.loads(chunk.decode())
+ yield data
+
+
+async def chat_completion(
+ model_name: str, gen_params: Dict[str, Any]
+) -> Optional[Dict[str, Any]]:
+ async with httpx.AsyncClient() as client:
+ worker_addr = await _get_worker_address(model_name, client)
+
+ output = None
+ delimiter = b"\0"
+
+ async with client.stream(
+ "POST",
+ worker_addr + "/worker_generate_stream",
+ headers=headers,
+ json=gen_params,
+ timeout=WORKER_API_TIMEOUT,
+ ) as response:
+ content = await response.aread()
+
+ for chunk in content.split(delimiter):
+ if not chunk:
+ continue
+ data = json.loads(chunk.decode())
+ output = data
+
+ return output
+
+
+@app.post("/v1/completions")
+async def create_completion(request: CompletionRequest):
+ error_check_ret = await check_model(request)
+ if error_check_ret is not None:
+ return error_check_ret
+ error_check_ret = check_requests(request)
+ if error_check_ret is not None:
+ return error_check_ret
+
+ request.prompt = process_input(request.model, request.prompt)
+
+ for text in request.prompt:
+ error_check_ret = await check_length(request, text, request.max_tokens)
+ if error_check_ret is not None:
+ return error_check_ret
+
+ if request.stream:
+ generator = generate_completion_stream_generator(request, request.n)
+ return StreamingResponse(generator, media_type="text/event-stream")
+ else:
+ text_completions = []
+ for text in request.prompt:
+ payload = get_gen_params(
+ request.model,
+ text,
+ temperature=request.temperature,
+ top_p=request.top_p,
+ max_tokens=request.max_tokens,
+ echo=request.echo,
+ stream=request.stream,
+ stop=request.stop,
+ )
+ for i in range(request.n):
+ content = asyncio.create_task(generate_completion(payload))
+ text_completions.append(content)
+
+ try:
+ all_tasks = await asyncio.gather(*text_completions)
+ except Exception as e:
+ return create_error_response(ErrorCode.INTERNAL_ERROR, str(e))
+
+ choices = []
+ usage = UsageInfo()
+ for i, content in enumerate(all_tasks):
+ if content["error_code"] != 0:
+ return create_error_response(content["error_code"], content["text"])
+ choices.append(
+ CompletionResponseChoice(
+ index=i,
+ text=content["text"],
+ logprobs=content.get("logprobs", None),
+ finish_reason=content.get("finish_reason", "stop"),
+ )
+ )
+ task_usage = UsageInfo.parse_obj(content["usage"])
+ for usage_key, usage_value in task_usage.dict().items():
+ setattr(usage, usage_key, getattr(usage, usage_key) + usage_value)
+
+ return CompletionResponse(
+ model=request.model, choices=choices, usage=UsageInfo.parse_obj(usage)
+ )
+
+
+async def generate_completion_stream_generator(request: CompletionRequest, n: int):
+ model_name = request.model
+ id = f"cmpl-{shortuuid.random()}"
+ finish_stream_events = []
+ for text in request.prompt:
+ for i in range(n):
+ previous_text = ""
+ payload = get_gen_params(
+ request.model,
+ text,
+ temperature=request.temperature,
+ top_p=request.top_p,
+ max_tokens=request.max_tokens,
+ echo=request.echo,
+ stream=request.stream,
+ stop=request.stop,
+ )
+ async for content in generate_completion_stream(payload):
+ if content["error_code"] != 0:
+ yield f"data: {json.dumps(content, ensure_ascii=False)}\n\n"
+ yield "data: [DONE]\n\n"
+ return
+ decoded_unicode = content["text"].replace("\ufffd", "")
+ delta_text = decoded_unicode[len(previous_text) :]
+ previous_text = decoded_unicode
+ # todo: index is not apparent
+ choice_data = CompletionResponseStreamChoice(
+ index=i,
+ text=delta_text,
+ logprobs=content.get("logprobs", None),
+ finish_reason=content.get("finish_reason", None),
+ )
+ chunk = CompletionStreamResponse(
+ id=id,
+ object="text_completion",
+ choices=[choice_data],
+ model=model_name,
+ )
+ if len(delta_text) == 0:
+ if content.get("finish_reason", None) is not None:
+ finish_stream_events.append(chunk)
+ continue
+ yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
+ # There is not "content" field in the last delta message, so exclude_none to exclude field "content".
+ for finish_chunk in finish_stream_events:
+ yield f"data: {finish_chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
+ yield "data: [DONE]\n\n"
+
+
+async def generate_completion_stream(payload: Dict[str, Any]):
+ controller_address = app_settings.controller_address
+ async with httpx.AsyncClient() as client:
+ worker_addr = await _get_worker_address(payload["model"], client)
+
+ delimiter = b"\0"
+ async with client.stream(
+ "POST",
+ worker_addr + "/worker_generate_completion_stream",
+ headers=headers,
+ json=payload,
+ timeout=WORKER_API_TIMEOUT,
+ ) as response:
+ # content = await response.aread()
+ async for raw_chunk in response.aiter_raw():
+ for chunk in raw_chunk.split(delimiter):
+ if not chunk:
+ continue
+ data = json.loads(chunk.decode())
+ yield data
+
+
+async def generate_completion(payload: Dict[str, Any]):
+ controller_address = app_settings.controller_address
+ async with httpx.AsyncClient() as client:
+ worker_addr = await _get_worker_address(payload["model"], client)
+
+ response = await client.post(
+ worker_addr + "/worker_generate_completion",
+ headers=headers,
+ json=payload,
+ timeout=WORKER_API_TIMEOUT,
+ )
+ completion = response.json()
+ return completion
+
+
+@app.post("/v1/embeddings")
+@app.post("/v1/engines/{model_name}/embeddings")
+async def create_embeddings(request: EmbeddingsRequest, model_name: str = None):
+ """Creates embeddings for the text"""
+ if request.model is None:
+ request.model = model_name
+ error_check_ret = await check_model(request)
+ if error_check_ret is not None:
+ return error_check_ret
+
+ request.input = process_input(request.model, request.input)
+
+ data = []
+ token_num = 0
+ batch_size = WORKER_API_EMBEDDING_BATCH_SIZE
+ batches = [
+ request.input[i : min(i + batch_size, len(request.input))]
+ for i in range(0, len(request.input), batch_size)
+ ]
+ for num_batch, batch in enumerate(batches):
+ payload = {
+ "model": request.model,
+ "input": batch,
+ }
+ embedding = await get_embedding(payload)
+ data += [
+ {
+ "object": "embedding",
+ "embedding": emb,
+ "index": num_batch * batch_size + i,
+ }
+ for i, emb in enumerate(embedding["embedding"])
+ ]
+ token_num += embedding["token_num"]
+ return EmbeddingsResponse(
+ data=data,
+ model=request.model,
+ usage=UsageInfo(
+ prompt_tokens=token_num,
+ total_tokens=token_num,
+ completion_tokens=None,
+ ),
+ ).dict(exclude_none=True)
+
+
+async def get_embedding(payload: Dict[str, Any]):
+ controller_address = app_settings.controller_address
+ model_name = payload["model"]
+ async with httpx.AsyncClient() as client:
+ worker_addr = await _get_worker_address(model_name, client)
+
+ response = await client.post(
+ worker_addr + "/worker_get_embeddings",
+ headers=headers,
+ json=payload,
+ timeout=WORKER_API_TIMEOUT,
+ )
+ embedding = response.json()
+ return embedding
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(
+ description="FastChat ChatGPT-Compatible RESTful API server."
+ )
+ parser.add_argument("--host", type=str, default="localhost", help="host name")
+ parser.add_argument("--port", type=int, default=8000, help="port number")
+ parser.add_argument(
+ "--controller-address", type=str, default="http://localhost:21001"
+ )
+ parser.add_argument(
+ "--allow-credentials", action="store_true", help="allow credentials"
+ )
+ parser.add_argument(
+ "--allowed-origins", type=json.loads, default=["*"], help="allowed origins"
+ )
+ parser.add_argument(
+ "--allowed-methods", type=json.loads, default=["*"], help="allowed methods"
+ )
+ parser.add_argument(
+ "--allowed-headers", type=json.loads, default=["*"], help="allowed headers"
+ )
+ args = parser.parse_args()
+
+ app.add_middleware(
+ CORSMiddleware,
+ allow_origins=args.allowed_origins,
+ allow_credentials=args.allow_credentials,
+ allow_methods=args.allowed_methods,
+ allow_headers=args.allowed_headers,
+ )
+ app_settings.controller_address = args.controller_address
+
+ logger.info(f"args: {args}")
+
+ uvicorn.run(app, host=args.host, port=args.port, log_level="info")
diff --git a/graphgpt/serve/register_worker.py b/graphgpt/serve/register_worker.py
new file mode 100644
index 0000000..2c2c402
--- /dev/null
+++ b/graphgpt/serve/register_worker.py
@@ -0,0 +1,26 @@
+"""
+Manually register workers.
+
+Usage:
+python3 -m fastchat.serve.register_worker --controller http://localhost:21001 --worker-name http://localhost:21002
+"""
+
+import argparse
+
+import requests
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--controller-address", type=str)
+ parser.add_argument("--worker-name", type=str)
+ parser.add_argument("--check-heart-beat", action="store_true")
+ args = parser.parse_args()
+
+ url = args.controller_address + "/register_worker"
+ data = {
+ "worker_name": args.worker_name,
+ "check_heart_beat": args.check_heart_beat,
+ "worker_status": None,
+ }
+ r = requests.post(url, json=data)
+ assert r.status_code == 200
diff --git a/graphgpt/serve/test_message.py b/graphgpt/serve/test_message.py
new file mode 100644
index 0000000..203a449
--- /dev/null
+++ b/graphgpt/serve/test_message.py
@@ -0,0 +1,81 @@
+"""Send a test message."""
+import argparse
+import json
+
+import requests
+
+from fastchat.model.model_adapter import get_conversation_template
+
+
+def main():
+ model_name = args.model_name
+
+ if args.worker_address:
+ worker_addr = args.worker_address
+ else:
+ controller_addr = args.controller_address
+ ret = requests.post(controller_addr + "/refresh_all_workers")
+ ret = requests.post(controller_addr + "/list_models")
+ models = ret.json()["models"]
+ models.sort()
+ print(f"Models: {models}")
+
+ ret = requests.post(
+ controller_addr + "/get_worker_address", json={"model": model_name}
+ )
+ worker_addr = ret.json()["address"]
+ print(f"worker_addr: {worker_addr}")
+
+ if worker_addr == "":
+ print(f"No available workers for {model_name}")
+ return
+
+ conv = get_conversation_template(model_name)
+ conv.append_message(conv.roles[0], args.message)
+ conv.append_message(conv.roles[1], None)
+ prompt = conv.get_prompt()
+
+ headers = {"User-Agent": "FastChat Client"}
+ gen_params = {
+ "model": model_name,
+ "prompt": prompt,
+ "temperature": args.temperature,
+ "max_new_tokens": args.max_new_tokens,
+ "stop": conv.stop_str,
+ "stop_token_ids": conv.stop_token_ids,
+ "echo": False,
+ }
+ response = requests.post(
+ worker_addr + "/worker_generate_stream",
+ headers=headers,
+ json=gen_params,
+ stream=True,
+ )
+
+ print(f"{conv.roles[0]}: {args.message}")
+ print(f"{conv.roles[1]}: ", end="")
+ prev = 0
+ for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
+ if chunk:
+ data = json.loads(chunk.decode())
+ output = data["text"].strip()
+ print(output[prev:], end="", flush=True)
+ prev = len(output)
+ print("")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--controller-address", type=str, default="http://localhost:21001"
+ )
+ parser.add_argument("--worker-address", type=str)
+ parser.add_argument("--model-name", type=str, required=True)
+ parser.add_argument("--temperature", type=float, default=0.0)
+ parser.add_argument("--max-new-tokens", type=int, default=32)
+ parser.add_argument(
+ "--message", type=str, default="Tell me a story with more than 1000 words."
+ )
+ args = parser.parse_args()
+
+ main()
diff --git a/graphgpt/serve/test_throughput.py b/graphgpt/serve/test_throughput.py
new file mode 100644
index 0000000..9cc5f45
--- /dev/null
+++ b/graphgpt/serve/test_throughput.py
@@ -0,0 +1,115 @@
+"""Benchmarking script to test the throughput of serving workers."""
+import argparse
+import json
+
+import requests
+import threading
+import time
+
+from fastchat.conversation import default_conversation
+
+
+def main():
+ if args.worker_address:
+ worker_addr = args.worker_address
+ else:
+ controller_addr = args.controller_address
+ ret = requests.post(controller_addr + "/refresh_all_workers")
+ ret = requests.post(controller_addr + "/list_models")
+ models = ret.json()["models"]
+ models.sort()
+ print(f"Models: {models}")
+
+ ret = requests.post(
+ controller_addr + "/get_worker_address", json={"model": args.model_name}
+ )
+ worker_addr = ret.json()["address"]
+ print(f"worker_addr: {worker_addr}")
+
+ if worker_addr == "":
+ return
+
+ conv = default_conversation.copy()
+ conv.append_message(conv.roles[0], "Tell me a story with more than 1000 words")
+ prompt_template = conv.get_prompt()
+ prompts = [prompt_template for _ in range(args.n_thread)]
+
+ headers = {"User-Agent": "fastchat Client"}
+ ploads = [
+ {
+ "model": args.model_name,
+ "prompt": prompts[i],
+ "max_new_tokens": args.max_new_tokens,
+ "temperature": 0.0,
+ # "stop": conv.sep,
+ }
+ for i in range(len(prompts))
+ ]
+
+ def send_request(results, i):
+ if args.test_dispatch:
+ ret = requests.post(
+ controller_addr + "/get_worker_address", json={"model": args.model_name}
+ )
+ thread_worker_addr = ret.json()["address"]
+ else:
+ thread_worker_addr = worker_addr
+ print(f"thread {i} goes to {thread_worker_addr}")
+ response = requests.post(
+ thread_worker_addr + "/worker_generate_stream",
+ headers=headers,
+ json=ploads[i],
+ stream=False,
+ )
+ k = list(
+ response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0")
+ )
+ # print(k)
+ response_new_words = json.loads(k[-2].decode("utf-8"))["text"]
+ error_code = json.loads(k[-2].decode("utf-8"))["error_code"]
+ # print(f"=== Thread {i} ===, words: {1}, error code: {error_code}")
+ results[i] = len(response_new_words.split(" ")) - len(prompts[i].split(" "))
+
+ # use N threads to prompt the backend
+ tik = time.time()
+ threads = []
+ results = [None] * args.n_thread
+ for i in range(args.n_thread):
+ t = threading.Thread(target=send_request, args=(results, i))
+ t.start()
+ # time.sleep(0.5)
+ threads.append(t)
+
+ for t in threads:
+ t.join()
+
+ print(f"Time (POST): {time.time() - tik} s")
+ # n_words = 0
+ # for i, response in enumerate(results):
+ # # print(prompt[i].replace(conv.sep, "\n"), end="")
+ # # make sure the streaming finishes at EOS or stopping criteria
+ # k = list(response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"))
+ # response_new_words = json.loads(k[-2].decode("utf-8"))["text"]
+ # # print(response_new_words)
+ # n_words += len(response_new_words.split(" ")) - len(prompts[i].split(" "))
+ n_words = sum(results)
+ time_seconds = time.time() - tik
+ print(
+ f"Time (Completion): {time_seconds}, n threads: {args.n_thread}, "
+ f"throughput: {n_words / time_seconds} words/s."
+ )
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--controller-address", type=str, default="http://localhost:21001"
+ )
+ parser.add_argument("--worker-address", type=str)
+ parser.add_argument("--model-name", type=str, default="vicuna")
+ parser.add_argument("--max-new-tokens", type=int, default=2048)
+ parser.add_argument("--n-thread", type=int, default=8)
+ parser.add_argument("--test-dispatch", action="store_true")
+ args = parser.parse_args()
+
+ main()
diff --git a/graphgpt/train/graphchat_trainer.py b/graphgpt/train/graphchat_trainer.py
new file mode 100644
index 0000000..11e4b62
--- /dev/null
+++ b/graphgpt/train/graphchat_trainer.py
@@ -0,0 +1,49 @@
+import os
+import torch
+import torch.nn as nn
+
+from transformers import Trainer
+from typing import Dict, Optional, Sequence
+
+
+def unwrap_model(model: nn.Module) -> nn.Module:
+ """
+ Recursively unwraps a model from potential containers (as used in distributed training).
+
+ Args:
+ model (`torch.nn.Module`): The model to unwrap.
+ """
+ # since there could be multiple levels of wrapping, unwrap recursively
+ if hasattr(model, "module"):
+ return unwrap_model(model.module)
+ else:
+ return model
+
+
+class GraphChatTrainer(Trainer):
+
+ def _save(self, output_dir: Optional[str] = None, state_dict=None):
+ if getattr(self.args, 'tune_graph_mlp_adapter', False):
+ # Save the model
+ _state_dict = state_dict
+ if _state_dict is None:
+ # Only save the model itself if we are using distributed training
+ model_to_save = unwrap_model(self.model)
+ _state_dict = model_to_save.state_dict()
+
+ weight_to_save = {}
+ keys_to_match = ['graph_projector', 'embed_tokens', 'embed_in']
+ for k, v in _state_dict.items():
+ if any(key_match in k for key_match in keys_to_match):
+ weight_to_save[k] = v
+
+ current_folder = output_dir.split('/')[-1]
+ parent_folder = os.path.dirname(output_dir)
+ if current_folder.startswith('checkpoint-'):
+ mm_projector_folder = os.path.join(parent_folder, "graph_projector")
+ os.makedirs(mm_projector_folder, exist_ok=True)
+ torch.save(weight_to_save, os.path.join(mm_projector_folder, f'{current_folder}.bin'))
+ else:
+ torch.save(weight_to_save, os.path.join(output_dir, f'graph_projector.bin'))
+
+ super(GraphChatTrainer, self)._save(output_dir, state_dict)
diff --git a/graphgpt/train/llama_flash_attn_monkey_patch.py b/graphgpt/train/llama_flash_attn_monkey_patch.py
new file mode 100644
index 0000000..00fc39e
--- /dev/null
+++ b/graphgpt/train/llama_flash_attn_monkey_patch.py
@@ -0,0 +1,114 @@
+from typing import List, Optional, Tuple
+
+import torch
+from torch import nn
+
+import transformers
+from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
+
+from einops import rearrange
+
+from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
+from flash_attn.bert_padding import unpad_input, pad_input
+
+
+def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ """Input shape: Batch x Time x Channel
+
+ attention_mask: [bsz, q_len]
+ """
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = (
+ self.q_proj(hidden_states)
+ .view(bsz, q_len, self.num_heads, self.head_dim)
+ .transpose(1, 2)
+ )
+ key_states = (
+ self.k_proj(hidden_states)
+ .view(bsz, q_len, self.num_heads, self.head_dim)
+ .transpose(1, 2)
+ )
+ value_states = (
+ self.v_proj(hidden_states)
+ .view(bsz, q_len, self.num_heads, self.head_dim)
+ .transpose(1, 2)
+ )
+ # [bsz, q_len, nh, hd]
+ # [bsz, nh, q_len, hd]
+
+ kv_seq_len = key_states.shape[-2]
+ assert past_key_value is None, "past_key_value is not supported"
+
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+ query_states, key_states = apply_rotary_pos_emb(
+ query_states, key_states, cos, sin, position_ids
+ )
+ # [bsz, nh, t, hd]
+ assert not output_attentions, "output_attentions is not supported"
+ assert not use_cache, "use_cache is not supported"
+
+ # Flash attention codes from
+ # https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py
+
+ # transform the data into the format required by flash attention
+ qkv = torch.stack(
+ [query_states, key_states, value_states], dim=2
+ ) # [bsz, nh, 3, q_len, hd]
+ qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
+ # We have disabled _prepare_decoder_attention_mask in LlamaModel
+ # the attention_mask should be the same as the key_padding_mask
+ key_padding_mask = attention_mask
+
+ if key_padding_mask is None:
+ qkv = rearrange(qkv, "b s ... -> (b s) ...")
+ max_s = q_len
+ cu_q_lens = torch.arange(
+ 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device
+ )
+ output = flash_attn_unpadded_qkvpacked_func(
+ qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
+ )
+ output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
+ else:
+ nheads = qkv.shape[-2]
+ x = rearrange(qkv, "b s three h d -> b s (three h d)")
+ x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
+ x_unpad = rearrange(
+ x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads
+ )
+ output_unpad = flash_attn_unpadded_qkvpacked_func(
+ x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
+ )
+ output = rearrange(
+ pad_input(
+ rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz, q_len
+ ),
+ "b s (h d) -> b s h d",
+ h=nheads,
+ )
+ return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, None
+
+
+# Disable the transformation of the attention mask in LlamaModel as the flash attention
+# requires the attention mask to be the same as the key_padding_mask
+def _prepare_decoder_attention_mask(
+ self, attention_mask, input_shape, inputs_embeds, past_key_values_length
+):
+ # [bsz, seq_len]
+ return attention_mask
+
+
+def replace_llama_attn_with_flash_attn():
+ transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = (
+ _prepare_decoder_attention_mask
+ )
+ transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
diff --git a/graphgpt/train/train_graph.py b/graphgpt/train/train_graph.py
new file mode 100644
index 0000000..00d0681
--- /dev/null
+++ b/graphgpt/train/train_graph.py
@@ -0,0 +1,966 @@
+# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
+# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
+# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import copy
+from dataclasses import dataclass, field
+import json
+import logging
+import pathlib
+from typing import Dict, Optional, Sequence, List
+
+import torch
+
+import transformers
+from torch.utils.data import Dataset
+from graphgpt.train.graphchat_trainer import GraphChatTrainer
+
+from graphgpt import conversation as conversation_lib
+from graphgpt.model import *
+
+from PIL import Image
+import torch.nn as nn
+from torch_geometric.data import Data
+
+# TODO: import and use code from ../data/dataset.py
+
+IGNORE_INDEX = -100
+DEFAULT_PAD_TOKEN = "[PAD]"
+DEFAULT_EOS_TOKEN = ""
+DEFAULT_BOS_TOKEN = ""
+DEFAULT_UNK_TOKEN = ""
+DEFAULT_GRAPH_TOKEN = ""
+DEFAULT_GRAPH_PATCH_TOKEN = ""
+DEFAULT_G_START_TOKEN = ""
+DEFAULT_G_END_TOKEN = ""
+
+
+@dataclass
+class ModelArguments:
+ model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
+ version: Optional[str] = field(default="v0")
+ freeze_backbone: bool = field(default=False)
+ tune_graph_mlp_adapter: bool = field(default=False)
+ graph_tower: Optional[str] = field(default=None)
+ graph_select_layer: Optional[int] = field(default=-1) # default to the last layer
+ pretrain_graph_mlp_adapter: Optional[str] = field(default=None)
+ use_graph_start_end: bool = field(default=False)
+
+
+@dataclass
+class DataArguments:
+ data_path: str = field(default=None,
+ metadata={"help": "Path to the training data."})
+ lazy_preprocess: bool = False
+ is_graph: bool = False
+ sep_graph_conv_front: bool = False
+ graph_token_len: int = 0
+ graph_content: Optional[str] = field(default=None)
+ graph_data_path: Optional[str] = field(default=None)
+ image_aspect_ratio: str = 'square'
+
+
+@dataclass
+class TrainingArguments(transformers.TrainingArguments):
+ cache_dir: Optional[str] = field(default=None)
+ optim: str = field(default="adamw_torch")
+ remove_unused_columns: bool = field(default=False)
+ freeze_graph_mlp_adapter: bool = field(default=False)
+ force_fsdp: bool = field(default=False)
+ model_max_length: int = field(
+ default=512,
+ metadata={
+ "help":
+ "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
+ },
+ )
+ double_quant: bool = field(
+ default=True,
+ metadata={"help": "Compress the quantization statistics through double quantization."}
+ )
+ quant_type: str = field(
+ default="nf4",
+ metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."}
+ )
+ bits: int = field(
+ default=16,
+ metadata={"help": "How many bits to use."}
+ )
+ lora_enable: bool = False
+ lora_r: int = 64
+ lora_alpha: int = 16
+ lora_dropout: float = 0.05
+ lora_weight_path: str = ""
+ lora_bias: str = "none"
+ disable_tqdm: bool =False
+
+
+def maybe_zero_3(param, ignore_status=False, name=None):
+ from deepspeed import zero
+ from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
+ if hasattr(param, "ds_id"):
+ if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
+ if not ignore_status:
+ logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}")
+ with zero.GatheredParameters([param]):
+ param = param.data.detach().cpu().clone()
+ else:
+ param = param.detach().cpu().clone()
+ return param
+
+
+# Borrowed from peft.utils.get_peft_model_state_dict
+def get_peft_state_maybe_zero_3(named_params, bias):
+ if bias == "none":
+ to_return = {k: t for k, t in named_params if "lora_" in k}
+ elif bias == "all":
+ to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
+ elif bias == "lora_only":
+ to_return = {}
+ maybe_lora_bias = {}
+ lora_bias_names = set()
+ for k, t in named_params:
+ if "lora_" in k:
+ to_return[k] = t
+ bias_name = k.split("lora_")[0] + "bias"
+ lora_bias_names.add(bias_name)
+ elif "bias" in k:
+ maybe_lora_bias[k] = t
+ for k, t in maybe_lora_bias:
+ if bias_name in lora_bias_names:
+ to_return[bias_name] = t
+ else:
+ raise NotImplementedError
+ to_return = {k: maybe_zero_3(v, name=k) for k, v in to_return.items()}
+ return to_return
+
+
+def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True):
+ to_return = {k: t for k, t in named_params if "lora_" not in k}
+ if require_grad_only:
+ to_return = {k: t for k, t in to_return.items() if t.requires_grad}
+ to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
+ return to_return
+
+
+def find_all_linear_names(model):
+ cls = torch.nn.Linear
+ lora_module_names = set()
+ for name, module in model.named_modules():
+ if isinstance(module, cls):
+ names = name.split('.')
+ lora_module_names.add(names[0] if len(names) == 1 else names[-1])
+
+
+ if 'lm_head' in lora_module_names: # needed for 16-bit
+ lora_module_names.remove('lm_head')
+ return list(lora_module_names)
+
+
+def safe_save_model_for_hf_trainer(trainer: transformers.Trainer,
+ output_dir: str):
+ """Collects the state dict and dump to disk."""
+ if trainer.deepspeed:
+ torch.cuda.synchronize()
+ trainer.save_model(output_dir)
+ return
+
+ state_dict = trainer.model.state_dict()
+ if trainer.args.should_save:
+ cpu_state_dict = {
+ key: value.cpu()
+ for key, value in state_dict.items()
+ }
+ del state_dict
+ trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
+
+
+def smart_tokenizer_and_embedding_resize(
+ special_tokens_dict: Dict,
+ tokenizer: transformers.PreTrainedTokenizer,
+ model: transformers.PreTrainedModel,
+):
+ """Resize tokenizer and embedding.
+
+ Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
+ """
+ num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
+ model.resize_token_embeddings(len(tokenizer))
+
+ if num_new_tokens > 0:
+ input_embeddings = model.get_input_embeddings().weight.data
+ output_embeddings = model.get_output_embeddings().weight.data
+
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
+ dim=0, keepdim=True)
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
+ dim=0, keepdim=True)
+
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
+
+
+def _tokenize_fn(strings: Sequence[str],
+ tokenizer: transformers.PreTrainedTokenizer) -> Dict:
+ """Tokenize a list of strings."""
+ tokenized_list = [
+ tokenizer(
+ text,
+ return_tensors="pt",
+ padding="longest",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ ) for text in strings
+ ]
+ input_ids = labels = [
+ tokenized.input_ids[0] for tokenized in tokenized_list
+ ]
+ input_ids_lens = labels_lens = [
+ tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item()
+ for tokenized in tokenized_list
+ ]
+ return dict(
+ input_ids=input_ids,
+ labels=labels,
+ input_ids_lens=input_ids_lens,
+ labels_lens=labels_lens,
+ )
+
+
+def _mask_targets(target, tokenized_lens, speakers):
+ # cur_idx = 0
+ cur_idx = tokenized_lens[0]
+ tokenized_lens = tokenized_lens[1:]
+ target[:cur_idx] = IGNORE_INDEX
+ for tokenized_len, speaker in zip(tokenized_lens, speakers):
+ if speaker == "human":
+ target[cur_idx+2:cur_idx + tokenized_len] = IGNORE_INDEX
+ cur_idx += tokenized_len
+
+
+def _add_speaker_and_signal(header, source, get_conversation=True):
+ """Add speaker and start/end signal on each round."""
+ BEGIN_SIGNAL = "### "
+ END_SIGNAL = "\n"
+ conversation = header
+ for sentence in source:
+ from_str = sentence["from"]
+ if from_str.lower() == "human":
+ from_str = conversation_lib.default_conversation.roles[0]
+ elif from_str.lower() == "gpt":
+ from_str = conversation_lib.default_conversation.roles[1]
+ else:
+ from_str = 'unknown'
+ sentence["value"] = (BEGIN_SIGNAL + from_str + ": " +
+ sentence["value"] + END_SIGNAL)
+ if get_conversation:
+ conversation += sentence["value"]
+ conversation += BEGIN_SIGNAL
+ return conversation
+
+
+def preprocess_graph(
+ sources: Sequence[str],
+ graph_cfg: dict,
+ cur_token_len: int,
+) -> Dict:
+ is_graph = graph_cfg['is_graph']
+ # image_token_len = multimodal_cfg['image_token_len']
+ graph_token_len = cur_token_len
+ if not is_graph:
+ return sources
+
+ for source in sources:
+ if graph_cfg['sep_graph_conv_front']:
+ assert DEFAULT_GRAPH_TOKEN in source[0]['value']
+ source[0]['value'] = source[0]['value'].replace(DEFAULT_GRAPH_TOKEN, '').strip()
+ source[0]['value'] = DEFAULT_GRAPH_TOKEN + conversation_lib.default_conversation.sep + conversation_lib.default_conversation.roles[0] + ": " + source[0]['value']
+ for sentence in source:
+ replace_token = DEFAULT_GRAPH_PATCH_TOKEN * graph_token_len
+ if graph_cfg['use_graph_start_end']:
+ replace_token = DEFAULT_G_START_TOKEN + replace_token + DEFAULT_G_END_TOKEN
+ sentence["value"] = sentence["value"].replace(DEFAULT_GRAPH_TOKEN, replace_token)
+
+ return sources
+
+def preprocess_graph_LP(
+ sources: Sequence[str],
+ graph_cfg: dict,
+ cur_token_len_1: int,
+ cur_token_len_2: int,
+) -> Dict:
+ is_graph = graph_cfg['is_graph']
+ # image_token_len = multimodal_cfg['image_token_len']
+ graph_token_len_1 = cur_token_len_1
+ graph_token_len_2 = cur_token_len_2
+
+ if not is_graph:
+ return sources
+
+ for source in sources:
+ if graph_cfg['sep_graph_conv_front']:
+ assert DEFAULT_GRAPH_TOKEN in source[0]['value']
+ source[0]['value'] = source[0]['value'].replace(DEFAULT_GRAPH_TOKEN, '').strip()
+ source[0]['value'] = DEFAULT_GRAPH_TOKEN + conversation_lib.default_conversation.sep + conversation_lib.default_conversation.roles[0] + ": " + source[0]['value']
+ for sentence in source:
+ replace_token_1 = DEFAULT_GRAPH_PATCH_TOKEN * graph_token_len_1
+ replace_token_2 = DEFAULT_GRAPH_PATCH_TOKEN * graph_token_len_2
+ if graph_cfg['use_graph_start_end']:
+ replace_token_1 = DEFAULT_G_START_TOKEN + replace_token_1 + DEFAULT_G_END_TOKEN
+ replace_token_2 = DEFAULT_G_START_TOKEN + replace_token_2 + DEFAULT_G_END_TOKEN
+
+ if DEFAULT_GRAPH_TOKEN in sentence["value"]:
+ first_index = sentence["value"].find(DEFAULT_GRAPH_TOKEN)
+ sentence["value"] = sentence["value"][:first_index] + replace_token_1 + sentence["value"][first_index+len(DEFAULT_GRAPH_TOKEN):]
+
+ # 替换第二个为B
+ second_index = sentence["value"].find(DEFAULT_GRAPH_TOKEN)
+ sentence["value"] = sentence["value"][:second_index] + replace_token_2 + sentence["value"][second_index+len(DEFAULT_GRAPH_TOKEN):]
+
+
+ # sentence["value"] = sentence["value"].replace(DEFAULT_GRAPH_TOKEN, replace_token)
+
+ # print(sources)
+
+ return sources
+
+
+def preprocess_v1(
+ sources,
+ tokenizer: transformers.PreTrainedTokenizer,
+) -> Dict:
+ conv = conversation_lib.default_conversation.copy()
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
+
+ # Apply prompt templates
+ conversations = []
+ for i, source in enumerate(sources):
+ if roles[source[0]["from"]] != conv.roles[0]:
+ # Skip the first one if it is not from human
+ source = source[1:]
+
+ conv.messages = []
+ for j, sentence in enumerate(source):
+ role = roles[sentence["from"]]
+ assert role == conv.roles[j % 2], f"{i}"
+ conv.append_message(role, sentence["value"])
+ conversations.append(conv.get_prompt())
+
+ # Tokenize conversations
+ input_ids = tokenizer(
+ conversations,
+ return_tensors="pt",
+ padding="longest",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ ).input_ids
+ targets = input_ids.clone()
+
+ assert conv.sep_style == conversation_lib.SeparatorStyle.TWO
+
+ # Mask targets
+ sep = conv.sep + conv.roles[1] + ": "
+ for conversation, target in zip(conversations, targets):
+ total_len = int(target.ne(tokenizer.pad_token_id).sum())
+
+ rounds = conversation.split(conv.sep2)
+ cur_len = 1
+ target[:cur_len] = IGNORE_INDEX
+ for i, rou in enumerate(rounds):
+ if rou == "":
+ break
+
+ parts = rou.split(sep)
+ if len(parts) != 2:
+ break
+ parts[0] += sep
+ round_len = len(tokenizer(rou).input_ids)
+ instruction_len = len(tokenizer(parts[0]).input_ids) - 2
+
+ target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
+
+ cur_len += round_len
+ target[cur_len:] = IGNORE_INDEX
+
+ if cur_len < tokenizer.model_max_length:
+ if cur_len != total_len:
+ target[:] = IGNORE_INDEX
+ print(
+ f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
+ f" (ignored)"
+ )
+
+ return dict(
+ input_ids=input_ids,
+ labels=targets,
+ )
+
+def preprocess_mpt(
+ sources,
+ tokenizer: transformers.PreTrainedTokenizer,
+) -> Dict:
+ conv = conversation_lib.default_conversation.copy()
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
+
+ # Apply prompt templates
+ conversations = []
+ for i, source in enumerate(sources):
+ if roles[source[0]["from"]] != conv.roles[0]:
+ # Skip the first one if it is not from human
+ source = source[1:]
+
+ conv.messages = []
+ for j, sentence in enumerate(source):
+ role = roles[sentence["from"]]
+ assert role == conv.roles[j % 2], f"{i}"
+ conv.append_message(role, sentence["value"])
+ conversations.append(conv.get_prompt())
+
+ # Tokenize conversations
+ input_ids = tokenizer(
+ conversations,
+ return_tensors="pt",
+ padding="longest",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ ).input_ids
+ targets = input_ids.clone()
+ assert conv.sep_style == conversation_lib.SeparatorStyle.MPT
+
+ # Mask targets
+ sep = conv.sep + conv.roles[1]
+ for conversation, target in zip(conversations, targets):
+ total_len = int(target.ne(tokenizer.pad_token_id).sum())
+
+ rounds = conversation.split(conv.sep)
+ re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt
+ for conv_idx in range(3, len(rounds), 2):
+ re_rounds.append(conv.sep.join(rounds[conv_idx:conv_idx+2])) # user + gpt
+ cur_len = 0
+ target[:cur_len] = IGNORE_INDEX
+ for i, rou in enumerate(re_rounds):
+ if rou == "":
+ break
+
+ parts = rou.split(sep)
+ if len(parts) != 2:
+ break
+ parts[0] += sep
+ round_len = len(tokenizer(rou).input_ids) + len(tokenizer(conv.sep).input_ids)
+ instruction_len = len(tokenizer(parts[0]).input_ids)
+ target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
+
+ cur_len += round_len
+ target[cur_len:] = IGNORE_INDEX
+
+ if cur_len < tokenizer.model_max_length:
+ if cur_len != total_len:
+ target[:] = IGNORE_INDEX
+ print(
+ f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
+ f" (ignored)"
+ )
+
+ return dict(
+ input_ids=input_ids,
+ labels=targets,
+ )
+
+
+def preprocess(
+ sources: Sequence[str],
+ tokenizer: transformers.PreTrainedTokenizer,
+) -> Dict:
+ """
+ Given a list of sources, each is a conversation list. This transform:
+ 1. Add signal '### ' at the beginning each sentence, with end signal '\n';
+ 2. Concatenate conversations together;
+ 3. Tokenize the concatenated conversation;
+ 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX.
+ """
+ if conversation_lib.default_conversation.version == "v1":
+ return preprocess_v1(sources, tokenizer)
+ if conversation_lib.default_conversation.version == "mpt":
+ return preprocess_mpt(sources, tokenizer)
+ # add end signal and concatenate together
+ conversations = []
+ for source in sources:
+ header = f"{conversation_lib.default_conversation.system}\n\n"
+ conversation = _add_speaker_and_signal(header, source)
+ conversations.append(conversation)
+ # tokenize conversations
+ conversations_tokenized = _tokenize_fn(conversations, tokenizer)
+ input_ids = conversations_tokenized["input_ids"]
+ targets = copy.deepcopy(input_ids)
+ for target, source in zip(targets, sources):
+ tokenized_lens = _tokenize_fn([header] + [s["value"] for s in source],
+ tokenizer)["input_ids_lens"]
+ speakers = [sentence["from"] for sentence in source]
+ _mask_targets(target, tokenized_lens, speakers)
+
+ return dict(input_ids=input_ids, labels=targets)
+
+
+class SupervisedDataset(Dataset):
+ """Dataset for supervised fine-tuning."""
+
+ def __init__(self, data_path: str,
+ tokenizer: transformers.PreTrainedTokenizer):
+ super(SupervisedDataset, self).__init__()
+ logging.warning("Loading data...")
+ list_data_dict = json.load(open(data_path, "r"))
+
+ logging.warning("Formatting inputs...")
+ sources = [example["conversations"] for example in list_data_dict]
+ data_dict = preprocess(sources, tokenizer)
+
+ self.input_ids = data_dict["input_ids"]
+ self.labels = data_dict["labels"]
+
+ def __len__(self):
+ return len(self.input_ids)
+
+ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
+ return dict(input_ids=self.input_ids[i], labels=self.labels[i])
+
+
+class LazySupervisedDataset(Dataset):
+ """Dataset for supervised fine-tuning."""
+
+ def __init__(self, data_path: str,
+ tokenizer: transformers.PreTrainedTokenizer,
+ graph_cfg: dict,
+ **kwargs,):
+ super(LazySupervisedDataset, self).__init__()
+ logging.warning("Loading data...")
+ list_data_dict = json.load(open(data_path, "r"))
+
+ logging.warning("Formatting inputs...Skip in lazy mode")
+ self.tokenizer = tokenizer
+ self.list_data_dict = list_data_dict
+ self.graph_cfg = graph_cfg
+ graph_data_path = kwargs.get('graph_data_path')
+ self.graph_data_all = torch.load(graph_data_path)
+
+ def __len__(self):
+ return len(self.list_data_dict)
+
+ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
+ sources = self.list_data_dict[i]
+ if isinstance(i, int):
+ sources = [sources]
+ assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME
+
+ task_type = self.list_data_dict[i]['id'].split("_")[-1]
+ if task_type != 'LP':
+ if 'graph' in sources[0]:
+ graph_dict = self.list_data_dict[i]['graph']
+ graph_edge_index = torch.Tensor(copy.deepcopy(graph_dict['edge_index'])).long()
+ graph_node_list = copy.deepcopy(graph_dict['node_list'])
+ target_node = copy.deepcopy(graph_dict['node_idx'])
+ graph_type = copy.deepcopy(self.list_data_dict[i]['id']).split('_')[0]
+ graph_node_rep = self.graph_data_all[graph_type].x[graph_node_list] ##
+
+ cur_token_len = len(graph_node_rep) # FIXME: 14 is hardcoded patch size
+ sources = preprocess_graph(
+ copy.deepcopy([e["conversations"] for e in sources]),
+ self.graph_cfg, cur_token_len)
+ else:
+ sources = copy.deepcopy([e["conversations"] for e in sources])
+ else:
+ if 'graph' in sources[0]:
+ graph_dict = self.list_data_dict[i]['graph']
+ graph_edge_index_1 = torch.Tensor(copy.deepcopy(graph_dict['edge_index_1'])).long()
+ graph_node_list_1 = copy.deepcopy(graph_dict['node_list_1'])
+ target_node_1 = copy.deepcopy(graph_dict['node_idx_1'])
+ graph_type = copy.deepcopy(self.list_data_dict[i]['id']).split('_')[0]
+ graph_node_rep_1 = self.graph_data_all[graph_type].x[graph_node_list_1] ##
+
+ cur_token_len_1 = len(graph_node_rep_1) # FIXME: 14 is hardcoded patch size
+
+ graph_edge_index_2 = torch.Tensor(copy.deepcopy(graph_dict['edge_index_2'])).long()
+ graph_node_list_2 = copy.deepcopy(graph_dict['node_list_2'])
+ target_node_2 = copy.deepcopy(graph_dict['node_idx_2'])
+ graph_node_rep_2 = self.graph_data_all[graph_type].x[graph_node_list_2] ##
+
+ cur_token_len_2 = len(graph_node_rep_2) # FIXME: 14 is hardcoded patch size
+ sources = preprocess_graph_LP(
+ copy.deepcopy([e["conversations"] for e in sources]),
+ self.graph_cfg, cur_token_len_1, cur_token_len_2)
+ else:
+ sources = copy.deepcopy([e["conversations"] for e in sources])
+ data_dict = preprocess(
+ sources,
+ self.tokenizer)
+ if isinstance(i, int):
+ data_dict = dict(input_ids=data_dict["input_ids"][0],
+ labels=data_dict["labels"][0])
+
+ # image exist in the data
+ if task_type != 'LP':
+ if 'graph' in self.list_data_dict[i]:
+ # data_dict['graph_node'] = graph_node_rep
+ # data_dict['graph_edge'] = graph_edge_index
+ # data_dict['target_node'] = target_node
+ data_dict['graph_data'] = Data(graph_node = graph_node_rep, edge_index=graph_edge_index, target_node = torch.tensor([target_node]))
+
+ elif self.graph_cfg['is_graph']:
+ # image does not exist in the data, but the model is multimodal
+ node_feas = self.graph_cfg['graph_processor'].node_feas
+ data_dict['graph_data'] = Data(graph_node = torch.zeros(3, node_feas), edge_index=torch.zeros(2, 3), target_node = torch.tensor([0]))
+ else:
+ if 'graph' in self.list_data_dict[i]:
+ # data_dict['graph_node'] = graph_node_rep
+ # data_dict['graph_edge'] = graph_edge_index
+ # data_dict['target_node'] = target_node
+ data_dict['graph_data'] = {
+ 'graph_1': Data(graph_node = graph_node_rep_1, edge_index=graph_edge_index_1, target_node = torch.tensor([target_node_1])),
+ 'graph_2': Data(graph_node = graph_node_rep_2, edge_index=graph_edge_index_2, target_node = torch.tensor([target_node_2]))
+ }
+
+ elif self.graph_cfg['is_graph']:
+ # image does not exist in the data, but the model is multimodal
+ node_feas = self.graph_cfg['graph_processor'].node_feas
+ data_dict['graph_data'] = Data(graph_node = torch.zeros(3, node_feas), edge_index=torch.zeros(2, 3), target_node = torch.tensor([0]))
+ return data_dict
+
+class LazySupervisedDataset_back(Dataset):
+ """Dataset for supervised fine-tuning."""
+
+ def __init__(self, data_path: str,
+ tokenizer: transformers.PreTrainedTokenizer,
+ graph_cfg: dict,
+ **kwargs,):
+ super(LazySupervisedDataset, self).__init__()
+ logging.warning("Loading data...")
+ list_data_dict = json.load(open(data_path, "r"))
+
+ logging.warning("Formatting inputs...Skip in lazy mode")
+ self.tokenizer = tokenizer
+ self.list_data_dict = list_data_dict
+ self.graph_cfg = graph_cfg
+ graph_data_path = kwargs.get('graph_data_path')
+ self.graph_data_all = torch.load(graph_data_path)
+
+ def __len__(self):
+ return len(self.list_data_dict)
+
+ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
+ sources = self.list_data_dict[i]
+ if isinstance(i, int):
+ sources = [sources]
+ assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME
+ if 'graph' in sources[0]:
+ graph_dict = self.list_data_dict[i]['graph']
+ graph_edge_index = torch.Tensor(copy.deepcopy(graph_dict['edge_index'])).long()
+ graph_node_list = copy.deepcopy(graph_dict['node_list'])
+ target_node = copy.deepcopy(graph_dict['node_idx'])
+ graph_type = copy.deepcopy(self.list_data_dict[i]['id']).split('_')[0]
+ graph_node_rep = self.graph_data_all[graph_type].x[graph_node_list] ##
+
+ cur_token_len = len(graph_node_rep) # FIXME: 14 is hardcoded patch size
+ sources = preprocess_graph(
+ copy.deepcopy([e["conversations"] for e in sources]),
+ self.graph_cfg, cur_token_len)
+ else:
+ sources = copy.deepcopy([e["conversations"] for e in sources])
+ data_dict = preprocess(
+ sources,
+ self.tokenizer)
+ if isinstance(i, int):
+ data_dict = dict(input_ids=data_dict["input_ids"][0],
+ labels=data_dict["labels"][0])
+
+ # image exist in the data
+ if 'graph' in self.list_data_dict[i]:
+ # data_dict['graph_node'] = graph_node_rep
+ # data_dict['graph_edge'] = graph_edge_index
+ # data_dict['target_node'] = target_node
+ data_dict['graph_data'] = Data(graph_node = graph_node_rep, edge_index=graph_edge_index, target_node = torch.tensor([target_node]))
+
+ elif self.graph_cfg['is_graph']:
+ # image does not exist in the data, but the model is multimodal
+ node_feas = self.graph_cfg['graph_processor'].node_feas
+ data_dict['graph_data'] = Data(graph_node = torch.zeros(3, node_feas), edge_index=torch.zeros(2, 3), target_node = torch.tensor([0]))
+ return data_dict
+
+
+@dataclass
+class DataCollatorForSupervisedDataset(object):
+ """Collate examples for supervised fine-tuning."""
+
+ tokenizer: transformers.PreTrainedTokenizer
+
+ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
+ input_ids, labels = tuple([instance[key] for instance in instances]
+ for key in ("input_ids", "labels"))
+ input_ids = torch.nn.utils.rnn.pad_sequence(
+ input_ids,
+ batch_first=True,
+ padding_value=self.tokenizer.pad_token_id)
+ labels = torch.nn.utils.rnn.pad_sequence(labels,
+ batch_first=True,
+ padding_value=IGNORE_INDEX)
+ batch = dict(
+ input_ids=input_ids,
+ labels=labels,
+ attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
+ )
+
+ if 'graph_data' in instances[0]:
+ # graph_node_reps = [instance['graph_node'] for instance in instances]
+ # edge_index_reps = [instance['graph_edge'] for instance in instances]
+ # target_node_reps = [instance['target_node'] for instance in instances]
+ graph_data_batch = [instance['graph_data'] for instance in instances]
+ # if all(x is not None and x.shape == images[0].shape for x in images):
+ # batch['images'] = torch.stack(images)
+ # else:
+ # batch['images'] = images
+ # batch['graph_node_reps'] = graph_node_reps
+ # batch['edge_index_reps'] = edge_index_reps
+ # batch['edge_index_reps'] = target_node_reps
+ batch['graph_data'] = graph_data_batch
+
+ return batch
+
+
+def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer,
+ data_args) -> Dict:
+ """Make dataset and collator for supervised fine-tuning."""
+ dataset_cls = (LazySupervisedDataset
+ if data_args.lazy_preprocess else SupervisedDataset)
+ train_dataset = dataset_cls(tokenizer=tokenizer,
+ data_path=data_args.data_path,
+ graph_cfg=dict(
+ is_graph=data_args.is_graph,
+ sep_graph_conv_front=data_args.sep_graph_conv_front,
+ graph_token_len=data_args.graph_token_len,
+ graph_content=data_args.graph_content,
+ use_graph_start_end=getattr(data_args, 'use_graph_start_end', False)
+ ),
+ graph_data_path = data_args.graph_data_path)
+ data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
+ return dict(train_dataset=train_dataset,
+ eval_dataset=None,
+ data_collator=data_collator)
+
+def train():
+ parser = transformers.HfArgumentParser(
+ (ModelArguments, DataArguments, TrainingArguments))
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+
+
+ compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
+
+ bnb_model_from_pretrained_args = {}
+
+ ## load 4 8 bit
+ if training_args.bits in [4, 8]:
+ from transformers import BitsAndBytesConfig
+ from peft import prepare_model_for_int8_training
+ bnb_model_from_pretrained_args.update(dict(
+ device_map={"": training_args.device},
+ load_in_4bit=training_args.bits == 4,
+ load_in_8bit=training_args.bits == 8,
+ quantization_config=BitsAndBytesConfig(
+ load_in_4bit=training_args.bits == 4,
+ load_in_8bit=training_args.bits == 8,
+ llm_int8_threshold=6.0,
+ llm_int8_has_fp16_weight=False,
+ bnb_4bit_compute_dtype=compute_dtype,
+ bnb_4bit_use_double_quant=training_args.double_quant,
+ bnb_4bit_quant_type=training_args.quant_type # {'fp4', 'nf4'}
+ )
+ ))
+
+ if model_args.graph_tower is not None:
+ model = GraphLlamaForCausalLM.from_pretrained(
+ model_args.model_name_or_path,
+ cache_dir=training_args.cache_dir,
+ **bnb_model_from_pretrained_args
+ ) ## TODO: add real Graph Llama model
+ else:
+ model = transformers.LlamaForCausalLM.from_pretrained(
+ model_args.model_name_or_path,
+ cache_dir=training_args.cache_dir,
+ **bnb_model_from_pretrained_args
+ )
+ model.config.pretrain_graph_model_path = model.config.pretrain_graph_model_path + model_args.graph_tower
+ model.config.use_cache = False
+
+ if model_args.freeze_backbone:
+ model.model.requires_grad_(False)
+
+ if training_args.bits in [4, 8]:
+ model.config.torch_dtype=(torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
+ model = prepare_model_for_int8_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing)
+
+ if training_args.gradient_checkpointing and model_args.graph_tower is None:
+ if hasattr(model, "enable_input_require_grads"):
+ model.enable_input_require_grads()
+ else:
+ def make_inputs_require_grad(module, input, output):
+ output.requires_grad_(True)
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
+
+ if training_args.lora_enable:
+ from peft import LoraConfig, get_peft_model
+ lora_config = LoraConfig(
+ r=training_args.lora_r,
+ lora_alpha=training_args.lora_alpha,
+ target_modules=find_all_linear_names(model),
+ lora_dropout=training_args.lora_dropout,
+ bias=training_args.lora_bias,
+ task_type="CAUSAL_LM",
+ )
+ if training_args.bits == 16:
+ if training_args.bf16:
+ model.to(torch.bfloat16)
+ if training_args.fp16:
+ model.to(torch.float16)
+ logging.warning("Adding LoRA adapters...")
+ model = get_peft_model(model, lora_config)
+
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
+ model_args.model_name_or_path,
+ cache_dir=training_args.cache_dir,
+ model_max_length=training_args.model_max_length,
+ padding_side="right",
+ use_fast=False,
+ )
+
+ if model_args.version == "v0":
+ if tokenizer.pad_token is None:
+ smart_tokenizer_and_embedding_resize(
+ special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN),
+ tokenizer=tokenizer,
+ model=model,
+ )
+ if "llama" in model_args.model_name_or_path:
+ tokenizer.add_special_tokens({
+ "eos_token": DEFAULT_EOS_TOKEN,
+ "bos_token": DEFAULT_BOS_TOKEN,
+ "unk_token": DEFAULT_UNK_TOKEN,
+ })
+ else:
+ tokenizer.pad_token = tokenizer.unk_token
+ conversation_lib.default_conversation = conversation_lib.conv_templates["vicuna_v1_1"]
+
+ if model_args.graph_tower is not None:
+ model_graph_dict = model.get_model().initialize_graph_modules(
+ graph_tower=model_args.graph_tower,
+ graph_select_layer=model_args.graph_select_layer,
+ pretrain_graph_mlp_adapter=model_args.pretrain_graph_mlp_adapter,
+ fsdp=training_args.fsdp
+ )
+ model.get_graph_tower().to(dtype=torch.float16, device=training_args.device)
+ # graph_config = model_graph_dict['graph_config']
+
+ # data_args.graph_token_len = model_graph_dict['graph_token_len']
+ # data_args.graph_processor = model_graph_dict['graph_processor']
+ data_args.is_graph = True
+
+ model.config.tune_graph_mlp_adapter = training_args.tune_graph_mlp_adapter = model_args.tune_graph_mlp_adapter
+ if model_args.tune_graph_mlp_adapter:
+ model.requires_grad_(False)
+ for p in model.get_model().graph_projector.parameters():
+ p.requires_grad = True
+
+ model.config.freeze_graph_mlp_adapter = training_args.freeze_graph_mlp_adapter
+ if training_args.freeze_graph_mlp_adapter:
+ for p in model.get_model().graph_projector.parameters():
+ p.requires_grad = False
+
+ if training_args.bits in [4, 8]:
+ model.get_model().graph_projector.to(dtype=compute_dtype, device=training_args.device)
+
+ model.config.use_graph_start_end = data_args.use_graph_start_end = model_args.use_graph_start_end
+ # graph_config.use_graph_start_end = training_args.use_graph_start_end = model_args.use_graph_start_end
+ training_args.use_graph_start_end = model_args.use_graph_start_end
+ model.config.sep_graph_conv_front = data_args.sep_graph_conv_front
+ model.initialize_graph_tokenizer(use_graph_start_end=model_args.use_graph_start_end, tokenizer=tokenizer, device=training_args.device,
+ tune_graph_mlp_adapter=model_args.tune_graph_mlp_adapter, pretrain_graph_mlp_adapter=model_args.pretrain_graph_mlp_adapter)
+
+ params_no_grad = [n for n, p in model.named_parameters() if not p.requires_grad]
+ if len(params_no_grad) > 0:
+ if training_args.fsdp is not None and len(training_args.fsdp) > 0:
+ if len(params_no_grad) < 10:
+ print('[WARNING] Attempting to use FSDP while {} parameters do not require gradients: {}'. format(len(params_no_grad), params_no_grad))
+ else:
+ print('[WARNING] Attempting to use FSDP while {} parameters do not require gradients: {}...(omitted)'. format(len(params_no_grad), ', '.join(params_no_grad[:10])))
+ print("[WARNING] Attempting to use FSDP with partially frozen paramters, this is experimental.")
+ print("[WARNING] As of 4/30/23, this feature requires PyTorch-nightly build. See here for details: https://github.com/haotian-liu/LLaVA#experimental-use-fsdp-to-save-memory-in-pretraining")
+
+ from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
+ def patch_FSDP_use_orig_params(func):
+ def wrap_func(*args, **kwargs):
+ use_orig_params = kwargs.pop('use_orig_params', True)
+ return func(*args, **kwargs, use_orig_params=use_orig_params)
+ return wrap_func
+
+ FSDP.__init__ = patch_FSDP_use_orig_params(FSDP.__init__)
+
+ if training_args.bits in [4, 8]:
+ from peft.tuners.lora import LoraLayer
+ for name, module in model.named_modules():
+ if isinstance(module, LoraLayer):
+ if training_args.bf16:
+ module = module.to(torch.bfloat16)
+ if 'norm' in name:
+ module = module.to(torch.float32)
+ if 'lm_head' in name or 'embed_tokens' in name:
+ if hasattr(module, 'weight'):
+ if training_args.bf16 and module.weight.dtype == torch.float32:
+ module = module.to(torch.bfloat16)
+
+ data_module = make_supervised_data_module(tokenizer=tokenizer,
+ data_args=data_args)
+ trainer = GraphChatTrainer(model=model,
+ tokenizer=tokenizer,
+ args=training_args,
+ **data_module)
+
+ print('************************** parameters: #', sum(p.numel() for p in model.parameters() if p.requires_grad))
+ tuned_params = []
+ for name, param in model.named_parameters():
+ if param.requires_grad:
+ tuned_params.append(name)
+ print(tuned_params)
+
+ if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
+ trainer.train(resume_from_checkpoint=True)
+ else:
+ trainer.train()
+ trainer.save_state()
+
+ if training_args.lora_enable:
+ state_dict = get_peft_state_maybe_zero_3(
+ model.named_parameters(), training_args.lora_bias
+ )
+ non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3(
+ model.named_parameters()
+ )
+ if training_args.local_rank == 0 or training_args.local_rank == -1:
+ model.config.save_pretrained(training_args.output_dir)
+ model.save_pretrained(training_args.output_dir, state_dict=state_dict)
+ torch.save(non_lora_state_dict, os.path.join(training_args.output_dir, 'non_lora_trainables.bin'))
+ else:
+ safe_save_model_for_hf_trainer(trainer=trainer,
+ output_dir=training_args.output_dir)
+
+
+if __name__ == "__main__":
+ train()
diff --git a/graphgpt/train/train_lora.py b/graphgpt/train/train_lora.py
new file mode 100644
index 0000000..ac52f81
--- /dev/null
+++ b/graphgpt/train/train_lora.py
@@ -0,0 +1,157 @@
+# Usage: deepspeed train_lora.py --deepspeed <$PATH_TO_DEEPSPEED_CONFIG>
+
+# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
+# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass, field
+import logging
+import pathlib
+import typing
+
+from deepspeed import zero
+from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
+from peft import LoraConfig, get_peft_model
+import transformers
+from transformers import Trainer
+
+from fastchat.train.train import (
+ DataArguments,
+ ModelArguments,
+ TrainingArguments,
+ make_supervised_data_module,
+)
+
+from fastchat.train.llama_flash_attn_monkey_patch import (
+ replace_llama_attn_with_flash_attn,
+)
+
+replace_llama_attn_with_flash_attn()
+
+
+@dataclass
+class LoraArguments:
+ lora_r: int = 8
+ lora_alpha: int = 16
+ lora_dropout: float = 0.05
+ lora_target_modules: typing.List[str] = field(
+ default_factory=lambda: ["q_proj", "v_proj"]
+ )
+ lora_weight_path: str = ""
+ lora_bias: str = "none"
+
+
+def maybe_zero_3(param):
+ if hasattr(param, "ds_id"):
+ assert param.ds_status == ZeroParamStatus.NOT_AVAILABLE
+ with zero.GatheredParameters([param]):
+ param = param.data.detach().cpu().clone()
+ else:
+ param = param.detach().cpu().clone()
+ return param
+
+
+# Borrowed from peft.utils.get_peft_model_state_dict
+def get_peft_state_maybe_zero_3(named_params, bias):
+ if bias == "none":
+ to_return = {k: t for k, t in named_params if "lora_" in k}
+ elif bias == "all":
+ to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
+ elif bias == "lora_only":
+ to_return = {}
+ maybe_lora_bias = {}
+ lora_bias_names = set()
+ for k, t in named_params:
+ if "lora_" in k:
+ to_return[k] = t
+ bias_name = k.split("lora_")[0] + "bias"
+ lora_bias_names.add(bias_name)
+ elif "bias" in k:
+ maybe_lora_bias[k] = t
+ for k, t in maybe_lora_bias:
+ if bias_name in lora_bias_names:
+ to_return[bias_name] = t
+ else:
+ raise NotImplementedError
+ to_return = {k: maybe_zero_3(v) for k, v in to_return.items()}
+ return to_return
+
+
+def train():
+ parser = transformers.HfArgumentParser(
+ (ModelArguments, DataArguments, TrainingArguments, LoraArguments)
+ )
+ (
+ model_args,
+ data_args,
+ training_args,
+ lora_args,
+ ) = parser.parse_args_into_dataclasses()
+
+ model = transformers.AutoModelForCausalLM.from_pretrained(
+ model_args.model_name_or_path,
+ cache_dir=training_args.cache_dir,
+ )
+ lora_config = LoraConfig(
+ r=lora_args.lora_r,
+ lora_alpha=lora_args.lora_alpha,
+ target_modules=lora_args.lora_target_modules,
+ lora_dropout=lora_args.lora_dropout,
+ bias=lora_args.lora_bias,
+ task_type="CAUSAL_LM",
+ )
+ model = get_peft_model(model, lora_config)
+ if training_args.deepspeed is not None and training_args.local_rank == 0:
+ model.print_trainable_parameters()
+
+ if training_args.gradient_checkpointing:
+ logging.warning(
+ "gradient checkpointing with lora makes requires_grad "
+ "incorrect and needs a monkey patch in Trainer or the "
+ "wrapped model's forward. ref: "
+ "https://github.com/lm-sys/FastChat/pull/138#issuecomment-1509172198"
+ )
+
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
+ model_args.model_name_or_path,
+ cache_dir=training_args.cache_dir,
+ model_max_length=training_args.model_max_length,
+ padding_side="right",
+ use_fast=False,
+ )
+ tokenizer.pad_token = tokenizer.unk_token
+
+ data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
+ trainer = Trainer(
+ model=model, tokenizer=tokenizer, args=training_args, **data_module
+ )
+
+ model.config.use_cache = False
+
+ if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
+ trainer.train(resume_from_checkpoint=True)
+ else:
+ trainer.train()
+ trainer.save_state()
+
+ # Save states. Weights might be a placeholder in zero3 and need a gather
+ state_dict = get_peft_state_maybe_zero_3(
+ model.named_parameters(), lora_args.lora_bias
+ )
+ if training_args.local_rank == 0:
+ model.save_pretrained(training_args.output_dir, state_dict=state_dict)
+
+
+if __name__ == "__main__":
+ train()
diff --git a/graphgpt/train/train_mem.py b/graphgpt/train/train_mem.py
new file mode 100644
index 0000000..bfe35ea
--- /dev/null
+++ b/graphgpt/train/train_mem.py
@@ -0,0 +1,13 @@
+# Make it more memory efficient by monkey patching the LLaMA model with FlashAttn.
+
+# Need to call this before importing transformers.
+from graphgpt.train.llama_flash_attn_monkey_patch import (
+ replace_llama_attn_with_flash_attn,
+)
+
+replace_llama_attn_with_flash_attn()
+
+from graphgpt.train.train_graph import train
+
+if __name__ == "__main__":
+ train()
diff --git a/graphgpt/utils.py b/graphgpt/utils.py
new file mode 100644
index 0000000..75188b0
--- /dev/null
+++ b/graphgpt/utils.py
@@ -0,0 +1,240 @@
+from asyncio import AbstractEventLoop
+import json
+import logging
+import logging.handlers
+import os
+import platform
+import sys
+from typing import AsyncGenerator, Generator
+import warnings
+
+import requests
+import torch
+
+from fastchat.constants import LOGDIR
+
+
+handler = None
+server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
+moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
+
+def build_logger(logger_name, logger_filename):
+ global handler
+
+ formatter = logging.Formatter(
+ fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
+ datefmt="%Y-%m-%d %H:%M:%S",
+ )
+
+ # Set the format of root handlers
+ if not logging.getLogger().handlers:
+ if sys.version_info[1] >= 9:
+ # This is for windows
+ logging.basicConfig(level=logging.INFO, encoding="utf-8")
+ else:
+ if platform.system() == "Windows":
+ warnings.warn(
+ "If you are running on Windows, "
+ "we recommend you use Python >= 3.9 for UTF-8 encoding."
+ )
+ logging.basicConfig(level=logging.INFO)
+ logging.getLogger().handlers[0].setFormatter(formatter)
+
+ # Redirect stdout and stderr to loggers
+ stdout_logger = logging.getLogger("stdout")
+ stdout_logger.setLevel(logging.INFO)
+ sl = StreamToLogger(stdout_logger, logging.INFO)
+ sys.stdout = sl
+
+ stderr_logger = logging.getLogger("stderr")
+ stderr_logger.setLevel(logging.ERROR)
+ sl = StreamToLogger(stderr_logger, logging.ERROR)
+ sys.stderr = sl
+
+ # Get logger
+ logger = logging.getLogger(logger_name)
+ logger.setLevel(logging.INFO)
+
+ # Add a file handler for all loggers
+ if handler is None:
+ os.makedirs(LOGDIR, exist_ok=True)
+ filename = os.path.join(LOGDIR, logger_filename)
+ handler = logging.handlers.TimedRotatingFileHandler(
+ filename, when="D", utc=True, encoding="utf-8"
+ )
+ handler.setFormatter(formatter)
+
+ for name, item in logging.root.manager.loggerDict.items():
+ if isinstance(item, logging.Logger):
+ item.addHandler(handler)
+
+ return logger
+
+
+class StreamToLogger(object):
+ """
+ Fake file-like stream object that redirects writes to a logger instance.
+ """
+
+ def __init__(self, logger, log_level=logging.INFO):
+ self.terminal = sys.stdout
+ self.logger = logger
+ self.log_level = log_level
+ self.linebuf = ""
+
+ def __getattr__(self, attr):
+ return getattr(self.terminal, attr)
+
+ def write(self, buf):
+ temp_linebuf = self.linebuf + buf
+ self.linebuf = ""
+ for line in temp_linebuf.splitlines(True):
+ # From the io.TextIOWrapper docs:
+ # On output, if newline is None, any '\n' characters written
+ # are translated to the system default line separator.
+ # By default sys.stdout.write() expects '\n' newlines and then
+ # translates them so this is still cross platform.
+ if line[-1] == "\n":
+ encoded_message = line.encode("utf-8", "ignore").decode("utf-8")
+ self.logger.log(self.log_level, encoded_message.rstrip())
+ else:
+ self.linebuf += line
+
+ def flush(self):
+ if self.linebuf != "":
+ encoded_message = self.linebuf.encode("utf-8", "ignore").decode("utf-8")
+ self.logger.log(self.log_level, encoded_message.rstrip())
+ self.linebuf = ""
+
+
+def disable_torch_init():
+ """
+ Disable the redundant torch default initialization to accelerate model creation.
+ """
+ import torch
+
+ setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
+ setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
+
+
+def get_gpu_memory(max_gpus=None):
+ """Get available memory for each GPU."""
+ gpu_memory = []
+ num_gpus = (
+ torch.cuda.device_count()
+ if max_gpus is None
+ else min(max_gpus, torch.cuda.device_count())
+ )
+
+ for gpu_id in range(num_gpus):
+ with torch.cuda.device(gpu_id):
+ device = torch.cuda.current_device()
+ gpu_properties = torch.cuda.get_device_properties(device)
+ total_memory = gpu_properties.total_memory / (1024**3)
+ allocated_memory = torch.cuda.memory_allocated() / (1024**3)
+ available_memory = total_memory - allocated_memory
+ gpu_memory.append(available_memory)
+ return gpu_memory
+
+
+def violates_moderation(text):
+ """
+ Check whether the text violates OpenAI moderation API.
+ """
+ url = "https://api.openai.com/v1/moderations"
+ headers = {
+ "Content-Type": "application/json",
+ "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"],
+ }
+ text = text.replace("\n", "")
+ data = "{" + '"input": ' + f'"{text}"' + "}"
+ data = data.encode("utf-8")
+ try:
+ ret = requests.post(url, headers=headers, data=data, timeout=5)
+ flagged = ret.json()["results"][0]["flagged"]
+ except requests.exceptions.RequestException as e:
+ flagged = False
+ except KeyError as e:
+ flagged = False
+
+ return flagged
+
+
+# Flan-t5 trained with HF+FSDP saves corrupted weights for shared embeddings,
+# Use this function to make sure it can be correctly loaded.
+def clean_flant5_ckpt(ckpt_path):
+ index_file = os.path.join(ckpt_path, "pytorch_model.bin.index.json")
+ index_json = json.load(open(index_file, "r"))
+
+ weightmap = index_json["weight_map"]
+
+ share_weight_file = weightmap["shared.weight"]
+ share_weight = torch.load(os.path.join(ckpt_path, share_weight_file))[
+ "shared.weight"
+ ]
+
+ for weight_name in ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"]:
+ weight_file = weightmap[weight_name]
+ weight = torch.load(os.path.join(ckpt_path, weight_file))
+ weight[weight_name] = share_weight
+ torch.save(weight, os.path.join(ckpt_path, weight_file))
+
+
+def pretty_print_semaphore(semaphore):
+ """Print a semaphore in better format."""
+ if semaphore is None:
+ return "None"
+ return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
+
+
+"""A javascript function to get url parameters for the gradio web server."""
+get_window_url_params_js = """
+function() {
+ const params = new URLSearchParams(window.location.search);
+ url_params = Object.fromEntries(params);
+ console.log("url_params", url_params);
+ return url_params;
+ }
+"""
+
+
+def iter_over_async(
+ async_gen: AsyncGenerator, event_loop: AbstractEventLoop
+) -> Generator:
+ """
+ Convert async generator to sync generator
+
+ :param async_gen: the AsyncGenerator to convert
+ :param event_loop: the event loop to run on
+ :returns: Sync generator
+ """
+ ait = async_gen.__aiter__()
+
+ async def get_next():
+ try:
+ obj = await ait.__anext__()
+ return False, obj
+ except StopAsyncIteration:
+ return True, None
+
+ while True:
+ done, obj = event_loop.run_until_complete(get_next())
+ if done:
+ break
+ yield obj
+
+
+def detect_language(text: str) -> str:
+ """Detect the langauge of a string."""
+ import polyglot # pip3 install polyglot pyicu pycld2
+ from polyglot.detect import Detector
+ from polyglot.detect.base import logger as polyglot_logger
+ import pycld2
+
+ polyglot_logger.setLevel("ERROR")
+
+ try:
+ lang_code = Detector(text).language.name
+ except (pycld2.error, polyglot.detect.base.UnknownLanguage):
+ lang_code = "unknown"
+ return lang_code
diff --git a/images/.DS_Store b/images/.DS_Store
new file mode 100644
index 0000000..58595eb
Binary files /dev/null and b/images/.DS_Store differ
diff --git a/images/graphgpt.png b/images/graphgpt.png
new file mode 100644
index 0000000..a73ef65
Binary files /dev/null and b/images/graphgpt.png differ
diff --git a/playground/inspect_conv.py b/playground/inspect_conv.py
new file mode 100644
index 0000000..99306e5
--- /dev/null
+++ b/playground/inspect_conv.py
@@ -0,0 +1,87 @@
+import argparse
+import code
+import datetime
+import json
+import os
+from pytz import timezone
+import time
+
+import pandas as pd
+from tqdm import tqdm
+
+
+def get_log_files(max_num_files=None):
+ dates = []
+ for month in [4, 5]:
+ for day in range(1, 32):
+ dates.append(f"2023-{month:02d}-{day:02d}")
+
+ num_servers = 12
+ filenames = []
+ for d in dates:
+ for i in range(num_servers):
+ name = os.path.expanduser(f"~/fastchat_logs/server{i}/{d}-conv.json")
+ if os.path.exists(name):
+ filenames.append(name)
+ max_num_files = max_num_files or len(filenames)
+ filenames = filenames[-max_num_files:]
+ return filenames
+
+
+def pretty_print_conversation(messages):
+ for role, msg in messages:
+ print(f"[[{role}]]: {msg}")
+
+
+def inspect_convs(log_files):
+ data = []
+ for filename in tqdm(log_files, desc="read files"):
+ for retry in range(5):
+ try:
+ lines = open(filename).readlines()
+ break
+ except FileNotFoundError:
+ time.sleep(2)
+
+ for l in lines:
+ row = json.loads(l)
+
+ if "states" not in row:
+ continue
+ if row["type"] not in ["leftvote", "rightvote", "bothbad_vote"]:
+ continue
+
+ model_names = row["states"][0]["model_name"], row["states"][1]["model_name"]
+ if row["type"] == "leftvote":
+ winner, loser = model_names[0], model_names[1]
+ winner_conv, loser_conv = row["states"][0], row["states"][1]
+ elif row["type"] == "rightvote":
+ loser, winner = model_names[0], model_names[1]
+ loser_conv, winner_conv = row["states"][0], row["states"][1]
+
+ if loser == "bard" and winner == "vicuna-13b":
+ print("=" * 20)
+ print(f"Winner: {winner}")
+ pretty_print_conversation(winner_conv["messages"])
+ print(f"Loser: {loser}")
+ pretty_print_conversation(loser_conv["messages"])
+ print("=" * 20)
+ input()
+
+ # if row["type"] == "bothbad_vote" and "gpt-4" in model_names:
+ # print("=" * 20)
+ # print(f"Model A: {model_names[0]}")
+ # pretty_print_conversation(row["states"][0]["messages"])
+ # print(f"Model B: {model_names[1]}")
+ # pretty_print_conversation(row["states"][1]["messages"])
+ # print("=" * 20)
+ # input()
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--max-num-files", type=int)
+ args = parser.parse_args()
+
+ log_files = get_log_files(args.max_num_files)
+ inspect_convs(log_files)
diff --git a/playground/test_embedding/README.md b/playground/test_embedding/README.md
new file mode 100644
index 0000000..57ac73c
--- /dev/null
+++ b/playground/test_embedding/README.md
@@ -0,0 +1,15 @@
+## Machine Learning with Embeddings
+You can use embeddings to
+- Evaluate text similarity, see [test_sentence_similarity.py](test_sentence_similarity.py)
+- Build your own classifier, see [test_classification.py](test_classification.py)
+- Search relative texts, see [test_semantic_search.py](test_semantic_search.py)
+
+To these tests, you need to download the data [here](https://www.kaggle.com/datasets/snap/amazon-fine-food-reviews). You also need an OpenAI API key for comparison.
+
+Run with:
+```bash
+cd playground/test_embedding
+python3 test_classification.py
+```
+
+The script will train classifiers based on `vicuna-7b`, `text-similarity-ada-001` and `text-embedding-ada-002` and report the accuracy of each classifier.
diff --git a/playground/test_embedding/test_classification.py b/playground/test_embedding/test_classification.py
new file mode 100644
index 0000000..393827b
--- /dev/null
+++ b/playground/test_embedding/test_classification.py
@@ -0,0 +1,83 @@
+import json
+import os
+
+import numpy as np
+import openai
+import pandas as pd
+import requests
+from sklearn.ensemble import RandomForestClassifier
+from sklearn.model_selection import train_test_split
+from sklearn.metrics import classification_report, accuracy_score
+
+
+np.set_printoptions(threshold=10000)
+
+
+def get_embedding_from_api(word, model="vicuna-7b-v1.1"):
+ if "ada" in model:
+ resp = openai.Embedding.create(
+ model=model,
+ input=word,
+ )
+ embedding = np.array(resp["data"][0]["embedding"])
+ return embedding
+
+ url = "http://localhost:8000/v1/embeddings"
+ headers = {"Content-Type": "application/json"}
+ data = json.dumps({"model": model, "input": word})
+
+ response = requests.post(url, headers=headers, data=data)
+ if response.status_code == 200:
+ embedding = np.array(response.json()["data"][0]["embedding"])
+ return embedding
+ else:
+ print(f"Error: {response.status_code} - {response.text}")
+ return None
+
+
+def create_embedding_data_frame(data_path, model, max_tokens=500):
+ df = pd.read_csv(data_path, index_col=0)
+ df = df[["Time", "ProductId", "UserId", "Score", "Summary", "Text"]]
+ df = df.dropna()
+ df["combined"] = (
+ "Title: " + df.Summary.str.strip() + "; Content: " + df.Text.str.strip()
+ )
+ top_n = 1000
+ df = df.sort_values("Time").tail(top_n * 2)
+ df.drop("Time", axis=1, inplace=True)
+
+ df["n_tokens"] = df.combined.apply(lambda x: len(x))
+ df = df[df.n_tokens <= max_tokens].tail(top_n)
+ df["embedding"] = df.combined.apply(lambda x: get_embedding_from_api(x, model))
+ return df
+
+
+def train_random_forest(df):
+ X_train, X_test, y_train, y_test = train_test_split(
+ list(df.embedding.values), df.Score, test_size=0.2, random_state=42
+ )
+
+ clf = RandomForestClassifier(n_estimators=100)
+ clf.fit(X_train, y_train)
+ preds = clf.predict(X_test)
+
+ report = classification_report(y_test, preds)
+ accuracy = accuracy_score(y_test, preds)
+ return clf, accuracy, report
+
+
+input_datapath = "amazon_fine_food_review.csv"
+if not os.path.exists(input_datapath):
+ raise Exception(
+ f"Please download data from: https://www.kaggle.com/datasets/snap/amazon-fine-food-reviews"
+ )
+
+df = create_embedding_data_frame(input_datapath, "vicuna-7b-v1.1")
+clf, accuracy, report = train_random_forest(df)
+print(f"Vicuna-7b-v1.1 accuracy:{accuracy}")
+df = create_embedding_data_frame(input_datapath, "text-similarity-ada-001")
+clf, accuracy, report = train_random_forest(df)
+print(f"text-similarity-ada-001 accuracy:{accuracy}")
+df = create_embedding_data_frame(input_datapath, "text-embedding-ada-002")
+clf, accuracy, report = train_random_forest(df)
+print(f"text-embedding-ada-002 accuracy:{accuracy}")
diff --git a/playground/test_embedding/test_semantic_search.py b/playground/test_embedding/test_semantic_search.py
new file mode 100644
index 0000000..879b240
--- /dev/null
+++ b/playground/test_embedding/test_semantic_search.py
@@ -0,0 +1,99 @@
+import json
+import os
+
+import numpy as np
+import openai
+import pandas as pd
+import requests
+from scipy.spatial.distance import cosine
+
+
+def cosine_similarity(vec1, vec2):
+ try:
+ return 1 - cosine(vec1, vec2)
+ except:
+ print(vec1.shape, vec2.shape)
+
+
+def get_embedding_from_api(word, model="vicuna-7b-v1.1"):
+ if "ada" in model:
+ resp = openai.Embedding.create(
+ model=model,
+ input=word,
+ )
+ embedding = np.array(resp["data"][0]["embedding"])
+ return embedding
+
+ url = "http://localhost:8000/v1/embeddings"
+ headers = {"Content-Type": "application/json"}
+ data = json.dumps({"model": model, "input": word})
+
+ response = requests.post(url, headers=headers, data=data)
+ if response.status_code == 200:
+ embedding = np.array(response.json()["data"][0]["embedding"])
+ return embedding
+ else:
+ print(f"Error: {response.status_code} - {response.text}")
+ return None
+
+
+def create_embedding_data_frame(data_path, model, max_tokens=500):
+ df = pd.read_csv(data_path, index_col=0)
+ df = df[["Time", "ProductId", "UserId", "Score", "Summary", "Text"]]
+ df = df.dropna()
+ df["combined"] = (
+ "Title: " + df.Summary.str.strip() + "; Content: " + df.Text.str.strip()
+ )
+ top_n = 1000
+ df = df.sort_values("Time").tail(top_n * 2)
+ df.drop("Time", axis=1, inplace=True)
+
+ df["n_tokens"] = df.combined.apply(lambda x: len(x))
+ df = df[df.n_tokens <= max_tokens].tail(top_n)
+ df["embedding"] = df.combined.apply(lambda x: get_embedding_from_api(x, model))
+ return df
+
+
+def search_reviews(df, product_description, n=3, pprint=False, model="vicuna-7b-v1.1"):
+ product_embedding = get_embedding_from_api(product_description, model=model)
+ df["similarity"] = df.embedding.apply(
+ lambda x: cosine_similarity(x, product_embedding)
+ )
+
+ results = (
+ df.sort_values("similarity", ascending=False)
+ .head(n)
+ .combined.str.replace("Title: ", "")
+ .str.replace("; Content:", ": ")
+ )
+ if pprint:
+ for r in results:
+ print(r[:200])
+ print()
+ return results
+
+
+def print_model_search(input_path, model):
+ print(f"Model: {model}")
+ df = create_embedding_data_frame(input_path, model)
+ print("search: delicious beans")
+ results = search_reviews(df, "delicious beans", n=5, model=model)
+ print(results)
+ print("search: whole wheat pasta")
+ results = search_reviews(df, "whole wheat pasta", n=5, model=model)
+ print(results)
+ print("search: bad delivery")
+ results = search_reviews(df, "bad delivery", n=5, model=model)
+ print(results)
+
+
+input_datapath = "amazon_fine_food_review.csv"
+if not os.path.exists(input_datapath):
+ raise Exception(
+ f"Please download data from: https://www.kaggle.com/datasets/snap/amazon-fine-food-reviews"
+ )
+
+
+print_model_search(input_datapath, "vicuna-7b-v1.1")
+print_model_search(input_datapath, "text-similarity-ada-001")
+print_model_search(input_datapath, "text-embedding-ada-002")
diff --git a/playground/test_embedding/test_sentence_similarity.py b/playground/test_embedding/test_sentence_similarity.py
new file mode 100644
index 0000000..0b9a540
--- /dev/null
+++ b/playground/test_embedding/test_sentence_similarity.py
@@ -0,0 +1,67 @@
+import json
+import os
+
+import numpy as np
+import openai
+import requests
+from scipy.spatial.distance import cosine
+
+
+def get_embedding_from_api(word, model="vicuna-7b-v1.1"):
+ if "ada" in model:
+ resp = openai.Embedding.create(
+ model=model,
+ input=word,
+ )
+ embedding = np.array(resp["data"][0]["embedding"])
+ return embedding
+
+ url = "http://localhost:8000/v1/embeddings"
+ headers = {"Content-Type": "application/json"}
+ data = json.dumps({"model": model, "input": word})
+
+ response = requests.post(url, headers=headers, data=data)
+ if response.status_code == 200:
+ embedding = np.array(response.json()["data"][0]["embedding"])
+ return embedding
+ else:
+ print(f"Error: {response.status_code} - {response.text}")
+ return None
+
+
+def cosine_similarity(vec1, vec2):
+ return 1 - cosine(vec1, vec2)
+
+
+def print_cosine_similarity(embeddings, texts):
+ for i in range(len(texts)):
+ for j in range(i + 1, len(texts)):
+ sim = cosine_similarity(embeddings[texts[i]], embeddings[texts[j]])
+ print(f"Cosine similarity between '{texts[i]}' and '{texts[j]}': {sim:.2f}")
+
+
+texts = [
+ "The quick brown fox",
+ "The quick brown dog",
+ "The fast brown fox",
+ "A completely different sentence",
+]
+
+embeddings = {}
+for text in texts:
+ embeddings[text] = get_embedding_from_api(text)
+
+print("Vicuna-7B:")
+print_cosine_similarity(embeddings, texts)
+
+for text in texts:
+ embeddings[text] = get_embedding_from_api(text, model="text-similarity-ada-001")
+
+print("text-similarity-ada-001:")
+print_cosine_similarity(embeddings, texts)
+
+for text in texts:
+ embeddings[text] = get_embedding_from_api(text, model="text-embedding-ada-002")
+
+print("text-embedding-ada-002:")
+print_cosine_similarity(embeddings, texts)
diff --git a/playground/test_openai_api/anthropic_api.py b/playground/test_openai_api/anthropic_api.py
new file mode 100644
index 0000000..c60a701
--- /dev/null
+++ b/playground/test_openai_api/anthropic_api.py
@@ -0,0 +1,27 @@
+import os
+
+from fastchat.model import get_conversation_template
+
+
+def claude():
+ import anthropic
+ c = anthropic.Client(os.environ["ANTHROPIC_API_KEY"])
+
+ model = "claude-v1"
+ conv = get_conversation_template(model)
+ conv.append_message(conv.roles[0], "Hello!")
+ conv.append_message(conv.roles[1], None)
+ prompt = conv.get_prompt()
+
+ response = c.completion_stream(
+ prompt=prompt,
+ stop_sequences=[anthropic.HUMAN_PROMPT],
+ max_tokens_to_sample=256,
+ model=model,
+ stream=True,
+ )
+ for data in response:
+ print(data["completion"])
+
+
+claude()
diff --git a/playground/test_openai_api/openai_api.py b/playground/test_openai_api/openai_api.py
new file mode 100644
index 0000000..a36c680
--- /dev/null
+++ b/playground/test_openai_api/openai_api.py
@@ -0,0 +1,26 @@
+import os
+
+from fastchat.model import get_conversation_template
+
+def chatgpt():
+ import openai
+ model = "gpt-3.5-turbo"
+ conv = get_conversation_template(model)
+ conv.append_message(conv.roles[0], "Hello!")
+ conv.append_message(conv.roles[1], None)
+
+ messages = conv.to_openai_api_messages()
+ print(messages)
+
+ res = openai.ChatCompletion.create(model=model, messages=messages)
+ msg = res["choices"][0]["message"]["content"]
+ print(msg)
+
+ res = openai.ChatCompletion.create(model=model, messages=messages, stream=True)
+ msg = ""
+ for chunk in res:
+ msg += chunk["choices"][0]["delta"].get("content", "")
+ print(msg)
+
+
+chatgpt()
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..a6b1f63
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,209 @@
+absl-py==1.4.0
+accelerate==0.21.0
+aiofiles==23.1.0
+aiohttp==3.8.4
+aiosignal==1.3.1
+altair==5.0.1
+anthropic==0.3.9
+anyio==3.7.0
+appdirs==1.4.4
+aspy.yaml==1.3.0
+astor==0.8.1
+astroid==2.5.1
+asttokens==2.2.1
+async-timeout==4.0.2
+attrs==20.3.0
+backcall==0.2.0
+bitsandbytes==0.39.1
+black==23.3.0
+blessed==1.20.0
+blinker==1.6.2
+cached-property==1.5.2
+cachetools==5.3.1
+certifi==2020.12.5
+cfgv==3.2.0
+chardet==4.0.0
+charset-normalizer==3.1.0
+click==8.1.3
+contourpy==1.1.0
+cpm-kernels==1.0.11
+cycler==0.11.0
+datasets==2.10.1
+decorator==4.4.2
+deepspeed==0.10.0
+dill==0.3.4
+distlib==0.3.1
+distro==1.8.0
+docker-pycreds==0.4.0
+einops==0.6.1
+evaluate==0.4.0
+exceptiongroup==1.1.1
+executing==1.2.0
+fastapi==0.98.0
+ffmpy==0.3.0
+filelock==3.0.12
+fire==0.5.0
+flash-attn==1.0.4
+Flask==2.3.2
+Flask-Cors==4.0.0
+fonttools==4.40.0
+frozenlist==1.3.3
+fsspec==2023.6.0
+ftfy==6.1.1
+gast==0.4.0
+gitdb==4.0.10
+GitPython==3.1.31
+google-auth==2.22.0
+google-auth-oauthlib==1.0.0
+gpustat==1.1
+gradio==3.23.0
+grpcio==1.56.2
+h11==0.14.0
+hjson==3.1.0
+httpcore==0.17.2
+httpx==0.24.1
+huggingface-hub==0.15.1
+icetk==0.0.7
+identify==2.2.0
+idna==2.10
+importlib-metadata==6.8.0
+importlib-resources==5.12.0
+iniconfig==1.1.1
+ipykernel==4.6.0
+ipython==8.13.0
+ipython-genutils==0.2.0
+isort==5.8.0
+itsdangerous==2.1.2
+jedi==0.18.2
+Jinja2==3.1.2
+joblib==1.3.1
+jsonschema==4.17.3
+jupyter-client==6.1.12
+jupyter-core==4.7.1
+kiwisolver==1.4.4
+lazy-object-proxy==1.5.2
+linkify-it-py==2.0.2
+loguru==0.7.0
+loralib==0.1.1
+Markdown==3.4.4
+markdown-it-py==2.2.0
+markdown2==2.4.9
+MarkupSafe==2.1.3
+matplotlib==3.7.1
+matplotlib-inline==0.1.6
+mccabe==0.6.1
+mdit-py-plugins==0.3.3
+mdurl==0.1.2
+mpi4py==3.1.4
+msgpack==1.0.5
+multidict==6.0.4
+multiprocess==0.70.12.2
+mypy-extensions==1.0.0
+nh3==0.2.13
+ninja==1.11.1
+nodeenv==1.5.0
+numpy==1.24.2
+nvidia-ml-py==12.535.77
+oauthlib==3.2.2
+openai==0.27.8
+orjson==3.9.1
+packaging==23.1
+pandas==2.0.3
+parso==0.8.3
+pathspec==0.11.1
+pathtools==0.1.2
+peft==0.4.0
+pexpect==4.8.0
+pickleshare==0.7.5
+Pillow==8.1.2
+pkgutil_resolve_name==1.3.10
+platformdirs==3.8.0
+pluggy==0.13.1
+pre-commit==1.10.4
+prompt-toolkit==3.0.38
+protobuf==3.20.0
+psutil==5.9.5
+ptyprocess==0.7.0
+pure-eval==0.2.2
+py==1.10.0
+py-cpuinfo==9.0.0
+pyarrow==12.0.1
+pyasn1==0.5.0
+pyasn1-modules==0.3.0
+pydantic==1.10.9
+pydub==0.25.1
+pyg-lib==0.2.0
+Pygments==2.15.1
+PyGObject==3.26.1
+pylint==2.7.2
+pyparsing==2.4.7
+pyrsistent==0.19.3
+pytest==6.2.2
+python-apt==1.6.5+ubuntu0.7
+python-dateutil==2.8.2
+python-multipart==0.0.6
+pytz==2023.3
+PyYAML==5.4.1
+pyzmq==22.0.3
+ray==2.6.1
+regex==2023.6.3
+requests==2.31.0
+requests-oauthlib==1.3.1
+responses==0.18.0
+rich==13.4.2
+rsa==4.9
+safetensors==0.3.1
+scikit-learn==1.2.2
+scipy==1.10.1
+semantic-version==2.10.0
+sentencepiece==0.1.99
+sentry-sdk==1.26.0
+setproctitle==1.3.2
+shortuuid==1.0.11
+simplegeneric==0.8.1
+six==1.15.0
+smmap==5.0.0
+sniffio==1.3.0
+ssh-import-id==5.7
+stack-data==0.6.2
+starlette==0.27.0
+svgwrite==1.4.3
+tensorboard==2.13.0
+tensorboard-data-server==0.7.1
+tensorboardX==2.6.1
+termcolor==2.3.0
+threadpoolctl==3.2.0
+tiktoken==0.4.0
+tokenize-rt==5.1.0
+tokenizers==0.13.3
+toml==0.10.2
+tomli==2.0.1
+toolz
+# torch==1.13.0
+torch-cluster==1.6.1
+torch-geometric==2.3.1
+torch-scatter==2.1.0
+torch-sparse==0.6.16
+torch-spline-conv==1.2.2
+# torchvision==0.14.0
+tornado==6.1
+tqdm
+traitlets==5.0.5
+transformers==4.31.0
+trl
+typing_extensions==4.7.0
+tzdata==2023.3
+uc-micro-py==1.0.2
+unattended-upgrades==0.1
+urllib3
+uvicorn==0.22.0
+virtualenv==20.4.3
+wandb==0.15.11
+wavedrom==2.0.3.post3
+wcwidth==0.2.5
+websockets==11.0.3
+Werkzeug==2.3.6
+wrapt==1.12.1
+xxhash==3.2.0
+yarl==1.9.2
+zipp==3.15.0
diff --git a/scripts/eval_script/graphgpt_eval.sh b/scripts/eval_script/graphgpt_eval.sh
new file mode 100644
index 0000000..412fcf6
--- /dev/null
+++ b/scripts/eval_script/graphgpt_eval.sh
@@ -0,0 +1,10 @@
+# to fill in the following path to extract projector for the second tuning stage!
+output_model=
+datapath=
+graph_data_path=
+res_path=
+start_id=
+end_id=
+num_gpus=
+
+python3.8 ./graphgpt/eval/run_graphgpt.py --model-name ${output_model} --prompting_file ${datapath} --graph_data_path ${graph_data_path} --output_res_path ${res_path} --start_id ${start_id} --end_id ${end_id} --num_gpus ${num_gpus}
\ No newline at end of file
diff --git a/scripts/extract_graph_projector.py b/scripts/extract_graph_projector.py
new file mode 100644
index 0000000..6060a2f
--- /dev/null
+++ b/scripts/extract_graph_projector.py
@@ -0,0 +1,42 @@
+import os
+import argparse
+import torch
+import json
+from collections import defaultdict
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description='Extract MMProjector weights')
+ parser.add_argument('--model_name_or_path', type=str, help='model folder')
+ parser.add_argument('--output', type=str, help='output file')
+ args = parser.parse_args()
+ return args
+
+
+if __name__ == '__main__':
+ args = parse_args()
+
+ keys_to_match = ['graph_projector', 'embed_tokens', 'transformer.wte']
+ ckpt_to_key = defaultdict(list)
+ try:
+ model_indices = json.load(open(os.path.join(args.model_name_or_path, 'pytorch_model.bin.index.json')))
+ for k, v in model_indices['weight_map'].items():
+ if any(key_match in k for key_match in keys_to_match):
+ ckpt_to_key[v].append(k)
+ except FileNotFoundError:
+ # Smaller models or model checkpoints saved by DeepSpeed.
+ v = 'pytorch_model.bin'
+ for k in torch.load(os.path.join(args.model_name_or_path, v), map_location='cpu').keys():
+ if any(key_match in k for key_match in keys_to_match):
+ ckpt_to_key[v].append(k)
+
+ loaded_weights = {}
+
+ for ckpt_name, weight_keys in ckpt_to_key.items():
+ ckpt = torch.load(os.path.join(args.model_name_or_path, ckpt_name), map_location='cpu')
+ for k in weight_keys:
+ loaded_weights[k] = ckpt[k]
+
+ print(loaded_weights.keys())
+
+ torch.save(loaded_weights, args.output)
diff --git a/scripts/serving/controller.yaml b/scripts/serving/controller.yaml
new file mode 100644
index 0000000..35c3a88
--- /dev/null
+++ b/scripts/serving/controller.yaml
@@ -0,0 +1,29 @@
+resources:
+ cloud: gcp
+ region: us-central1
+
+num_nodes: 1
+
+workdir: .
+
+file_mounts:
+ ~/chatlogs:
+ name: skypilot-chatbot-logs
+ store: gcs
+ mode: MOUNT
+
+setup: |
+ conda activate chatbot
+ if [ $? -eq 0 ]; then
+ echo 'conda env exists'
+ else
+ # Setup the environment
+ conda create -n chatbot python=3.10 -y
+ conda activate chatbot
+ pip3 install -e .
+ fi
+
+run: |
+ conda activate chatbot
+ python3 -m fastchat.serve.controller --host 0.0.0.0 --port 21001 &
+ python3 -m fastchat.serve.gradio_web_server --share
diff --git a/scripts/serving/model_worker.yaml b/scripts/serving/model_worker.yaml
new file mode 100644
index 0000000..d241e74
--- /dev/null
+++ b/scripts/serving/model_worker.yaml
@@ -0,0 +1,57 @@
+resources:
+ accelerators: A100:1
+ cloud: gcp
+ region: us-central1
+
+num_nodes: 1
+
+workdir: .
+
+file_mounts:
+ /artifacts:
+ name: skypilot-chatbot
+ store: gcs
+ mode: MOUNT
+
+ ~/chatlogs:
+ name: skypilot-chatbot-logs
+ store: gcs
+ mode: MOUNT
+
+setup: |
+ conda activate chatbot
+ if [ $? -eq 0 ]; then
+ echo 'conda env exists'
+ else
+ # Setup the environment
+ conda create -n chatbot python=3.10 -y
+ conda activate chatbot
+
+ pip3 install -e .
+
+ # Install pytorch
+ pip install torch==1.13.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116
+
+ # Install huggingface with the LLaMA commit
+ pip install git+https://github.com/huggingface/transformers
+
+ # Install alpaca
+ git clone https://github.com/tatsu-lab/stanford_alpaca.git
+ cd stanford_alpaca
+ pip install -r requirements.txt
+ cd -
+ fi
+
+ ln -s /artifacts/chatbot/13b/ckpt/ ~/alpaca-13b
+
+run: |
+ conda activate chatbot
+ WORKER_IP=$(hostname -I | cut -d' ' -f1)
+ CONTROLLER_PORT=21001
+ WORKER_PORT=21002
+ python3 -m fastchat.serve.model_worker \
+ --model ~/alpaca-13b \
+ --controller-address http://${CONTROLLER_IP}:${CONTROLLER_PORT} \
+ --worker-address http://${WORKER_IP}:${WORKER_PORT} \
+ --host 0.0.0.0 \
+ --port ${WORKER_PORT}
diff --git a/scripts/tune_script/extract_projector.sh b/scripts/tune_script/extract_projector.sh
new file mode 100644
index 0000000..6012a76
--- /dev/null
+++ b/scripts/tune_script/extract_projector.sh
@@ -0,0 +1,7 @@
+# to fill in the following path to extract projector for the second tuning stage!
+src_model=
+output_proj=
+
+python3.8 ./scripts/extract_graph_projector.py \
+ --model_name_or_path ${src_model} \
+ --output ${output_proj}
\ No newline at end of file
diff --git a/scripts/tune_script/graphgpt_stage1.sh b/scripts/tune_script/graphgpt_stage1.sh
new file mode 100644
index 0000000..075f179
--- /dev/null
+++ b/scripts/tune_script/graphgpt_stage1.sh
@@ -0,0 +1,39 @@
+# to fill in the following path to run the first stage of our GraphGPT!
+model_path=
+instruct_ds=
+graph_data_path=
+pretra_gnn=
+output_model=
+
+wandb offline
+python -m torch.distributed.run --nnodes=1 --nproc_per_node=4 --master_port=20001 \
+ graphgpt/train/train_mem.py \
+ --model_name_or_path ${model_path} \
+ --version v1 \
+ --data_path ${instruct_ds} \
+ --graph_content ./arxiv_ti_ab.json \
+ --graph_data_path ${graph_data_path} \
+ --graph_tower ${pretra_gnn} \
+ --tune_graph_mlp_adapter True \
+ --graph_select_layer -2 \
+ --use_graph_start_end \
+ --bf16 True \
+ --output_dir ${output_model} \
+ --num_train_epochs 3 \
+ --per_device_train_batch_size 2 \
+ --per_device_eval_batch_size 2 \
+ --gradient_accumulation_steps 1 \
+ --evaluation_strategy "no" \
+ --save_strategy "steps" \
+ --save_steps 2400 \
+ --save_total_limit 1 \
+ --learning_rate 2e-3 \
+ --weight_decay 0. \
+ --warmup_ratio 0.03 \
+ --lr_scheduler_type "cosine" \
+ --logging_steps 1 \
+ --tf32 True \
+ --model_max_length 2048 \
+ --gradient_checkpointing True \
+ --lazy_preprocess True \
+ --report_to wandb
diff --git a/scripts/tune_script/graphgpt_stage2.sh b/scripts/tune_script/graphgpt_stage2.sh
new file mode 100644
index 0000000..5bc3a9a
--- /dev/null
+++ b/scripts/tune_script/graphgpt_stage2.sh
@@ -0,0 +1,42 @@
+# to fill in the following path to run the second stage of our GraphGPT!
+model_path=
+instruct_ds=
+graph_data_path=
+pretra_gnn=
+tuned_proj=
+output_model=
+
+wandb offline
+python -m torch.distributed.run --nnodes=1 --nproc_per_node=4 --master_port=20001 \
+ graphgpt/train/train_mem.py \
+ --model_name_or_path ${model_path} \
+ --version v1 \
+ --data_path ${instruct_ds} \
+ --graph_content ./arxiv_ti_ab.json \
+ --graph_data_path ${graph_data_path} \
+ --graph_tower ${pretra_gnn} \
+ --pretrain_graph_mlp_adapter ${tuned_proj} \
+ --tune_graph_mlp_adapter True \
+ --graph_select_layer -2 \
+ --use_graph_start_end True\
+ --bf16 True \
+ --output_dir ${output_model} \
+ --num_train_epochs 2 \
+ --per_device_train_batch_size 1 \
+ --per_device_eval_batch_size 1 \
+ --gradient_accumulation_steps 1 \
+ --evaluation_strategy "no" \
+ --save_strategy "steps" \
+ --save_steps 50000 \
+ --save_total_limit 1 \
+ --learning_rate 2e-5 \
+ --weight_decay 0. \
+ --warmup_ratio 0.03 \
+ --lr_scheduler_type "cosine" \
+ --logging_steps 1 \
+ --tf32 True \
+ --model_max_length 2048 \
+ --gradient_checkpointing True \
+ --dataloader_num_workers 4 \
+ --lazy_preprocess True \
+ --report_to wandb
diff --git a/tests/test_openai_curl.sh b/tests/test_openai_curl.sh
new file mode 100644
index 0000000..48766f3
--- /dev/null
+++ b/tests/test_openai_curl.sh
@@ -0,0 +1,32 @@
+set -x
+
+curl http://localhost:8000/v1/models
+
+echo
+
+curl http://localhost:8000/v1/chat/completions \
+ -H "Content-Type: application/json" \
+ -d '{
+ "model": "vicuna-7b-v1.1",
+ "messages": [{"role": "user", "content": "Hello! What is your name?"}]
+ }'
+
+echo
+
+curl http://localhost:8000/v1/completions \
+ -H "Content-Type: application/json" \
+ -d '{
+ "model": "vicuna-7b-v1.1",
+ "prompt": "Once upon a time",
+ "max_tokens": 41,
+ "temperature": 0.5
+ }'
+
+echo
+
+curl http://localhost:8000/v1/embeddings \
+ -H "Content-Type: application/json" \
+ -d '{
+ "model": "vicuna-7b-v1.1",
+ "input": "Hello world!"
+ }'
diff --git a/tests/test_openai_langchain.py b/tests/test_openai_langchain.py
new file mode 100644
index 0000000..78b7f9f
--- /dev/null
+++ b/tests/test_openai_langchain.py
@@ -0,0 +1,38 @@
+# export OPENAI_API_BASE=http://localhost:8000/v1
+# export OPENAI_API_KEY=EMPTY
+
+from langchain import OpenAI, LLMChain, PromptTemplate
+from langchain.memory import ConversationBufferWindowMemory
+from langchain.embeddings import OpenAIEmbeddings
+import numpy as np
+
+template = """{history}
+Human: {human_input}
+Assistant:"""
+
+def test_embedding():
+ embeddings = OpenAIEmbeddings()
+ texts = ["Why does the chicken cross the road", "To be honest", "Long time ago"]
+ query_result = embeddings.embed_query(texts[0])
+ doc_result = embeddings.embed_documents(texts)
+ assert np.allclose(query_result, doc_result[0], atol=1e-3)
+
+def test_chain():
+
+ prompt = PromptTemplate(
+ input_variables=["history", "human_input"],
+ template=template
+ )
+ chain = LLMChain(
+ llm=OpenAI(model="text-embedding-ada-002", temperature=1),
+ prompt=prompt,
+ verbose=True,
+ memory=ConversationBufferWindowMemory(k=2),
+ )
+ output = chain.predict(human_input="ls ~")
+ print(output)
+
+if __name__ == "__main__":
+ test_embedding()
+ test_chain()
+
diff --git a/tests/test_openai_sdk.py b/tests/test_openai_sdk.py
new file mode 100644
index 0000000..197838b
--- /dev/null
+++ b/tests/test_openai_sdk.py
@@ -0,0 +1,46 @@
+import openai
+
+openai.api_key = "EMPTY" # Not support yet
+openai.api_base = "http://localhost:8000/v1"
+
+model = "vicuna-7b-v1.1"
+
+
+def test_list_models():
+ model_list = openai.Model.list()
+ print(model_list["data"][0]["id"])
+
+
+def test_completion():
+ prompt = "Once upon a time"
+ completion = openai.Completion.create(model=model, prompt=prompt, max_tokens=64)
+ print(prompt + completion.choices[0].text)
+
+
+def test_embedding():
+ embedding = openai.Embedding.create(model=model, input="Hello world!")
+ print(len(embedding["data"][0]["embedding"]))
+
+
+def test_chat_completion():
+ completion = openai.ChatCompletion.create(
+ model=model, messages=[{"role": "user", "content": "Hello! What is your name?"}]
+ )
+ print(completion.choices[0].message.content)
+
+
+def test_chat_completion_stream():
+ messages = [{"role": "user", "content": "Hello! What is your name?"}]
+ res = openai.ChatCompletion.create(model=model, messages=messages, stream=True)
+ for chunk in res:
+ content = chunk["choices"][0]["delta"].get("content", "")
+ print(content, end="", flush=True)
+ print()
+
+
+if __name__ == "__main__":
+ test_list_models()
+ test_completion()
+ test_embedding()
+ test_chat_completion()
+ test_chat_completion_stream()