diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000..ebd09f8 Binary files /dev/null and b/.DS_Store differ diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..261eeb9 --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md new file mode 100644 index 0000000..b260f52 --- /dev/null +++ b/README.md @@ -0,0 +1,407 @@ +# GraphGPT: Graph Instruction Tuning for Large Language Models +[Jiabin Tang](https://tjb-tech.github.io/), [Yuhao Yang](http://yuh-yang.github.io), [Wei Wei](#), [Lei Shi](#), [Lixin Su](#), [Suqi Cheng](#), [Dawei Yin](https://www.yindawei.com/) and [Chao Huang](https://sites.google.com/view/chaoh/home)*. +(*Correspondence ) + +**[Data Intelligence Lab](https://sites.google.com/view/chaoh/home)@[University of Hong Kong](https://www.hku.hk/)**, Baidu Inc. + +----- + + + + +[![YouTube](https://badges.aleen42.com/src/youtube.svg)](#) + + +This repository hosts the code, data and model weight of **GraphGPT**. + +----------- + +## 🎉 News + +- [x] [2023.10.15] 🚀🚀 Release the code of GraphGPT. + + +## 👉 TODO +- [ ] Release our utilized Instruction data. +- [ ] Release checkpoints of our GraphGPT and pre-trained graph encoder. +- [ ] Exploring the potential of our GraphGPT for more graph learning tasks. +- [ ] ... + +----------- + + + + + + +## Brief Introduction + + +we present the **GraphGPT** framework that aligns LLMs with graph structural knowledge with a graph instruction tuning paradigm. + + +- **Structural Information Encoding with Text-Graph Grounding.** To enhance the understanding of graph structural information by large language models, our framework emphasizes aligning the encoding of graph structures with the natural language space. This alignment aims to enable language models to effectively comprehend and interpret the structural elements of the graph, leveraging their inherent language understanding capabilities. To achieve this objective, we introduce a text-graph grounding paradigm that generates prompts designed to preserve the graph’s structural context for language models. This paradigm acts as a bridge, connecting the semantic understanding of textual information with the inherent structural relationships found within the graph. +- **Dual-Stage Graph Instruction Tuning.** The dual-stage graph instruction tuning paradigm proposed in this work builds upon the concept of instruction tuning, which has been recently introduced to enhance the adaptability of language models for specific domains. In this paradigm, we aim to align the language capacity of the model with the nuances of graph learning tasks, enabling the language model to generate more accurate and contextually appropriate responses for graph-structured data. +- **Chain-of-Thought (CoT) Distillation.** When faced with diverse graph data, language models may encounter new or unfamiliar patterns and structures. This distribution shift can pose challenges in generating accurate and coherent responses, especially when the number of node classes varies across different types of graph data. To address this challenge and boost accuracy in the presence of distribution shift, it is essential to equip our GraphGPT with step-by-step reasoning abilities. In this regard, we propose utilizing the Chain-of-Thought (COT) technique [47], which explicitly models the flow of thoughts and reasoning steps. By incorporating COT, our language model improves the coherence and consistency of generated text. It enables the model to follow a logical progression of ideas, enhancing its ability to understand and reason about the given graph data. + + +For more technical details, kindly refer to the [paper]() and the project [website](https://graphgpt.github.io/) of our Graph. + + +----------- + + + +## Getting Started + + + +### Table of Contents: +* 1. Code Structure +* 2. Environment Preparation +* 3. Training GraphGPT + * 3.1. Prepare Pre-trained Checkpoint + * 3.2. Self-Supervised Instruction Tuning + * 3.3. Extract the Trained Projector + * 3.4. Task-Specific Instruction Tuning +* 4. Evaluating GraphGPT + * 4.1. Preparing Checkpoints and Data + * 4.2. Running Evaluation + +**** + + + + + +### 1. Code Structure + +``` +. +├── README.md +├── assets +│   ├── demo_narrow.gif +│   ├── screenshot_cli.png +│   ├── screenshot_gui.png +│   ├── server_arch.png +│   └── vicuna_logo.jpeg +├── format.sh +├── graphgpt +│   ├── __init__.py +│   ├── constants.py +│   ├── conversation.py +│   ├── eval +│   │   ├── README.md +│   │   ├── requirements.txt +│   │   ├── run_graphgpt.py +│   │   ├── run_graphgpt_LP.py +│   │   ├── run_vicuna.py +│   │   └── script +│   │   └── run_model_qa.yaml +│   ├── model +│   │   ├── GraphLlama.py +│   │   ├── __init__.py +│   │   ├── apply_delta.py +│   │   ├── apply_lora.py +│   │   ├── builder.py +│   │   ├── compression.py +│   │   ├── convert_fp16.py +│   │   ├── graph_layers +│   │   │   ├── __init__.py +│   │   │   ├── bpe_simple_vocab_16e6.txt.gz +│   │   │   ├── clip_graph.py +│   │   │   ├── graph_transformer.py +│   │   │   ├── mpnn.py +│   │   │   └── simple_tokenizer.py +│   │   ├── make_delta.py +│   │   ├── model_adapter.py +│   │   ├── model_registry.py +│   │   ├── monkey_patch_non_inplace.py +│   │   └── utils.py +│   ├── protocol +│   │   └── openai_api_protocol.py +│   ├── serve +│   │   ├── __init__.py +│   │   ├── api_provider.py +│   │   ├── bard_worker.py +│   │   ├── cacheflow_worker.py +│   │   ├── cli.py +│   │   ├── controller.py +│   │   ├── gateway +│   │   │   ├── README.md +│   │   │   └── nginx.conf +│   │   ├── gradio_block_arena_anony.py +│   │   ├── gradio_block_arena_named.py +│   │   ├── gradio_css.py +│   │   ├── gradio_patch.py +│   │   ├── gradio_web_server.py +│   │   ├── gradio_web_server_multi.py +│   │   ├── huggingface_api.py +│   │   ├── inference.py +│   │   ├── model_worker.py +│   │   ├── monitor +│   │   │   ├── basic_stats.py +│   │   │   ├── clean_battle_data.py +│   │   │   ├── elo_analysis.py +│   │   │   ├── hf_space_leaderboard_app.py +│   │   │   └── monitor.py +│   │   ├── openai_api_server.py +│   │   ├── register_worker.py +│   │   ├── test_message.py +│   │   └── test_throughput.py +│   ├── train +│   │   ├── graphchat_trainer.py +│   │   ├── llama_flash_attn_monkey_patch.py +│   │   ├── train_graph.py +│   │   ├── train_lora.py +│   │   └── train_mem.py +│   └── utils.py +├── playground +│   ├── inspect_conv.py +│   ├── test_embedding +│   │   ├── README.md +│   │   ├── test_classification.py +│   │   ├── test_semantic_search.py +│   │   └── test_sentence_similarity.py +│   └── test_openai_api +│   ├── anthropic_api.py +│   └── openai_api.py +├── pyproject.toml +├── scripts +│   ├── eval_script +│   │   └── graphgpt_eval.sh +│   ├── extract_graph_projector.py +│   ├── serving +│   │   ├── controller.yaml +│   │   └── model_worker.yaml +│   └── tune_script +│   ├── extract_projector.sh +│   ├── graphgpt_stage1.sh +│   └── graphgpt_stage2.sh +└── tests + ├── test_openai_curl.sh + ├── test_openai_langchain.py + └── test_openai_sdk.py +``` + + + + + +### 2. Environment Preparation [Back to Top] +Please first clone the repo and install the required environment, which can be done by running the following commands: +```shell +conda env create -n graphgpt python=3.8 + +conda activate graphgpt + +# Torch with CUDA 11.7 +pip install torch==1.13.0+cu117 torchvision==0.14.0+cu117 torchaudio==0.13.0 --extra-index-url https://download.pytorch.org/whl/cu117 +# To support vicuna base model +pip3 install "fschat[model_worker,webui]" +# Clone our GraphGPT +git clone https://github.com/HKUDS/GraphGPT.git +cd GraphGPT +# Install required libaries +pip install -r requirements.txt +``` + + + +### 3. Training GraphGPT + +GraphGPT tuning paradigm consists of two stages: (1) self-supervised instruction tuning: use approximately 600K filtered CC3M to connect a *frozen pretrained* vision encoder to a *frozen LLM*; (2) task-specific instruction tuning: use 150K GPT-generated multimodal instruction-following to teach the model to follow multimodal instructions. + + + +#### 3.1. Preparing Pre-trained Checkpoint [Back to Top] +GraphGPT is trained based on following excellent existing models. +Please follow the instructions to prepare the checkpoints. + +- `Vicuna`: + Prepare our base model Vicuna, which is an instruction-tuned chatbot and base model in our implementation. Please download its weights [here](https://github.com/lm-sys/FastChat#model-weights). We generally utilize v1.1 and v1.5 model with 7B parameters. + +- `Graph Encoder`: + is used to encode graph structures. We empoly text-graph grounding approach to obtain the pre-trained graph transformer model, which you could download by [graph transformer]() and put it at [[./GraphGPT]](./GraphGPT) + +- `Graph Data`: + + is a combination of all utilized pyg graph data that contain node features, edge_index and son on. You can download by [all_graph_data.pt]() and put it at [[./GraphGPT/graph_data]](./GraphGPT/graph_data) + + + +#### 3.2. Self-Supervised Instruction Tuning [Back to Top] + +* **Prepare data:** Please download our instruction tuning data [graph_matching.json](https://huggingface.co/datasets/Jiabin99/graph_matching) for the graph matching task. + +* **Start tuning:** After the aforementioned steps, you could start the first stage tuning by filling blanks at [graphgpt_stage1.sh](https://github.com/HKUDS/GraphGPT/scripts/tune_script/graphgpt_stage1.sh). There is an example as below: + +```shell +# to fill in the following path to run the first stage of our GraphGPT! +model_path=../vicuna-7b-v1.5-16k +instruct_ds=./data/stage_1/graph_matching.json +graph_data_path=./graph_data/all_graph_data.pt +pretra_gnn=./clip_gt_arxiv +output_model=./checkpoints/stage_1 + +wandb offline +python -m torch.distributed.run --nnodes=1 --nproc_per_node=4 --master_port=20001 \ + graphgpt/train/train_mem.py \ + --model_name_or_path ${model_path} \ + --version v1 \ + --data_path ${instruct_ds} \ + --graph_content ./arxiv_ti_ab.json \ + --graph_data_path ${graph_data_path} \ + --graph_tower ${pretra_gnn} \ + --tune_graph_mlp_adapter True \ + --graph_select_layer -2 \ + --use_graph_start_end \ + --bf16 True \ + --output_dir ${output_model} \ + --num_train_epochs 3 \ + --per_device_train_batch_size 2 \ + --per_device_eval_batch_size 2 \ + --gradient_accumulation_steps 1 \ + --evaluation_strategy "no" \ + --save_strategy "steps" \ + --save_steps 2400 \ + --save_total_limit 1 \ + --learning_rate 2e-3 \ + --weight_decay 0. \ + --warmup_ratio 0.03 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --tf32 True \ + --model_max_length 2048 \ + --gradient_checkpointing True \ + --lazy_preprocess True \ + --report_to wandb +``` + + + +#### 3.3. Extract the Trained Projector [Back to Top] + +We could extract the trained projector in the stage 1 by filling blanks at [extract_projector.sh](https://github.com/HKUDS/GraphGPT/scripts/tune_script/extract_projector.sh). There is an example as below: + +```shell +# to fill in the following path to extract projector for the second tuning stage! +src_model=./checkpoints/stage_1 +output_proj=./checkpoints/stage_1_projector/stage_1_projector.bin + +python3.8 ./scripts/extract_graph_projector.py \ + --model_name_or_path ${src_model} \ + --output ${output_proj} +``` + + + +#### 3.4. Task-Specific Instruction Tuning [Back to Top] + +* **Prepare data:** The choices of our task-specific instruction data could be diverse, e.g., standard or COT (Chain-of-Thought) node classifiction, link prediction or mixing data for multitasking. Please refer to the [task_specific](https://huggingface.co/datasets/Jiabin99/task_specific). + +* **Start tuning:** After the aforementioned steps, you could start the second stage tuning by filling blanks at [graphgpt_stage2.sh](https://github.com/HKUDS/GraphGPT/scripts/tune_script/graphgpt_stage2.sh). There is an example as below: + +```shell +# to fill in the following path to run the second stage of our GraphGPT! +model_path=../vicuna-7b-v1.5-16k +instruct_ds=./data/stage_2/data_all_mix.json +graph_data_path=./graph_data/all_graph_data.pt +pretra_gnn=./clip_gt_arxiv +tuned_proj=./checkpoints/stage_1_projector/stage_1_projector.bin +output_model=./checkpoints/stage_2 + +wandb offline +python -m torch.distributed.run --nnodes=1 --nproc_per_node=4 --master_port=20001 \ + graphgpt/train/train_mem.py \ + --model_name_or_path ${model_path} \ + --version v1 \ + --data_path ${instruct_ds} \ + --graph_content ./arxiv_ti_ab.json \ + --graph_data_path ${graph_data_path} \ + --graph_tower ${pretra_gnn} \ + --pretrain_graph_mlp_adapter ${tuned_proj} \ + --tune_graph_mlp_adapter True \ + --graph_select_layer -2 \ + --use_graph_start_end True\ + --bf16 True \ + --output_dir ${output_model} \ + --num_train_epochs 2 \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --gradient_accumulation_steps 1 \ + --evaluation_strategy "no" \ + --save_strategy "steps" \ + --save_steps 50000 \ + --save_total_limit 1 \ + --learning_rate 2e-5 \ + --weight_decay 0. \ + --warmup_ratio 0.03 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --tf32 True \ + --model_max_length 2048 \ + --gradient_checkpointing True \ + --dataloader_num_workers 4 \ + --lazy_preprocess True \ + --report_to wandb + +``` + + + + + +## 4. Evaluating GraphGPT [Back to Top] + + + + +#### 4.1. Preparing Checkpoints and Data + +* **Checkpoints:** You could try to evaluate GraphGPT by using your own model or our released checkpoints. +* **Data:** We split test sets for different graph datasets and make the instruction data for evaluation. Please refer to the [evaluating](https://huggingface.co/datasets/Jiabin99/evaluating). + + + +#### 4.2. Running Evaluation + +You could start the second stage tuning by filling blanks at [graphgpt_eval.sh](https://github.com/HKUDS/GraphGPT/scripts/eval_script/graphgpt_eval.sh). There is an example as below: +```shell +# to fill in the following path to extract projector for the second tuning stage! +output_model=./checkpoints/stage_2 +datapath=./data/eval/arxiv_nc.json +graph_data_path=./graph_data/all_graph_data.pt +res_path=./output_stage_2_arxiv_nc +start_id=0 +end_id=20000 +num_gpus=2 + +python3.8 ./graphgpt/eval/run_graphgpt.py --model-name ${output_model} --prompting_file ${datapath} --graph_data_path ${graph_data_path} --output_res_path ${res_path} --start_id ${start_id} --end_id ${end_id} --num_gpus ${num_gpus} +``` +--------- + + +## Contact + +For any questions or feedback, feel free to contact [Jiabin Tang](mailto:jiabintang77@gmail.com). + + +## Citation + +If you find GraphGPT useful in your research or applications, please kindly cite: +``` +@articles{tang2023graphgpt, + title={GraphGPT: Graph Instruction Tuning for Large Language Models}, + author={Jiabin Tang, Yuhao Yang, Wei Wei, Lei Shi, Lixin Su, Suqi Cheng, Dawei Yin, Chao Huang}, + journal = {arXiv preprint}, + year={2023} +} +``` + + + +## Acknowledgements +You may refer to related work that serves as foundations for our framework and code repository, +[Vicuna](https://github.com/lm-sys/FastChat), [LLaVa](https://github.com/haotian-liu/LLaVA), We also partially draw inspirations from [MiniGPT-4](https://github.com/Vision-CAIR/MiniGPT-4). The design of our website and README.md was inspired by [NExT-GPT](https://next-gpt.github.io/). Thanks for their wonderful works. + diff --git a/assets/demo_narrow.gif b/assets/demo_narrow.gif new file mode 100644 index 0000000..aa1af38 Binary files /dev/null and b/assets/demo_narrow.gif differ diff --git a/assets/screenshot_cli.png b/assets/screenshot_cli.png new file mode 100644 index 0000000..7a7dd5d Binary files /dev/null and b/assets/screenshot_cli.png differ diff --git a/assets/screenshot_gui.png b/assets/screenshot_gui.png new file mode 100644 index 0000000..ecb41d2 Binary files /dev/null and b/assets/screenshot_gui.png differ diff --git a/assets/server_arch.png b/assets/server_arch.png new file mode 100644 index 0000000..16708a8 Binary files /dev/null and b/assets/server_arch.png differ diff --git a/assets/vicuna_logo.jpeg b/assets/vicuna_logo.jpeg new file mode 100644 index 0000000..e7883dc Binary files /dev/null and b/assets/vicuna_logo.jpeg differ diff --git a/graphgpt/.DS_Store b/graphgpt/.DS_Store new file mode 100644 index 0000000..5b99346 Binary files /dev/null and b/graphgpt/.DS_Store differ diff --git a/graphgpt/__init__.py b/graphgpt/__init__.py new file mode 100644 index 0000000..5635676 --- /dev/null +++ b/graphgpt/__init__.py @@ -0,0 +1 @@ +__version__ = "0.2.11" diff --git a/graphgpt/__pycache__/__init__.cpython-311.pyc b/graphgpt/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000..51d98e8 Binary files /dev/null and b/graphgpt/__pycache__/__init__.cpython-311.pyc differ diff --git a/graphgpt/__pycache__/constants.cpython-311.pyc b/graphgpt/__pycache__/constants.cpython-311.pyc new file mode 100644 index 0000000..ea3afb2 Binary files /dev/null and b/graphgpt/__pycache__/constants.cpython-311.pyc differ diff --git a/graphgpt/__pycache__/conversation.cpython-311.pyc b/graphgpt/__pycache__/conversation.cpython-311.pyc new file mode 100644 index 0000000..bed88fa Binary files /dev/null and b/graphgpt/__pycache__/conversation.cpython-311.pyc differ diff --git a/graphgpt/__pycache__/utils.cpython-311.pyc b/graphgpt/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000..9ef66dd Binary files /dev/null and b/graphgpt/__pycache__/utils.cpython-311.pyc differ diff --git a/graphgpt/constants.py b/graphgpt/constants.py new file mode 100644 index 0000000..7d5183b --- /dev/null +++ b/graphgpt/constants.py @@ -0,0 +1,52 @@ +from enum import IntEnum +import os + +# For the gradio web server +SERVER_ERROR_MSG = ( + "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" +) +MODERATION_MSG = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE FIX YOUR INPUT AND TRY AGAIN." +CONVERSATION_LIMIT_MSG = "YOU HAVE REACHED THE CONVERSATION LENGTH LIMIT. PLEASE CLEAR HISTORY AND START A NEW CONVERSATION." +INPUT_CHAR_LEN_LIMIT = 2560 +CONVERSATION_LEN_LIMIT = 50 +LOGDIR = "." + +# For the controller and workers(could be overwritten through ENV variables.) +CONTROLLER_HEART_BEAT_EXPIRATION = int( + os.getenv("FASTCHAT_CONTROLLER_HEART_BEAT_EXPIRATION", 90) +) +WORKER_HEART_BEAT_INTERVAL = int(os.getenv("FASTCHAT_WORKER_HEART_BEAT_INTERVAL", 30)) +WORKER_API_TIMEOUT = int(os.getenv("FASTCHAT_WORKER_API_TIMEOUT", 100)) +WORKER_API_EMBEDDING_BATCH_SIZE = int(os.getenv("WORKER_API_EMBEDDING_BATCH_SIZE", 4)) + + +class ErrorCode(IntEnum): + """ + https://platform.openai.com/docs/guides/error-codes/api-errors + """ + + VALIDATION_TYPE_ERROR = 40001 + + INVALID_AUTH_KEY = 40101 + INCORRECT_AUTH_KEY = 40102 + NO_PERMISSION = 40103 + + INVALID_MODEL = 40301 + PARAM_OUT_OF_RANGE = 40302 + CONTEXT_OVERFLOW = 40303 + + RATE_LIMIT = 42901 + QUOTA_EXCEEDED = 42902 + ENGINE_OVERLOADED = 42903 + + INTERNAL_ERROR = 50001 + CUDA_OUT_OF_MEMORY = 50002 + GRADIO_REQUEST_ERROR = 50003 + GRADIO_STREAM_UNKNOWN_ERROR = 50004 + CONTROLLER_NO_WORKER = 50005 + CONTROLLER_WORKER_TIMEOUT = 50006 + +DEFAULT_GRAPH_TOKEN = "" +DEFAULT_GRAPH_PATCH_TOKEN = "" +DEFAULT_G_START_TOKEN = "" +DEFAULT_G_END_TOKEN = "" \ No newline at end of file diff --git a/graphgpt/conversation.py b/graphgpt/conversation.py new file mode 100644 index 0000000..add5e03 --- /dev/null +++ b/graphgpt/conversation.py @@ -0,0 +1,382 @@ +import dataclasses +from enum import auto, Enum +from typing import List, Tuple + + +class SeparatorStyle(Enum): + """Different separator style.""" + SINGLE = auto() + TWO = auto() + MPT = auto() + + +@dataclasses.dataclass +class Conversation: + """A class that keeps all conversation history.""" + system: str + roles: List[str] + messages: List[List[str]] + offset: int + sep_style: SeparatorStyle = SeparatorStyle.SINGLE + sep: str = "###" + sep2: str = None + version: str = "Unknown" + + skip_next: bool = False + + def get_prompt(self): + if self.sep_style == SeparatorStyle.SINGLE: + ret = self.system + self.sep + for role, message in self.messages: + if message: + if type(message) is tuple: + message, _, _ = message + ret += role + ": " + message + self.sep + else: + ret += role + ":" + return ret + elif self.sep_style == SeparatorStyle.TWO: + seps = [self.sep, self.sep2] + ret = self.system + seps[0] + for i, (role, message) in enumerate(self.messages): + if message: + if type(message) is tuple: + message, _, _ = message + ret += role + ": " + message + seps[i % 2] + else: + ret += role + ":" + return ret + if self.sep_style == SeparatorStyle.MPT: + ret = self.system + self.sep + for role, message in self.messages: + if message: + if type(message) is tuple: + message, _, _ = message + ret += role + message + self.sep + else: + ret += role + return ret + else: + raise ValueError(f"Invalid style: {self.sep_style}") + + def append_message(self, role, message): + self.messages.append([role, message]) + + def get_images(self, return_pil=False): + images = [] + for i, (role, msg) in enumerate(self.messages[self.offset:]): + if i % 2 == 0: + if type(msg) is tuple: + import base64 + from io import BytesIO + from PIL import Image + msg, image, image_process_mode = msg + if image_process_mode == "Pad": + def expand2square(pil_img, background_color=(122, 116, 104)): + width, height = pil_img.size + if width == height: + return pil_img + elif width > height: + result = Image.new(pil_img.mode, (width, width), background_color) + result.paste(pil_img, (0, (width - height) // 2)) + return result + else: + result = Image.new(pil_img.mode, (height, height), background_color) + result.paste(pil_img, ((height - width) // 2, 0)) + return result + image = expand2square(image) + elif image_process_mode == "Crop": + pass + elif image_process_mode == "Resize": + image = image.resize((224, 224)) + else: + raise ValueError(f"Invalid image_process_mode: {image_process_mode}") + max_hw, min_hw = max(image.size), min(image.size) + aspect_ratio = max_hw / min_hw + max_len, min_len = 800, 400 + shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw)) + longest_edge = int(shortest_edge * aspect_ratio) + W, H = image.size + if H > W: + H, W = longest_edge, shortest_edge + else: + H, W = shortest_edge, longest_edge + image = image.resize((W, H)) + if return_pil: + images.append(image) + else: + buffered = BytesIO() + image.save(buffered, format="JPEG") + img_b64_str = base64.b64encode(buffered.getvalue()).decode() + images.append(img_b64_str) + return images + + def to_gradio_chatbot(self): + ret = [] + for i, (role, msg) in enumerate(self.messages[self.offset:]): + if i % 2 == 0: + if type(msg) is tuple: + import base64 + from io import BytesIO + msg, image, image_process_mode = msg + max_hw, min_hw = max(image.size), min(image.size) + aspect_ratio = max_hw / min_hw + max_len, min_len = 800, 400 + shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw)) + longest_edge = int(shortest_edge * aspect_ratio) + W, H = image.size + if H > W: + H, W = longest_edge, shortest_edge + else: + H, W = shortest_edge, longest_edge + image = image.resize((W, H)) + # image = image.resize((224, 224)) + buffered = BytesIO() + image.save(buffered, format="JPEG") + img_b64_str = base64.b64encode(buffered.getvalue()).decode() + img_str = f'user upload image' + msg = msg.replace('', img_str) + ret.append([msg, None]) + else: + ret[-1][-1] = msg + return ret + + def copy(self): + return Conversation( + system=self.system, + roles=self.roles, + messages=[[x, y] for x, y in self.messages], + offset=self.offset, + sep_style=self.sep_style, + sep=self.sep, + sep2=self.sep2) + + def dict(self): + if len(self.get_images()) > 0: + return { + "system": self.system, + "roles": self.roles, + "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages], + "offset": self.offset, + "sep": self.sep, + "sep2": self.sep2, + } + return { + "system": self.system, + "roles": self.roles, + "messages": self.messages, + "offset": self.offset, + "sep": self.sep, + "sep2": self.sep2, + } + + +conv_v1 = Conversation( + system="A chat between a curious human and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the human's questions.", + roles=("Human", "Assistant"), + messages=( + ("Human", "Give three tips for staying healthy."), + ("Assistant", + "Sure, here are three tips for staying healthy:\n" + "1. Exercise regularly: Regular physical activity can help improve your overall health and wellbeing. " + "It can also help reduce your risk of chronic conditions such as obesity, diabetes, heart disease, " + "and certain cancers. Aim for at least 150 minutes of moderate-intensity aerobic exercise or " + "75 minutes of vigorous-intensity aerobic exercise per week, along with muscle-strengthening " + "activities at least two days per week.\n" + "2. Eat a balanced diet: Eating a balanced diet that is rich in fruits, " + "vegetables, whole grains, lean proteins, and healthy fats can help support " + "your overall health. Try to limit your intake of processed and high-sugar foods, " + "and aim to drink plenty of water throughout the day.\n" + "3. Get enough sleep: Getting enough quality sleep is essential for your physical " + "and mental health. Adults should aim for seven to nine hours of sleep per night. " + "Establish a regular sleep schedule and try to create a relaxing bedtime routine to " + "help improve the quality of your sleep.") + ), + offset=2, + sep_style=SeparatorStyle.SINGLE, + sep="###", +) + +conv_v1_2 = Conversation( + system="A chat between a curious human and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the human's questions.", + roles=("Human", "Assistant"), + messages=( + ("Human", "What are the key differences between renewable and non-renewable energy sources?"), + ("Assistant", + "Renewable energy sources are those that can be replenished naturally in a relatively " + "short amount of time, such as solar, wind, hydro, geothermal, and biomass. " + "Non-renewable energy sources, on the other hand, are finite and will eventually be " + "depleted, such as coal, oil, and natural gas. Here are some key differences between " + "renewable and non-renewable energy sources:\n" + "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable " + "energy sources are finite and will eventually run out.\n" + "2. Environmental impact: Renewable energy sources have a much lower environmental impact " + "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, " + "and other negative effects.\n" + "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically " + "have lower operational costs than non-renewable sources.\n" + "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote " + "locations than non-renewable sources.\n" + "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different " + "situations and needs, while non-renewable sources are more rigid and inflexible.\n" + "6. Sustainability: Renewable energy sources are more sustainable over the long term, while " + "non-renewable sources are not, and their depletion can lead to economic and social instability.\n") + ), + offset=2, + sep_style=SeparatorStyle.SINGLE, + sep="###", +) + +conv_vicuna_v1_1 = Conversation( + system="A chat between a curious user and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the user's questions.", + roles=("USER", "ASSISTANT"), + version="v1", + messages=(), + offset=0, + sep_style=SeparatorStyle.TWO, + sep=" ", + sep2="", +) + +conv_mpt = Conversation( + system="""<|im_start|>system +- You are a helpful language and vision assistant. +- You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language. +- You should follow the instructions carefully and explain your answers in detail.""", + roles=("<|im_start|>user\n", "<|im_start|>assistant\n"), + version="mpt", + messages=(), + offset=0, + sep_style=SeparatorStyle.MPT, + sep="<|im_end|>", +) + +conv_mpt_text = Conversation( + system="""<|im_start|>system +- You are a helpful assistant chatbot trained by MosaicML. +- You answer questions. +- You are excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user. +- You are more than just an information source, you are also able to write poetry, short stories, and make jokes.""", + roles=("<|im_start|>user\n", "<|im_start|>assistant\n"), + version="mpt", + messages=(), + offset=0, + sep_style=SeparatorStyle.MPT, + sep="<|im_end|>", +) + +conv_bair_v1 = Conversation( + system="BEGINNING OF CONVERSATION:", + roles=("USER", "GPT"), + messages=(), + offset=0, + sep_style=SeparatorStyle.TWO, + sep=" ", + sep2="", +) + +simple_conv = Conversation( + system="A chat between a curious human and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the human's questions.", + roles=("Human", "Assistant"), + messages=( + ("Human", "Hi!"), + ("Assistant", "Hi there! How can I help you today?") + ), + offset=2, + sep_style=SeparatorStyle.SINGLE, + sep="###", +) + +simple_conv_multimodal = Conversation( + system="You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab." + "You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language." + "Follow the instructions carefully and explain your answers in detail.", + roles=("Human", "Assistant"), + messages=( + ("Human", "Hi!"), + ("Assistant", "Hi there! How can I help you today?\n") + ), + offset=2, + sep_style=SeparatorStyle.SINGLE, + sep="###", +) + +simple_conv_mpt_multimodal = Conversation( + system="""<|im_start|>system +- You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab. +- You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language. +- You should follow the instructions carefully and explain your answers in detail.""", + roles=("<|im_start|>user\n", "<|im_start|>assistant\n"), + version="mpt", + messages=(), + offset=0, + sep_style=SeparatorStyle.MPT, + sep="<|im_end|>", +) + +simple_conv_legacy = Conversation( + system="You are LLaVA, a large language model trained by UW Madison WAIV Lab." + "You are designed to assist human with a variety of tasks using natural language." + "Follow the instructions carefully.", + roles=("Human", "Assistant"), + messages=( + ("Human", "Hi!\n\n### Response:"), + ("Assistant", "Hi there! How can I help you today?\n") + ), + offset=2, + sep_style=SeparatorStyle.SINGLE, + sep="###", +) + +conv_llava_v1 = Conversation( + system="You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab." + "You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language." + "Follow the instructions carefully and explain your answers in detail.", + roles=("USER", "ASSISTANT"), + version="v1", + messages=(), + offset=0, + sep_style=SeparatorStyle.TWO, + sep=" ", + sep2="", +) + +conv_graphchat_v1 = Conversation( + system="You are GraphGPT, a large language and graph-structral assistant trained by HKUDS Lab." + "You are able to understand the graph structures that the user provides, and assist the user with a variety of tasks using natural language." + "Follow the instructions carefully and explain your answers in detail.", + roles=("USER", "ASSISTANT"), + version="v1", + messages=(), + offset=0, + sep_style=SeparatorStyle.TWO, + sep=" ", + sep2="", +) + +default_conversation = conv_v1_2 +conv_templates = { + "default": conv_v1_2, + "simple": simple_conv, + "simple_legacy": simple_conv_legacy, + "multimodal": simple_conv_multimodal, + "mpt_multimodal": simple_conv_mpt_multimodal, + "llava_v1": conv_llava_v1, + "graphchat_v1": conv_graphchat_v1, + + + # fastchat + "v1": conv_v1_2, + "bair_v1": conv_bair_v1, + "vicuna_v1_1": conv_vicuna_v1_1, + "mpt": conv_mpt, + "mpt_text": conv_mpt_text, +} + + +if __name__ == "__main__": + print(default_conversation.get_prompt()) diff --git a/graphgpt/eval/.DS_Store b/graphgpt/eval/.DS_Store new file mode 100644 index 0000000..b002ded Binary files /dev/null and b/graphgpt/eval/.DS_Store differ diff --git a/graphgpt/eval/README.md b/graphgpt/eval/README.md new file mode 100644 index 0000000..403c9ac --- /dev/null +++ b/graphgpt/eval/README.md @@ -0,0 +1,187 @@ +# Evaluations + +This directory contains end-to-end pipelines for AI-enhanced evaluation. We will introduce the evaluation pipeline and the data format in this document. + +## Generate Answers + +### ChatGPT (gpt-3.5-turbo) + +Make sure you have setup the OpenAI API Key in your environment. Then run: + +```bash +python qa_baseline_gpt35.py --question table/question.jsonl --output table/answer/answer_gpt35.jsonl +``` + +### Bard + +Unfortunately, Bard has not release its public APIs till now. You may have to enter the anwsers manually. Or you could find a third-party project that interfaces with Bard. + +### Vicuna and others + +To generate answers with Vicuna or other models, specify path to the model checkpoint, a desired model ID and run: +```bash +python get_model_answer.py --model-id [MODEL-ID] --model-path /model/path --question-file table/question.jsonl --answer-file table/answer/answer.jsonl --num-gpus [NUM-GPUS] +``` +Then the answers to the questions will be saved in `table/answer/answer.jsonl`. +Note: we assume the model can be loaded with a single GPU. + +## Evaluate Answers Automatically + +### Generete Reviews with GPT-4 + +Note: Below script requires access to GPT-4 API. If you only have access to GPT-4 on web interface, you can evaluate the answers by manually formatting the prompt. See more details in the **Reviewers** and **Prompts** sections in **Data Format**. +It is critical to follow the prompt templates; otherwise GPT-4 may not give fair reviews. `table/review/*.jsonl` are some review examples generated by GPT-4 or you can view them on our eval [webpage](https://vicuna.lmsys.org/eval/). + +To use the script for generating reviews with GPT-4, you need to `export` your OpenAI API key in environment variable. Then run: +```bash +python eval_gpt_review.py -q table/question.jsonl -a /path/to/answer_1.jsonl /path/to/answer_2.jsonl -p table/prompt.jsonl -r table/reviewer.jsonl -o /path/to/review_output.jsonl +``` +The GPT-4 reviews will be saved in `/path/to/review_output.jsonl`. Note: we implement some simple parsing code to extract the score pairs from GPT-4's reviews. However, you need to double check whether the parsed score pair are correct. Sometime the parsing logic may fail if GPT-4 doesn't give a structured answer. + +## Visualize Results + +You can generate the data for the webpage by running: + +```bash +python eval/generate_webpage_data_from_table.py +``` + +Then you can serve a static website in `webpage` to see the results. + +## Data Format + +If you want to have a deeper understanding of our evaluation pipeline or want to contribute to the evaluation process, you need to learn the data format we used for evaluation. + +Our evaluation data are encoded with [JSON Lines](https://jsonlines.org/). + +### Random ID Generation + +We use the `shortuuid` Python library for generating short random UUIDs. + +```python +import shortuuid +shortuuid.uuid() -> str +``` + +### Models + +`model.jsonl` contains model information we used for generating anwsers. + +Each row contains a record of a model with the following field: + +* `model_id` (str): A unique ID for a model. Models with different IDs is supposed to have different performance. This ID is generated by `{model_name}:{model_version}`. +* `model_name` (str): The name of a model. This is not unique, because a model could be trained and updated continuously, but it is still considered as the same model with different versions. +* `model_version` (str): The version of a model. +* `model_metadata` (Any): Any metadata of a model (descriptions etc). This is optional. + +For example: + +```json +{ + "model_id": "vicuna-13b:v1", + "model_name": "vicuna-13b", + "model_version": "v1", + "model_metadata": "learning rate 1e-5, 3 epochs, 13b" +} +``` + +### Prompts + +We store prompts in `prompt.jsonl`. Each row contains a record of a prompt with the following field: + +* `prompt_id` (int): A unique integer ID for a prompt. Prompts with different IDs are supposed to have different purpose. +* `system_prompt` (str): The system prompt given to a model. This is the prompt that the model sees first. +* `prompt_template` (str): The prompt body. This is the user prompt that the model sees after the system prompt. It is a Python f-string template, so that we can fill in the inputs later. +* `defaults` (dict): A dictionary of default values for the prompt template. It can be empty. +* `description` (str): A description of the functionality of the prompt. + +For example: + +```json +{ + "prompt_id": 1, + "system_prompt": "You are a helpful assistant.", + "prompt_template": "[Question]\n{question}\n\n[Assistant 1]\n{answer_1}\n\n[End of Assistant 1]\n\n[Assistant 2]\n{answer_2}\n\n[End of Assistant 2]\n\n[System]\n{prompt}\n\n", + "defaults": {"prompt": "Which assistant is more helpful?"}, + "description": "Compare two assistants' answers to a question." +} +``` + +### Reviewers + +`reviewer.jsonl` contains reviewer information we used for reviewing answers generated by different models. Each row contains a record of a reviewer with the following field: + +* `reviewer_id` (str): A unique ID for a reviewer. Reviewers with different IDs is supposed to have different reviewing performance. +* `prompt_id` (str): The ID of the prompt given to the reviewer (e.g., an AI assistant). Different prompts could result in different reviewing performance. +* `metadata` (dict): Metadata of a reviewer about its configurations. +* `description` (str): A description of the reviewer. +* `category` (str): The category that the reviewer belongs to. + +For example: + +```json +{ + "reviewer_id": "gpt-4-0328-default", + "prompt_id": 1, + "temperature": 0.2, + "max_tokens": 8192, + "description": "GPT-4 for general questions.", + "category": "general" +} +``` + +### Questions + +`question.jsonl` contains questions we used for evaluation. Each row contains a record of a question with the following field: + +* `question_id` (int): A unique integer for a question. Questions with different IDs is supposed to be different. +* `text` (str): The question text. +* `category` (str): The category of the question. Questions with the same category are supposed to be similar or originate from the same source. + +### Answers + +`answer/xxx.jsonl` contains answers generated by different models. Each row contains a record of an answer with the following field: + +* `answer_id` (str): A unique UUID for an answer. Answers with different IDs is supposed to be different. +* `question_id` (int): The ID of the question the answer is generated for. +* `model_id` (str): The ID of the model the answer is generated by. +* `text` (str): The answer text. +* `metadata` (dict): Any metadata of the answer. + +Example: + +```json +{ + "answer_id": "[short uuid]", + "question_id": 1, + "model_id": "vicuna-13b:v1", + "text": "Here are five tips...", + "metadata": {} +} +``` + +### Reviews + +`review/xxx.jsonl` contains reviews given by reviewers, comparing peformance between a pair of models. Each row contains a record of a review with the following field: + +* `review_id` (str): A unique UUID for a review. Reviews with different IDs is supposed to be different. +* `question_id` (int): The ID of the question the review is given for. +* `answer1_id` (str): The ID of the first answer. +* `answer2_id` (str): The ID of the second answer. +* `text` (str): The review text. +* `score` (list): A list of scores given by the reviewer. The first score is for the first answer, and the second score is for the second answer. +* `reviewer_id` (str): The ID of the reviewer. +* `metadata` (dict): Any metadata of the review. + +```json +{ + "review_id": "[short uuid]", + "question_id": 1, + "answer1_id": "[answer1_id]", + "answer2_id": "[answer2_id]", + "text": "Assistant 2 is better...", + "score": [9.0, 7.5], + "reviewer_id": "gpt-4-0328-default", + "metadata": {} +} +``` diff --git a/graphgpt/eval/requirements.txt b/graphgpt/eval/requirements.txt new file mode 100644 index 0000000..c2490e1 --- /dev/null +++ b/graphgpt/eval/requirements.txt @@ -0,0 +1,2 @@ +shortuuid +ray \ No newline at end of file diff --git a/graphgpt/eval/run_graphgpt.py b/graphgpt/eval/run_graphgpt.py new file mode 100644 index 0000000..ac5486f --- /dev/null +++ b/graphgpt/eval/run_graphgpt.py @@ -0,0 +1,243 @@ +import argparse +from transformers import AutoTokenizer, AutoModelForCausalLM +import torch +import os +from graphgpt.conversation import conv_templates, SeparatorStyle +from graphgpt.utils import disable_torch_init +from transformers import CLIPVisionModel, CLIPImageProcessor, StoppingCriteria +from graphgpt.model import * +from graphgpt.model.utils import KeywordsStoppingCriteria +from torch_geometric.data import Data +import json +import copy + +import os +import requests +from PIL import Image +from io import BytesIO + +from tqdm import tqdm +import json +import os.path as osp + +import ray + +# os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python' + +DEFAULT_GRAPH_TOKEN = "" +DEFAULT_GRAPH_PATCH_TOKEN = "" +DEFAULT_G_START_TOKEN = "" +DEFAULT_G_END_TOKEN = "" + + +def load_graph(instruct_item, graph_data_path): + graph_data_all = torch.load(graph_data_path) + graph_dict = instruct_item['graph'] + graph_edge_index = torch.Tensor(copy.deepcopy(graph_dict['edge_index'])).long() + graph_node_list = copy.deepcopy(graph_dict['node_list']) + target_node = copy.deepcopy(graph_dict['node_idx']) + graph_type = copy.deepcopy(instruct_item['id']).split('_')[0] + graph_node_rep = graph_data_all[graph_type].x[graph_node_list] ## + + cur_token_len = len(graph_node_rep) # FIXME: 14 is hardcoded patch size + + graph_ret = Data(graph_node = graph_node_rep, edge_index=graph_edge_index, target_node = torch.tensor([target_node])) + + return { + 'graph_data': graph_ret, + 'graph_token_len': cur_token_len + } + + +def load_prompting_file(file_path): + with open(file_path, 'r') as f: + data = json.load(f) + return data + +# def prepare_query(instruct_item): + + +def run_eval(args, num_gpus): + # split question file into num_gpus files + prompt_file = load_prompting_file(args.prompting_file) + prompt_file = prompt_file[args.start_id:args.end_id] + chunk_size = len(prompt_file) // num_gpus + ans_handles = [] + split_list = list(range(args.start_id, args.end_id, chunk_size)) + idx_list = list(range(0, len(prompt_file), chunk_size)) + if len(split_list) == num_gpus: + split_list.append(args.end_id) + idx_list.append(len(prompt_file)) + elif len(split_list) == num_gpus + 1: + split_list[-1] = args.end_id + idx_list[-1] = len(prompt_file) + else: + raise ValueError('error in the number of list') + + if osp.exists(args.output_res_path) is False: + os.mkdir(args.output_res_path) + + for idx in range(len(idx_list) - 1): + start_idx = idx_list[idx] + end_idx = idx_list[idx + 1] + + start_split = split_list[idx] + end_split = split_list[idx + 1] + ans_handles.append( + eval_model.remote( + args, prompt_file[start_idx:end_idx], start_split, end_split + ) + ) + + ans_jsons = [] + for ans_handle in ans_handles: + ans_jsons.extend(ray.get(ans_handle)) + + # with open(args.output_res_path, "w") as ans_file: + # for line in ans_jsons: + # ans_file.write(json.dumps(line) + "\n") + + +@ray.remote(num_gpus=1) +@torch.inference_mode() +def eval_model(args, prompt_file, start_idx, end_idx): + # load prompting file + # prompt_file = load_prompting_file(args.prompting_file) + + + # Model + disable_torch_init() + # model_name = os.path.expanduser(args.model_name) + print('start loading') + tokenizer = AutoTokenizer.from_pretrained(args.model_name) + print('finish loading') + + print('start loading') + model = GraphLlamaForCausalLM.from_pretrained(args.model_name, torch_dtype=torch.float16, use_cache=True, low_cpu_mem_usage=True).cuda() + print('finish loading') + + use_graph_start_end = getattr(model.config, "use_graph_start_end", False) + tokenizer.add_tokens([DEFAULT_GRAPH_PATCH_TOKEN], special_tokens=True) + if use_graph_start_end: + tokenizer.add_tokens([DEFAULT_G_START_TOKEN, DEFAULT_G_END_TOKEN], special_tokens=True) + + graph_tower = model.get_model().graph_tower + + # TODO: add graph tower + # if graph_tower.device.type == 'meta': + # print('meta') + clip_graph, args_graph= load_model_pretrained(CLIP, './clip_gt_arxiv_pub') + graph_tower = graph_transformer(args_graph) + graph_tower = transfer_param_tograph(clip_graph, graph_tower) + + model.get_model().graph_tower = graph_tower.cuda() + # else: + # print('other') + # print(next(graph_tower.parameters()).dtype) + graph_tower.to(device='cuda', dtype=torch.float16) + graph_config = graph_tower.config + graph_config.graph_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_GRAPH_PATCH_TOKEN])[0] + graph_config.use_graph_start_end = use_graph_start_end + if use_graph_start_end: + graph_config.graph_start_token, graph_config.graph_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_G_START_TOKEN, DEFAULT_G_END_TOKEN]) + # TODO: add graph token len + + res_data = [] + print(f'total: {len(prompt_file)}') + for idx, instruct_item in tqdm(enumerate(prompt_file)): + # instruct_item = prompt_file[0] + # if idx >= 3: + # break + graph_dict = load_graph(instruct_item, args.graph_data_path) + graph_token_len = graph_dict['graph_token_len'] + graph_data = graph_dict['graph_data'] + + qs = instruct_item["conversations"][0]["value"] + # if use_graph_start_end: + # qs = qs + '\n' + DEFAULT_G_START_TOKEN + DEFAULT_GRAPH_PATCH_TOKEN * graph_token_len + DEFAULT_G_END_TOKEN + # else: + # qs = qs + '\n' + DEFAULT_GRAPH_PATCH_TOKEN * graph_token_len + + replace_token = DEFAULT_GRAPH_PATCH_TOKEN * graph_token_len + replace_token = DEFAULT_G_START_TOKEN + replace_token + DEFAULT_G_END_TOKEN + qs = qs.replace(DEFAULT_GRAPH_TOKEN, replace_token) + + # if "v1" in args.model_name.lower(): + # conv_mode = "graphchat_v1" + # else: + # raise ValueError('Don\'t support this model') + conv_mode = "graphchat_v1" + + if args.conv_mode is not None and conv_mode != args.conv_mode: + print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode)) + else: + args.conv_mode = conv_mode + + conv = conv_templates[args.conv_mode].copy() + conv.append_message(conv.roles[0], qs) + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + inputs = tokenizer([prompt]) + + + + input_ids = torch.as_tensor(inputs.input_ids).cuda() + + stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 + keywords = [stop_str] + stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) + + graph_data.graph_node = graph_data.graph_node.to(torch.float16) + # graph_data.edge_index = graph_data.edge_index.to(torch.float16) + + with torch.inference_mode(): + output_ids = model.generate( + input_ids, + graph_data=graph_data.cuda(), + do_sample=True, + temperature=0.2, + max_new_tokens=1024, + stopping_criteria=[stopping_criteria]) + + input_token_len = input_ids.shape[1] + n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() + if n_diff_input_output > 0: + print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids') + outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] + outputs = outputs.strip() + if outputs.endswith(stop_str): + outputs = outputs[:-len(stop_str)] + outputs = outputs.strip() + # print(outputs) + + res_data.append({"id": instruct_item["id"], "node_idx": instruct_item["graph"]["node_idx"], "res": outputs}.copy()) + with open(osp.join(args.output_res_path, 'arxiv_test_res_{}_{}.json'.format(start_idx, end_idx)), "w") as fout: + json.dump(res_data, fout, indent=4) + return res_data + # with open(args.output_res_path, "w") as fout: + # json.dump(res_data, fout, indent=4) + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model-name", type=str, default="facebook/opt-350m") + # parser.add_argument("--image-file", type=str, required=True) + # parser.add_argument("--query", type=str, required=True) + parser.add_argument("--prompting_file", type=str, default=None) + parser.add_argument("--conv-mode", type=str, default=None) + parser.add_argument("--graph_data_path", type=str, default=None) + + parser.add_argument("--output_res_path", type=str, default=None) + parser.add_argument("--num_gpus", type=int, default=4) + + parser.add_argument("--start_id", type=int, default=0) + parser.add_argument("--end_id", type=int, default=20567) + + args = parser.parse_args() + + # eval_model(args) + + ray.init() + run_eval(args, args.num_gpus) + + +# protobuf 4.22.3 \ No newline at end of file diff --git a/graphgpt/eval/run_graphgpt_LP.py b/graphgpt/eval/run_graphgpt_LP.py new file mode 100644 index 0000000..519ed39 --- /dev/null +++ b/graphgpt/eval/run_graphgpt_LP.py @@ -0,0 +1,320 @@ +import argparse +from transformers import AutoTokenizer, AutoModelForCausalLM +import torch +import os +from graphgpt.conversation import conv_templates, SeparatorStyle +from graphgpt.utils import disable_torch_init +from transformers import CLIPVisionModel, CLIPImageProcessor, StoppingCriteria +from graphgpt.model import * +from graphgpt.model.utils import KeywordsStoppingCriteria +from torch_geometric.data import Data +import json +import copy +import random + +import os +import requests +from PIL import Image +from io import BytesIO + +from tqdm import tqdm +import json +import os.path as osp + +import ray + +# os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python' + +DEFAULT_GRAPH_TOKEN = "" +DEFAULT_GRAPH_PATCH_TOKEN = "" +DEFAULT_G_START_TOKEN = "" +DEFAULT_G_END_TOKEN = "" + + +def load_graph(instruct_item, graph_data_path): + graph_data_all = torch.load(graph_data_path) + graph_dict = instruct_item['graph'] + graph_edge_index = torch.Tensor(copy.deepcopy(graph_dict['edge_index'])).long() + graph_node_list = copy.deepcopy(graph_dict['node_list']) + target_node = copy.deepcopy(graph_dict['node_idx']) + graph_type = copy.deepcopy(instruct_item['id']).split('_')[0] + graph_node_rep = graph_data_all[graph_type].x[graph_node_list] ## + + cur_token_len = len(graph_node_rep) # FIXME: 14 is hardcoded patch size + + graph_ret = Data(graph_node = graph_node_rep, edge_index=graph_edge_index, target_node = torch.tensor([target_node])) + + return { + 'graph_data': graph_ret, + 'graph_token_len': cur_token_len + } + +def load_graph_LP(instruct_item, graph_data_path): + graph_data_all = torch.load(graph_data_path) + graph_dict = instruct_item['graph'] + graph_edge_index_1 = torch.Tensor(copy.deepcopy(graph_dict['edge_index_1'])).long() + graph_node_list_1 = copy.deepcopy(graph_dict['node_list_1']) + target_node_1 = copy.deepcopy(graph_dict['node_idx_1']) + graph_type = copy.deepcopy(instruct_item['id']).split('_')[0] + graph_node_rep_1 = graph_data_all[graph_type].x[graph_node_list_1] ## + + cur_token_len_1 = len(graph_node_rep_1) # FIXME: 14 is hardcoded patch size + + graph_edge_index_2 = torch.Tensor(copy.deepcopy(graph_dict['edge_index_2'])).long() + graph_node_list_2 = copy.deepcopy(graph_dict['node_list_2']) + target_node_2 = copy.deepcopy(graph_dict['node_idx_2']) + graph_node_rep_2 = graph_data_all[graph_type].x[graph_node_list_2] ## + + cur_token_len_2 = len(graph_node_rep_2) # FIXME: 14 is hardcoded patch + + graph_ret = { + 'graph_1': Data(graph_node = graph_node_rep_1, edge_index=graph_edge_index_1, target_node = torch.tensor([target_node_1])), + 'graph_2': Data(graph_node = graph_node_rep_2, edge_index=graph_edge_index_2, target_node = torch.tensor([target_node_2])) + } + + return { + 'graph_data': graph_ret, + 'graph_token_len_1': cur_token_len_1, + 'graph_token_len_2': cur_token_len_2 + } + + +def load_prompting_file(file_path): + with open(file_path, 'r') as f: + data = json.load(f) + return data + +# def prepare_query(instruct_item): + + +def run_eval(args, num_gpus): + # split question file into num_gpus files + prompt_file = load_prompting_file(args.prompting_file) + if args.is_shuffle: + print('shuffle the prompt file!') + random.seed(0) # 设置随机种子 + random.shuffle(prompt_file) + else: + print('Not shuffle the prompt file!') + + prompt_file = prompt_file[args.start_id:args.end_id] + chunk_size = len(prompt_file) // num_gpus + ans_handles = [] + split_list = list(range(args.start_id, args.end_id, chunk_size)) + idx_list = list(range(0, len(prompt_file), chunk_size)) + if len(split_list) == num_gpus: + split_list.append(args.end_id) + idx_list.append(len(prompt_file)) + elif len(split_list) == num_gpus + 1: + split_list[-1] = args.end_id + idx_list[-1] = len(prompt_file) + else: + raise ValueError('error in the number of list') + + if osp.exists(args.output_res_path) is False: + os.mkdir(args.output_res_path) + + for idx in range(len(idx_list) - 1): + start_idx = idx_list[idx] + end_idx = idx_list[idx + 1] + + start_split = split_list[idx] + end_split = split_list[idx + 1] + ans_handles.append( + eval_model.remote( + args, prompt_file[start_idx:end_idx], start_split, end_split + ) + ) + + ans_jsons = [] + for ans_handle in ans_handles: + ans_jsons.extend(ray.get(ans_handle)) + + # with open(args.output_res_path, "w") as ans_file: + # for line in ans_jsons: + # ans_file.write(json.dumps(line) + "\n") + + +@ray.remote(num_gpus=1) +@torch.inference_mode() +def eval_model(args, prompt_file, start_idx, end_idx): + # load prompting file + # prompt_file = load_prompting_file(args.prompting_file) + + + # Model + disable_torch_init() + # model_name = os.path.expanduser(args.model_name) + print('start loading') + tokenizer = AutoTokenizer.from_pretrained(args.model_name) + print('finish loading') + + print('start loading') + model = GraphLlamaForCausalLM.from_pretrained(args.model_name, torch_dtype=torch.float16, use_cache=True, low_cpu_mem_usage=True).cuda() + print('finish loading') + + use_graph_start_end = getattr(model.config, "use_graph_start_end", False) + tokenizer.add_tokens([DEFAULT_GRAPH_PATCH_TOKEN], special_tokens=True) + if use_graph_start_end: + tokenizer.add_tokens([DEFAULT_G_START_TOKEN, DEFAULT_G_END_TOKEN], special_tokens=True) + + graph_tower = model.get_model().graph_tower + + # TODO: add graph tower + # if graph_tower.device.type == 'meta': + # print('meta') + clip_graph, args_graph= load_model_pretrained(CLIP, './clip_gt_arxiv_pub') + graph_tower = graph_transformer(args_graph) + graph_tower = transfer_param_tograph(clip_graph, graph_tower) + + model.get_model().graph_tower = graph_tower.cuda() + # else: + # print('other') + # print(next(graph_tower.parameters()).dtype) + graph_tower.to(device='cuda', dtype=torch.float16) + graph_config = graph_tower.config + graph_config.graph_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_GRAPH_PATCH_TOKEN])[0] + graph_config.use_graph_start_end = use_graph_start_end + if use_graph_start_end: + graph_config.graph_start_token, graph_config.graph_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_G_START_TOKEN, DEFAULT_G_END_TOKEN]) + # TODO: add graph token len + + res_data = [] + print(f'total: {len(prompt_file)}') + for idx, instruct_item in tqdm(enumerate(prompt_file)): + # instruct_item = prompt_file[0] + # if idx >= 3: + # break + task_type = instruct_item['id'].split('_')[-1] + if task_type != 'LP': + graph_dict = load_graph(instruct_item, args.graph_data_path) + graph_token_len = graph_dict['graph_token_len'] + graph_data = graph_dict['graph_data'] + + qs = instruct_item["conversations"][0]["value"] + # if use_graph_start_end: + # qs = qs + '\n' + DEFAULT_G_START_TOKEN + DEFAULT_GRAPH_PATCH_TOKEN * graph_token_len + DEFAULT_G_END_TOKEN + # else: + # qs = qs + '\n' + DEFAULT_GRAPH_PATCH_TOKEN * graph_token_len + + replace_token = DEFAULT_GRAPH_PATCH_TOKEN * graph_token_len + replace_token = DEFAULT_G_START_TOKEN + replace_token + DEFAULT_G_END_TOKEN + qs = qs.replace(DEFAULT_GRAPH_TOKEN, replace_token) + else: + graph_dict = load_graph_LP(instruct_item, args.graph_data_path) + graph_token_len_1 = graph_dict['graph_token_len_1'] + graph_token_len_2 = graph_dict['graph_token_len_2'] + graph_data = graph_dict['graph_data'] + + qs = instruct_item["conversations"][0]["value"] + # if use_graph_start_end: + # qs = qs + '\n' + DEFAULT_G_START_TOKEN + DEFAULT_GRAPH_PATCH_TOKEN * graph_token_len + DEFAULT_G_END_TOKEN + # else: + # qs = qs + '\n' + DEFAULT_GRAPH_PATCH_TOKEN * graph_token_len + + # replace_token = DEFAULT_GRAPH_PATCH_TOKEN * graph_token_len + # replace_token = DEFAULT_G_START_TOKEN + replace_token + DEFAULT_G_END_TOKEN + # qs = qs.replace(DEFAULT_GRAPH_TOKEN, replace_token) + + replace_token_1 = DEFAULT_GRAPH_PATCH_TOKEN * graph_token_len_1 + replace_token_2 = DEFAULT_GRAPH_PATCH_TOKEN * graph_token_len_2 + + replace_token_1 = DEFAULT_G_START_TOKEN + replace_token_1 + DEFAULT_G_END_TOKEN + replace_token_2 = DEFAULT_G_START_TOKEN + replace_token_2 + DEFAULT_G_END_TOKEN + + if DEFAULT_GRAPH_TOKEN in qs: + first_index = qs.find(DEFAULT_GRAPH_TOKEN) + qs = qs[:first_index] + replace_token_1 + qs[first_index+len(DEFAULT_GRAPH_TOKEN):] + + second_index = qs.find(DEFAULT_GRAPH_TOKEN) + qs = qs[:second_index] + replace_token_2 + qs[second_index+len(DEFAULT_GRAPH_TOKEN):] + + # if "v1" in args.model_name.lower(): + # conv_mode = "graphchat_v1" + # else: + # raise ValueError('Don\'t support this model') + conv_mode = "graphchat_v1" + + if args.conv_mode is not None and conv_mode != args.conv_mode: + print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode)) + else: + args.conv_mode = conv_mode + + conv = conv_templates[args.conv_mode].copy() + conv.append_message(conv.roles[0], qs) + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + inputs = tokenizer([prompt]) + + + + input_ids = torch.as_tensor(inputs.input_ids).cuda() + + stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 + keywords = [stop_str] + stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) + + if task_type != 'LP': + graph_data.graph_node = graph_data.graph_node.to(torch.float16) + graph_data = graph_data.cuda() + else: + graph_data['graph_1'].graph_node = graph_data['graph_1'].graph_node.to(torch.float16) + graph_data['graph_2'].graph_node = graph_data['graph_2'].graph_node.to(torch.float16) + + graph_data['graph_1'] = graph_data['graph_1'].cuda() + graph_data['graph_2'] = graph_data['graph_2'].cuda() + + # graph_data.edge_index = graph_data.edge_index.to(torch.float16) + + with torch.inference_mode(): + output_ids = model.generate( + input_ids, + graph_data=graph_data, + do_sample=True, + temperature=0.2, + max_new_tokens=1024, + stopping_criteria=[stopping_criteria]) + + input_token_len = input_ids.shape[1] + n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() + if n_diff_input_output > 0: + print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids') + outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] + outputs = outputs.strip() + if outputs.endswith(stop_str): + outputs = outputs[:-len(stop_str)] + outputs = outputs.strip() + # print(outputs) + + res_data.append({"id": instruct_item["id"], "node_idx_1": instruct_item["graph"]["node_idx_1"], "node_idx_2": instruct_item["graph"]["node_idx_2"], 'truth': instruct_item["conversations"][1]["value"], "res": outputs}.copy()) + with open(osp.join(args.output_res_path, 'arxiv_test_res_{}_{}.json'.format(start_idx, end_idx)), "w") as fout: + json.dump(res_data, fout, indent=4) + return res_data + # with open(args.output_res_path, "w") as fout: + # json.dump(res_data, fout, indent=4) + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model-name", type=str, default="facebook/opt-350m") + # parser.add_argument("--image-file", type=str, required=True) + # parser.add_argument("--query", type=str, required=True) + parser.add_argument("--prompting_file", type=str, default=None) + parser.add_argument("--conv-mode", type=str, default=None) + parser.add_argument("--graph_data_path", type=str, default=None) + + parser.add_argument("--output_res_path", type=str, default=None) + parser.add_argument("--num_gpus", type=int, default=4) + + parser.add_argument("--start_id", type=int, default=0) + parser.add_argument("--end_id", type=int, default=20567) + parser.add_argument("--is_shuffle", type=bool, default=False) + + args = parser.parse_args() + + # eval_model(args) + + ray.init() + run_eval(args, args.num_gpus) + + +# protobuf 4.22.3 \ No newline at end of file diff --git a/graphgpt/eval/run_vicuna.py b/graphgpt/eval/run_vicuna.py new file mode 100644 index 0000000..7cc30c8 --- /dev/null +++ b/graphgpt/eval/run_vicuna.py @@ -0,0 +1,203 @@ +import argparse +from transformers import AutoTokenizer, AutoModelForCausalLM +import torch +import os +from graphgpt.conversation import conv_templates, SeparatorStyle +from graphgpt.utils import disable_torch_init +from transformers import CLIPVisionModel, CLIPImageProcessor, StoppingCriteria +from graphgpt.model import * +from graphgpt.model.utils import KeywordsStoppingCriteria +from torch_geometric.data import Data +import json +import copy +import re + +import os +import os.path as osp +import requests +from PIL import Image +from io import BytesIO + +from tqdm import tqdm +import json + +import ray + +# os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python' + +DEFAULT_GRAPH_TOKEN = "" +DEFAULT_GRAPH_PATCH_TOKEN = "" +DEFAULT_G_START_TOKEN = "" +DEFAULT_G_END_TOKEN = "" + + +def load_graph(instruct_item, graph_data_path): + graph_data_all = torch.load(graph_data_path) + graph_dict = instruct_item['graph'] + graph_edge_index = torch.Tensor(copy.deepcopy(graph_dict['edge_index'])).long() + graph_node_list = copy.deepcopy(graph_dict['node_list']) + target_node = copy.deepcopy(graph_dict['node_idx']) + graph_type = copy.deepcopy(instruct_item['id']).split('_')[0] + graph_node_rep = graph_data_all[graph_type].x[graph_node_list] ## + + cur_token_len = len(graph_node_rep) # FIXME: 14 is hardcoded patch size + + graph_ret = Data(graph_node = graph_node_rep, edge_index=graph_edge_index, target_node = torch.tensor([target_node])) + + return { + 'graph_data': graph_ret, + 'graph_token_len': cur_token_len + } + + +def load_prompting_file(file_path): + with open(file_path, 'r') as f: + data = json.load(f) + return data + +# def prepare_query(instruct_item): + + +def run_eval(args, num_gpus): + # split question file into num_gpus files + prompt_file = load_prompting_file(args.prompting_file) + prompt_file = prompt_file[args.start_id:args.end_id] + chunk_size = len(prompt_file) // num_gpus + ans_handles = [] + split_list = list(range(args.start_id, args.end_id, chunk_size)) + idx_list = list(range(0, len(prompt_file), chunk_size)) + if len(split_list) == num_gpus: + split_list.append(args.end_id) + idx_list.append(len(prompt_file)) + elif len(split_list) == num_gpus + 1: + split_list[-1] = args.end_id + idx_list[-1] = len(prompt_file) + else: + raise ValueError('error in the number of list') + + if osp.exists(args.output_res_path) is False: + os.mkdir(args.output_res_path) + + for idx in range(len(idx_list) - 1): + start_idx = idx_list[idx] + end_idx = idx_list[idx + 1] + + start_split = split_list[idx] + end_split = split_list[idx + 1] + ans_handles.append( + eval_model.remote( + args, prompt_file[start_idx:end_idx], start_split, end_split + ) + ) + + ans_jsons = [] + + for ans_handle in ans_handles: + ans_jsons.extend(ray.get(ans_handle)) + + # with open(args.output_res_path, "w") as ans_file: + # for line in ans_jsons: + # ans_file.write(json.dumps(line) + "\n") + + +@ray.remote(num_gpus=1) +@torch.inference_mode() +def eval_model(args, prompt_file, start_id, end_id): + # load prompting file + # prompt_file = load_prompting_file(args.prompting_file) + + + # Model + + disable_torch_init() + # model_name = os.path.expanduser(args.model_name) + print('start loading') + tokenizer = AutoTokenizer.from_pretrained(args.model_name) + print('finish loading') + + print('start loading') + model = AutoModelForCausalLM.from_pretrained(args.model_name, torch_dtype=torch.float16, use_cache=True, low_cpu_mem_usage=True).cuda() + print('finish loading') + + + print(f'total: {len(prompt_file)}') + res_file = osp.join(args.output_res_path, f'arxiv_test_res_{start_id}_{end_id}.json') + if osp.exists(res_file): + with open(res_file, 'r') as f: + res_data = json.load(f) + ready_len = len(res_data) + if ready_len == (end_id - start_id): + return res_data + else: + res_data = [] + ready_len = 0 + print('*'*10, 'create res file', '*'*10) + with open(res_file, 'w') as f: + json.dump(res_data, f) + + + + for idx, instruct_item in tqdm(enumerate(prompt_file[ready_len:])): + # instruct_item = prompt_file[0] + # if idx >= 3: + # break + + qs = instruct_item["conversations"][0]["value"] + + pattern = r'' + + qs = re.sub(pattern, '', qs) + + conv_mode = "vicuna_v1_1" + + if args.conv_mode is not None and conv_mode != args.conv_mode: + print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode)) + else: + args.conv_mode = conv_mode + + conv = conv_templates[args.conv_mode].copy() + conv.append_message(conv.roles[0], qs) + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + input_ids = tokenizer([prompt]).input_ids + + output_ids = model.generate( + torch.as_tensor(input_ids).cuda(), + do_sample=True, + temperature=0.7, + max_new_tokens=1024, + ) + output_ids = output_ids[0][len(input_ids[0]) :] + outputs = tokenizer.decode(output_ids, skip_special_tokens=True).strip() + + res_data.append({"id": instruct_item["id"], "node_idx": instruct_item["graph"]["node_idx"], "res": outputs}.copy()) + with open(res_file, 'w') as f: + json.dump(res_data, f) + return res_data + # with open(args.output_res_path, "w") as fout: + # json.dump(res_data, fout, indent=4) + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model-name", type=str, default="facebook/opt-350m") + # parser.add_argument("--image-file", type=str, required=True) + # parser.add_argument("--query", type=str, required=True) + parser.add_argument("--prompting_file", type=str, default=None) + parser.add_argument("--conv-mode", type=str, default=None) + parser.add_argument("--graph_data_path", type=str, default=None) + + parser.add_argument("--output_res_path", type=str, default=None) + parser.add_argument("--num_gpus", type=int, default=2) + + parser.add_argument("--start_id", type=int, default=0) + parser.add_argument("--end_id", type=int, default=20567) + + args = parser.parse_args() + + # eval_model(args) + + ray.init() + run_eval(args, args.num_gpus) + + +# protobuf 4.22.3 \ No newline at end of file diff --git a/graphgpt/eval/script/run_model_qa.yaml b/graphgpt/eval/script/run_model_qa.yaml new file mode 100644 index 0000000..64e3656 --- /dev/null +++ b/graphgpt/eval/script/run_model_qa.yaml @@ -0,0 +1,48 @@ +resources: + accelerators: A100:4 + cloud: gcp + +num_nodes: 1 + +workdir: . + +setup: | + conda activate chatbot + if [ $? -eq 0 ]; then + echo 'conda env exists' + else + # Setup the environment + conda create -n chatbot python=3.10 -y + fi + conda activate chatbot + + pip3 install -e . + + # Install pytorch + pip install torch==1.13.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116 + + # Install huggingface with the LLaMA commit + pip install git+https://github.com/huggingface/transformers.git@c612628045822f909020f7eb6784c79700813eda + + cd fastchat/eval + pip install -r requirements.txt + + MODEL_NAME=vicuna-7b-20230322-fp16 + MODEL_PATH=~/${MODEL_NAME} + + if [ ! -f "$MODEL_PATH/ready" ]; then + echo "export MODEL_PATH=${MODEL_PATH}" >> ~/.bashrc + echo "export MODEL_NAME=${MODEL_NAME}" >> ~/.bashrc + mkdir -p $MODEL_PATH + gsutil -m cp gs://model-weights/${MODEL_NAME}/* $MODEL_PATH + touch $MODEL_PATH/ready + echo "model downloaded" + fi + +run: | + conda activate chatbot + python -m fastchat.eval.get_model_answer --model-path $MODEL_PATH \ + --model-id $MODEL_NAME \ + --question-file fastchat/eval/table/question.jsonl \ + --answer-file answer.jsonl \ + --num-gpus $SKYPILOT_NUM_GPUS_PER_NODE diff --git a/graphgpt/model/GraphLlama.py b/graphgpt/model/GraphLlama.py new file mode 100644 index 0000000..a9ab31a --- /dev/null +++ b/graphgpt/model/GraphLlama.py @@ -0,0 +1,435 @@ +# Copyright 2023 Haotian Liu +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import CrossEntropyLoss + +from transformers import AutoConfig, AutoModelForCausalLM, \ + LlamaConfig, LlamaModel, LlamaForCausalLM, \ + CLIPVisionModel, CLIPImageProcessor + +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast + +from graphgpt.model.graph_layers import MPNN, GNN, CLIP, graph_transformer +from torch_geometric.data import Data +import json +import os.path as osp +import glob + +DEFAULT_GRAPH_TOKEN = "" +DEFAULT_GRAPH_PATCH_TOKEN = "" +DEFAULT_G_START_TOKEN = "" +DEFAULT_G_END_TOKEN = "" + + + + +class GraphLlamaConfig(LlamaConfig): + model_type = "GraphLlama" + +class GraphPretrainConfig: + def __init__(self, dictionary): + for key, value in dictionary.items(): + setattr(self, key, value) + +def load_model_pretrained(model_name, pretrain_model_path): + # load conig json + + assert osp.exists(osp.join(pretrain_model_path, 'config.json')), 'config.json missing' + with open(osp.join(pretrain_model_path, 'config.json'), 'r') as f: + config_dict = json.load(f) + args = GraphPretrainConfig(config_dict) + model = model_name(args) + pkl_files = glob.glob(osp.join(pretrain_model_path, '*.pkl')) + state_dict = torch.load(pkl_files[0]) + # print(state_dict.keys()) + if 'logit_scale' in state_dict.keys(): + state_dict.pop('logit_scale') + print('loading graph pre train model') + model.load_state_dict(state_dict) + + + return model, args +def transfer_param_tograph(clip_graph, gnn): + + print(clip_graph) + gnn_state_dict = clip_graph.gnn.state_dict() + gnn.load_state_dict(gnn_state_dict) + return gnn + + +class GraphLlamaModel(LlamaModel): + config_class = GraphLlamaConfig + + def __init__(self, config: LlamaConfig): + super(GraphLlamaModel, self).__init__(config) + + if hasattr(config, "graph_tower"): + # HACK: for FSDP + # self.vision_tower = [CLIPVisionModel.from_pretrained(config.graph_tower)] + # self.arxiv_projector = nn.Linear(config.graph_hidden_size, config.hidden_size) + if config.graph_tower == 'MPNN': + self.graph_tower = MPNN(in_channels = config.graph_hidden_size, hidden_channels = config.graph_hidden_size * 2, out_channels = config.graph_hidden_size, dropout = 0.1, num_layers = 2, if_param = False) + elif config.graph_tower == "clip_gcn_arxiv": + + clip_graph, args= load_model_pretrained(CLIP, config.pretrain_graph_model_path) + self.graph_tower = GNN(args) + self.graph_tower = transfer_param_tograph(clip_graph, self.graph_tower) + elif config.graph_tower == "clip_gt": + clip_graph, args= load_model_pretrained(CLIP, config.pretrain_graph_model_path) + self.graph_tower = graph_transformer(args) + self.graph_tower = transfer_param_tograph(clip_graph, self.graph_tower) + elif config.graph_tower == "clip_gt_arxiv": + clip_graph, args= load_model_pretrained(CLIP, config.pretrain_graph_model_path) + self.graph_tower = graph_transformer(args) + self.graph_tower = transfer_param_tograph(clip_graph, self.graph_tower) + elif config.graph_tower == "clip_gt_arxiv_pub": + clip_graph, args= load_model_pretrained(CLIP, config.pretrain_graph_model_path) + self.graph_tower = graph_transformer(args) + self.graph_tower = transfer_param_tograph(clip_graph, self.graph_tower) + + + + # self.vision_tower = CLIPVisionModel.from_pretrained(config.mm_vision_tower) + + if hasattr(config, "use_graph_proj"): + self.graph_projector = nn.Linear(config.graph_hidden_size, config.hidden_size) + + def get_graph_tower(self): + graph_tower = getattr(self, 'graph_tower', None) + if type(graph_tower) is list: + graph_tower = graph_tower[0] + return graph_tower + + def initialize_graph_modules(self, graph_tower, graph_select_layer, + pretrain_graph_mlp_adapter=None, fsdp=None): # TODO: modify this function + self.config.graph_tower = graph_tower + + + if not hasattr(self, 'graph_tower'): + if self.config.graph_tower == 'MPNN': + graph_tower = MPNN(in_channels = self.config.graph_hidden_size, hidden_channels = self.config.graph_hidden_size * 2, out_channels = self.config.graph_hidden_size, dropout = 0.1, num_layers = 2, if_param = False) + elif self.config.graph_tower == "clip_gcn_arxiv": + + clip_graph, args= load_model_pretrained(CLIP, self.config.pretrain_graph_model_path) + graph_tower = GNN(args) + graph_tower = transfer_param_tograph(clip_graph, graph_tower) + elif self.config.graph_tower == "clip_gt": + clip_graph, args= load_model_pretrained(CLIP, self.config.pretrain_graph_model_path) + graph_tower = graph_transformer(args) + graph_tower = transfer_param_tograph(clip_graph, graph_tower) + # graph_tower = MPNN(in_channels = self.config.graph_hidden_size, hidden_channels = self.config.graph_hidden_size * 2, out_channels = self.config.graph_hidden_size, dropout = 0.1, num_layers = 2) + elif self.config.graph_tower == "clip_gt_arxiv": + clip_graph, args= load_model_pretrained(CLIP, self.config.pretrain_graph_model_path) + graph_tower = graph_transformer(args) + graph_tower = transfer_param_tograph(clip_graph, graph_tower) + elif self.config.graph_tower == "clip_gt_arxiv_pub": + clip_graph, args= load_model_pretrained(CLIP, self.config.pretrain_graph_model_path) + graph_tower = graph_transformer(args) + graph_tower = transfer_param_tograph(clip_graph, graph_tower) + else: + graph_tower = self.graph_tower + graph_tower.requires_grad_(False) + + if fsdp is not None and len(fsdp) > 0: + self.graph_tower = [graph_tower] + else: + self.graph_tower = graph_tower + + + + self.config.use_graph_proj = True + self.config.graph_select_layer = graph_select_layer + + if not hasattr(self, 'graph_projector'): + self.graph_projector = nn.Linear(self.config.graph_hidden_size, self.config.hidden_size) + + if pretrain_graph_mlp_adapter is not None: + graph_projector_weights = torch.load(pretrain_graph_mlp_adapter, map_location='cpu') + self.graph_projector.load_state_dict({k.split('.')[-1]: v for k, v in graph_projector_weights.items()}) + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + # graph_node_reps: Optional[torch.FloatTensor] = None, + # edge_index_reps: Optional[torch.FloatTensor] = None, + graph_data: Optional[Data] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + + # HACK: replace back original embeddings for LLaVA pretraining + orig_embeds_params = getattr(self, 'orig_embeds_params', None) + # if orig_embeds_params is not None: + # orig_embeds_params = orig_embeds_params[0] + # with torch.no_grad(): + # self.get_input_embeddings().weight.data[:-2] = orig_embeds_params[:-2].data + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + graph_tower = self.get_graph_tower() + if graph_tower is not None and (input_ids.shape[1] != 1 or self.training) and graph_data is not None: + # TODO: this is a modified multimodal LLM -- Haotian Liu + with torch.no_grad(): + if type(graph_data) is list: + # variable length images + graph_node_features = [] + if type(graph_data[0]) is Data: + for g in graph_data: + # print(g) + node_forward_out = graph_tower(g) + graph_node_features.append(node_forward_out) + elif type(graph_data[0]) is dict: + for g_dict in graph_data: + node_forward_out_1 = graph_tower(g_dict['graph_1']) + node_forward_out_2 = graph_tower(g_dict['graph_2']) + graph_node_features.append(node_forward_out_1) + graph_node_features.append(node_forward_out_2) + else: + raise ValueError(f'graph_node_reps is expected to be a list but got {type(graph_data)}') + if type(graph_data) is list: + # if type(graph_node_features[0]) is not dict: + graph_node_features = [self.graph_projector(node_feature) for node_feature in graph_node_features] + # else: + # graph_node_features = [{'graph_1': self.graph_projector(node_feature['graph_1']), 'graph_2': self.graph_projector(node_feature['graph_2'])} for node_feature in graph_node_features] + else: + raise ValueError(f'graph_node_reps is expected to be a list but got {type(graph_data)}') + dummy_graph_features = torch.zeros(256, 128, device=inputs_embeds.device, dtype=inputs_embeds.dtype) + dummy_graph_features = self.graph_projector(dummy_graph_features) + + new_input_embeds = [] + cur_graph_idx = 0 + for cur_input_ids, cur_input_embeds in zip(input_ids, inputs_embeds): + if (cur_input_ids == graph_tower.config.graph_patch_token).sum() == 0: + # multimodal LLM, but the current sample is not multimodal + cur_input_embeds = cur_input_embeds + (0. * dummy_graph_features).sum() + new_input_embeds.append(cur_input_embeds) + cur_graph_idx += 1 + continue + if graph_tower.config.use_graph_start_end: + cur_graph_features = graph_node_features[cur_graph_idx] + num_patches = cur_graph_features.shape[0] + if (cur_input_ids == graph_tower.config.graph_start_token).sum() != (cur_input_ids == graph_tower.config.graph_end_token).sum(): + raise ValueError("The number of graph start tokens and graph end tokens should be the same.") + graph_start_tokens = torch.where(cur_input_ids == graph_tower.config.graph_start_token)[0] + # print(graph_start_tokens) + for graph_start_token_pos in graph_start_tokens: + cur_graph_features = graph_node_features[cur_graph_idx].to(device=cur_input_embeds.device) + num_patches = cur_graph_features.shape[0] + if cur_input_ids[graph_start_token_pos + num_patches + 1] != graph_tower.config.graph_end_token: + raise ValueError("The graph end token should follow the graph start token.") + if orig_embeds_params is not None: + cur_new_input_embeds = torch.cat((cur_input_embeds[:graph_start_token_pos].detach(), cur_input_embeds[graph_start_token_pos:graph_start_token_pos+1], cur_graph_features, cur_input_embeds[graph_start_token_pos + num_patches + 1:graph_start_token_pos + num_patches + 2], cur_input_embeds[graph_start_token_pos + num_patches + 2:].detach()), dim=0) + else: + cur_new_input_embeds = torch.cat((cur_input_embeds[:graph_start_token_pos+1], cur_graph_features, cur_input_embeds[graph_start_token_pos + num_patches + 1:]), dim=0) + cur_graph_idx += 1 + new_input_embeds.append(cur_new_input_embeds) + else: + cur_graph_features = graph_node_features[cur_graph_idx] + num_patches = cur_graph_features.shape[0] + if (cur_input_ids == graph_tower.config.graph_patch_token).sum() != num_patches: + raise ValueError("The number of graph patch tokens should be the same as the number of graph patches.") + masked_indices = torch.where(cur_input_ids == graph_tower.config.graph_patch_token)[0] + mask_index_start = masked_indices[0] + if (masked_indices != torch.arange(mask_index_start, mask_index_start+num_patches, device=masked_indices.device, dtype=masked_indices.dtype)).any(): + raise ValueError("The graph patch tokens should be consecutive.") + if orig_embeds_params is not None: + cur_new_input_embeds = torch.cat((cur_input_embeds[:mask_index_start].detach(), cur_graph_features, cur_input_embeds[mask_index_start+num_patches:].detach()), dim=0) + else: + cur_new_input_embeds = torch.cat((cur_input_embeds[:mask_index_start], cur_graph_features, cur_input_embeds[mask_index_start+num_patches:]), dim=0) + new_input_embeds.append(cur_new_input_embeds) + cur_graph_idx += 1 + + # print(cur_graph_idx) + # print(len(graph_node_features)) + assert cur_graph_idx == len(graph_node_features) + inputs_embeds = torch.stack(new_input_embeds, dim=0) + + return super(GraphLlamaModel, self).forward( + input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values, + inputs_embeds=inputs_embeds, use_cache=use_cache, + output_attentions=output_attentions, output_hidden_states=output_hidden_states, + return_dict=return_dict + ) + + +class GraphLlamaForCausalLM(LlamaForCausalLM): + config_class = GraphLlamaConfig + + def __init__(self, config): + super(LlamaForCausalLM, self).__init__(config) + self.model = GraphLlamaModel(config) + + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_model(self): + return self.model + + def get_graph_tower(self): + return self.get_model().get_graph_tower() + + def get_vision_tower(self): + model = self.get_model() + graph_tower = model.graph_tower + if type(graph_tower) is list: + graph_tower = graph_tower[0] + return graph_tower + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + # graph_node_reps: Optional[torch.FloatTensor] = None, + # edge_index_reps: Optional[torch.FloatTensor] = None, + graph_data: Optional[Data] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + # graph_node_reps=graph_node_reps, + # edge_index_reps=edge_index_reps + graph_data = graph_data + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model/pipeline parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + ): + if past_key_values: + input_ids = input_ids[:, -1:] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + "graph_data": [kwargs.get("graph_data", None)], + # "edge_index_reps": kwargs.get("edge_index_reps", None), + } + ) + return model_inputs + + def initialize_graph_tokenizer(self, use_graph_start_end, tokenizer, device, + tune_graph_mlp_adapter=False, pretrain_graph_mlp_adapter=None): + vision_config = self.get_graph_tower().config + vision_config.use_graph_start_end = use_graph_start_end + tokenizer.add_tokens([DEFAULT_GRAPH_PATCH_TOKEN], special_tokens=True) + self.resize_token_embeddings(len(tokenizer)) + + if use_graph_start_end: + num_new_tokens = tokenizer.add_tokens([DEFAULT_G_START_TOKEN, DEFAULT_G_END_TOKEN], special_tokens=True) + self.resize_token_embeddings(len(tokenizer)) + vision_config.graph_start_token, vision_config.graph_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_G_START_TOKEN, DEFAULT_G_END_TOKEN]) + + if num_new_tokens > 0: + input_embeddings = self.get_input_embeddings().weight.data + output_embeddings = self.get_output_embeddings().weight.data + + input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( + dim=0, keepdim=True) + output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( + dim=0, keepdim=True) + + input_embeddings[-num_new_tokens:] = input_embeddings_avg + output_embeddings[-num_new_tokens:] = output_embeddings_avg + + if tune_graph_mlp_adapter: + self.get_model().orig_embeds_params = [self.get_input_embeddings().weight.data.clone().to(device=device)] + for p in self.get_input_embeddings().parameters(): + p.requires_grad = True + for p in self.get_output_embeddings().parameters(): + p.requires_grad = False + + if pretrain_graph_mlp_adapter: + mm_projector_weights = torch.load(pretrain_graph_mlp_adapter, map_location='cpu') + embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight'] + assert num_new_tokens == 2 + if input_embeddings.shape == embed_tokens_weight.shape: + input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:] + elif embed_tokens_weight.shape[0] == num_new_tokens: + input_embeddings[-num_new_tokens:] = embed_tokens_weight + else: + raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.") + + vision_config.graph_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_GRAPH_PATCH_TOKEN])[0] + +AutoConfig.register("GraphLlama", GraphLlamaConfig) +AutoModelForCausalLM.register(GraphLlamaConfig, GraphLlamaForCausalLM) diff --git a/graphgpt/model/__init__.py b/graphgpt/model/__init__.py new file mode 100644 index 0000000..0c8966f --- /dev/null +++ b/graphgpt/model/__init__.py @@ -0,0 +1,8 @@ +from graphgpt.model.model_adapter import ( + load_model, + get_conversation_template, + add_model_args, +) + +from graphgpt.model.GraphLlama import GraphLlamaForCausalLM, load_model_pretrained, transfer_param_tograph +from graphgpt.model.graph_layers.clip_graph import GNN, graph_transformer, CLIP diff --git a/graphgpt/model/apply_delta.py b/graphgpt/model/apply_delta.py new file mode 100644 index 0000000..ba1c06d --- /dev/null +++ b/graphgpt/model/apply_delta.py @@ -0,0 +1,165 @@ +""" +Apply the delta weights on top of a base model. + +Usage: +python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta-v1.1 +""" +import argparse +import gc +import glob +import json +import os +import shutil +import tempfile + +from huggingface_hub import snapshot_download +import torch +from torch import nn +from tqdm import tqdm +from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig + + +GB = 1 << 30 + + +def split_files(model_path, tmp_path, split_size): + if not os.path.exists(model_path): + model_path = snapshot_download(repo_id=model_path) + if not os.path.exists(tmp_path): + os.makedirs(tmp_path) + + file_pattern = os.path.join(model_path, "pytorch_model-*.bin") + files = glob.glob(file_pattern) + + part = 0 + try: + for file_path in tqdm(files): + state_dict = torch.load(file_path) + new_state_dict = {} + + current_size = 0 + for name, param in state_dict.items(): + param_size = param.numel() * param.element_size() + + if current_size + param_size > split_size: + new_file_name = f"pytorch_model-{part}.bin" + new_file_path = os.path.join(tmp_path, new_file_name) + torch.save(new_state_dict, new_file_path) + current_size = 0 + new_state_dict = None + gc.collect() + new_state_dict = {} + part += 1 + + new_state_dict[name] = param + current_size += param_size + + new_file_name = f"pytorch_model-{part}.bin" + new_file_path = os.path.join(tmp_path, new_file_name) + torch.save(new_state_dict, new_file_path) + new_state_dict = None + gc.collect() + new_state_dict = {} + part += 1 + except Exception as e: + print(f"An error occurred during split_files: {e}") + shutil.rmtree(tmp_path) + raise + + +def apply_delta_low_cpu_mem(base_model_path, target_model_path, delta_path): + delta_tokenizer = AutoTokenizer.from_pretrained(delta_path, use_fast=False) + delta_config = AutoConfig.from_pretrained(delta_path) + + if os.path.exists(target_model_path): + shutil.rmtree(target_model_path) + os.makedirs(target_model_path) + + split_size = 4 * GB + + with tempfile.TemporaryDirectory() as tmp_base_path, tempfile.TemporaryDirectory() as tmp_delta_path: + print(f"Split files for the base model to {tmp_base_path}") + split_files(base_model_path, tmp_base_path, split_size) + print(f"Split files for the delta weights to {tmp_delta_path}") + split_files(delta_path, tmp_delta_path, split_size) + + base_pattern = os.path.join(tmp_base_path, "pytorch_model-*.bin") + base_files = glob.glob(base_pattern) + delta_pattern = os.path.join(tmp_delta_path, "pytorch_model-*.bin") + delta_files = glob.glob(delta_pattern) + delta_state_dict = torch.load(delta_files[0]) + + print("Applying the delta") + weight_map = {} + total_size = 0 + + for i, base_file in tqdm(enumerate(base_files)): + state_dict = torch.load(base_file) + file_name = f"pytorch_model-{i}.bin" + for name, param in state_dict.items(): + if name not in delta_state_dict: + for delta_file in delta_files: + delta_state_dict = torch.load(delta_file) + gc.collect() + if name in delta_state_dict: + break + + state_dict[name] += delta_state_dict[name] + weight_map[name] = file_name + total_size += param.numel() * param.element_size() + gc.collect() + torch.save(state_dict, os.path.join(target_model_path, file_name)) + + with open( + os.path.join(target_model_path, "pytorch_model.bin.index.json"), "w" + ) as f: + json.dump( + {"weight_map": weight_map, "metadata": {"total_size": total_size}}, f + ) + + print(f"Saving the target model to {target_model_path}") + delta_tokenizer.save_pretrained(target_model_path) + delta_config.save_pretrained(target_model_path) + + +def apply_delta(base_model_path, target_model_path, delta_path): + print(f"Loading the delta weights from {delta_path}") + delta_tokenizer = AutoTokenizer.from_pretrained(delta_path, use_fast=False) + delta = AutoModelForCausalLM.from_pretrained( + delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True + ) + + print(f"Loading the base model from {base_model_path}") + base = AutoModelForCausalLM.from_pretrained( + base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True + ) + + print("Applying the delta") + for name, param in tqdm(base.state_dict().items(), desc="Applying delta"): + assert name in delta.state_dict() + param.data += delta.state_dict()[name] + + print(f"Saving the target model to {target_model_path}") + base.save_pretrained(target_model_path) + delta_tokenizer.save_pretrained(target_model_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--base-model-path", type=str, required=True) + parser.add_argument("--target-model-path", type=str, required=True) + parser.add_argument("--delta-path", type=str, required=True) + parser.add_argument( + "--low-cpu-mem", + action="store_true", + help="Lower the cpu memory usage. This will split large files and use " + "disk as swap to reduce the memory usage below 10GB.", + ) + args = parser.parse_args() + + if args.low_cpu_mem: + apply_delta_low_cpu_mem( + args.base_model_path, args.target_model_path, args.delta_path + ) + else: + apply_delta(args.base_model_path, args.target_model_path, args.delta_path) diff --git a/graphgpt/model/apply_lora.py b/graphgpt/model/apply_lora.py new file mode 100644 index 0000000..870e64a --- /dev/null +++ b/graphgpt/model/apply_lora.py @@ -0,0 +1,48 @@ +""" +Apply the LoRA weights on top of a base model. + +Usage: +python3 -m fastchat.model.apply_lora --base ~/model_weights/llama-7b --target ~/model_weights/baize-7b --lora project-baize/baize-lora-7B + +Dependency: +pip3 install git+https://github.com/huggingface/peft.git@2822398fbe896f25d4dac5e468624dc5fd65a51b +""" +import argparse + +import torch +from peft import PeftModel +from transformers import AutoTokenizer, AutoModelForCausalLM + + +def apply_lora(base_model_path, target_model_path, lora_path): + print(f"Loading the base model from {base_model_path}") + base = AutoModelForCausalLM.from_pretrained( + base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True + ) + base_tokenizer = AutoTokenizer.from_pretrained(base_model_path, use_fast=False) + + print(f"Loading the LoRA adapter from {lora_path}") + + lora_model = PeftModel.from_pretrained( + base, + lora_path, + torch_dtype=torch.float16, + ) + + print("Applying the LoRA") + model = lora_model.merge_and_unload() + + print(f"Saving the target model to {target_model_path}") + model.save_pretrained(target_model_path) + base_tokenizer.save_pretrained(target_model_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--base-model-path", type=str, required=True) + parser.add_argument("--target-model-path", type=str, required=True) + parser.add_argument("--lora-path", type=str, required=True) + + args = parser.parse_args() + + apply_lora(args.base_model_path, args.target_model_path, args.lora_path) diff --git a/graphgpt/model/builder.py b/graphgpt/model/builder.py new file mode 100644 index 0000000..d8398a2 --- /dev/null +++ b/graphgpt/model/builder.py @@ -0,0 +1,145 @@ +# Copyright 2023 Haotian Liu +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +import shutil + +from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig +import torch +from graphgpt.model import * +from graphgpt.constants import DEFAULT_GRAPH_PATCH_TOKEN, DEFAULT_G_START_TOKEN, DEFAULT_G_END_TOKEN + + +def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto"): + kwargs = {"device_map": device_map} + + if load_8bit: + kwargs['load_in_8bit'] = True + elif load_4bit: + kwargs['load_in_4bit'] = True + kwargs['quantization_config'] = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4' + ) + else: + kwargs['torch_dtype'] = torch.float16 + + if 'graphchat' in model_name.lower(): + # Load LLaVA model + if 'lora' in model_name.lower() and model_base is not None: + lora_cfg_pretrained = AutoConfig.from_pretrained(model_path) + tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) + print('Loading LLaVA from base model...') + model = GraphLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs) + token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features + if model.lm_head.weight.shape[0] != token_num: + model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype)) + model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype)) + + print('Loading additional LLaVA weights...') + if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')): + non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu') + else: + # this is probably from HF Hub + from huggingface_hub import hf_hub_download + def load_from_hf(repo_id, filename, subfolder=None): + cache_file = hf_hub_download( + repo_id=repo_id, + filename=filename, + subfolder=subfolder) + return torch.load(cache_file, map_location='cpu') + non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin') + non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()} + if any(k.startswith('model.model.') for k in non_lora_trainables): + non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()} + model.load_state_dict(non_lora_trainables, strict=False) + + from peft import PeftModel + print('Loading LoRA weights...') + model = PeftModel.from_pretrained(model, model_path) + print('Merging LoRA weights...') + model = model.merge_and_unload() + print('Model is loaded...') + elif model_base is not None: + # this may be mm projector only + print('Loading LLaVA from base model...') + if 'mpt' in model_name.lower(): + if not os.path.isfile(os.path.join(model_path, 'configuration_mpt.py')): + shutil.copyfile(os.path.join(model_base, 'configuration_mpt.py'), os.path.join(model_path, 'configuration_mpt.py')) + tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True) + cfg_pretrained = AutoConfig.from_pretrained(model_path, trust_remote_code=True) + model = LlavaMPTForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs) + else: + tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) + cfg_pretrained = AutoConfig.from_pretrained(model_path) + model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs) + + mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu') + mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()} + model.load_state_dict(mm_projector_weights, strict=False) + else: + if 'mpt' in model_name.lower(): + tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True) + model = LlavaMPTForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) + else: + tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) + model = LlavaLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) + else: + # Load language model + if model_base is not None: + # PEFT model + from peft import PeftModel + tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) + model = AutoModelForCausalLM.from_pretrained(model_base, torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map="auto") + print(f"Loading LoRA weights from {model_path}") + model = PeftModel.from_pretrained(model, model_path) + print(f"Merging weights") + model = model.merge_and_unload() + print('Convert to FP16...') + model.to(torch.float16) + else: + use_fast = False + if 'mpt' in model_name.lower(): + tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True) + model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, trust_remote_code=True, **kwargs) + else: + tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) + model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) + + image_processor = None + + if 'llava' in model_name.lower(): + mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False) + mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True) + if mm_use_im_patch_token: + tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) + if mm_use_im_start_end: + tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) + model.resize_token_embeddings(len(tokenizer)) + + vision_tower = model.get_vision_tower() + if not vision_tower.is_loaded: + vision_tower.load_model() + vision_tower.to(device='cuda', dtype=torch.float16) + image_processor = vision_tower.image_processor + + if hasattr(model.config, "max_sequence_length"): + context_len = model.config.max_sequence_length + else: + context_len = 2048 + + return tokenizer, model, image_processor, context_len \ No newline at end of file diff --git a/graphgpt/model/compression.py b/graphgpt/model/compression.py new file mode 100644 index 0000000..e06c2b2 --- /dev/null +++ b/graphgpt/model/compression.py @@ -0,0 +1,228 @@ +import dataclasses +import gc +import glob +import os + +from accelerate import init_empty_weights +from accelerate.utils import set_module_tensor_to_device +import torch +from torch import Tensor +import torch.nn as nn +from torch.nn import functional as F +from tqdm import tqdm +from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig + + +@dataclasses.dataclass +class CompressionConfig: + """Group-wise quantization.""" + + num_bits: int + group_size: int + group_dim: int + symmetric: bool + enabled: bool = True + + +default_compression_config = CompressionConfig( + num_bits=8, group_size=256, group_dim=1, symmetric=True, enabled=True +) + + +class CLinear(nn.Module): + """Compressed Linear Layer.""" + + def __init__(self, weight=None, bias=None, device=None): + super().__init__() + if weight is None: + self.weight = None + elif isinstance(weight, Tensor): + self.weight = compress(weight.data.to(device), default_compression_config) + else: + self.weight = weight + self.bias = bias + + def forward(self, input: Tensor) -> Tensor: + weight = decompress(self.weight, default_compression_config) + return F.linear(input.to(weight.dtype), weight, self.bias) + + +def compress_module(module, target_device): + for attr_str in dir(module): + target_attr = getattr(module, attr_str) + if type(target_attr) == torch.nn.Linear: + setattr( + module, + attr_str, + CLinear(target_attr.weight, target_attr.bias, target_device), + ) + for name, child in module.named_children(): + compress_module(child, target_device) + + +def get_compressed_list(module, prefix=""): + compressed_list = [] + for attr_str in dir(module): + target_attr = getattr(module, attr_str) + if type(target_attr) == torch.nn.Linear: + full_name = ( + f"{prefix}.{attr_str}.weight" if prefix else f"{attr_str}.weight" + ) + compressed_list.append(full_name) + for name, child in module.named_children(): + child_prefix = f"{prefix}.{name}" if prefix else name + for each in get_compressed_list(child, child_prefix): + compressed_list.append(each) + return compressed_list + + +def apply_compressed_weight(module, compressed_state_dict, target_device, prefix=""): + for attr_str in dir(module): + target_attr = getattr(module, attr_str) + if type(target_attr) == torch.nn.Linear: + full_name = ( + f"{prefix}.{attr_str}.weight" if prefix else f"{attr_str}.weight" + ) + setattr( + module, + attr_str, + CLinear( + compressed_state_dict[full_name], target_attr.bias, target_device + ), + ) + for name, child in module.named_children(): + child_prefix = f"{prefix}.{name}" if prefix else name + apply_compressed_weight( + child, compressed_state_dict, target_device, child_prefix + ) + + +def load_compress_model(model_path, device, torch_dtype): + # partially load model + tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) + base_pattern = os.path.join(model_path, "pytorch_model-*.bin") + files = glob.glob(base_pattern) + + with init_empty_weights(): + config = AutoConfig.from_pretrained( + model_path, low_cpu_mem_usage=True, torch_dtype=torch_dtype + ) + model = AutoModelForCausalLM.from_config(config) + linear_weights = get_compressed_list(model) + + compressed_state_dict = {} + + for filename in tqdm(files): + tmp_state_dict = torch.load(filename) + for name in tmp_state_dict: + if name in linear_weights: + tensor = tmp_state_dict[name].to(device).data.to(torch_dtype) + compressed_state_dict[name] = compress( + tensor, default_compression_config + ) + else: + compressed_state_dict[name] = tmp_state_dict[name].to(device) + tmp_state_dict[name] = None + tensor = None + gc.collect() + torch.cuda.empty_cache() + + for name in model.state_dict(): + if name not in linear_weights: + set_module_tensor_to_device( + model, name, device, value=compressed_state_dict[name] + ) + apply_compressed_weight(model, compressed_state_dict, device) + + model.to(device) + + return model, tokenizer + + +def compress(tensor, config): + """Simulate group-wise quantization.""" + if not config.enabled: + return tensor + + group_size, num_bits, group_dim, symmetric = ( + config.group_size, + config.num_bits, + config.group_dim, + config.symmetric, + ) + assert num_bits <= 8 + + original_shape = tensor.shape + num_groups = (original_shape[group_dim] + group_size - 1) // group_size + new_shape = ( + original_shape[:group_dim] + + (num_groups, group_size) + + original_shape[group_dim + 1 :] + ) + + # Pad + pad_len = (group_size - original_shape[group_dim] % group_size) % group_size + if pad_len != 0: + pad_shape = ( + original_shape[:group_dim] + (pad_len,) + original_shape[group_dim + 1 :] + ) + tensor = torch.cat( + [tensor, torch.zeros(pad_shape, dtype=tensor.dtype, device=tensor.device)], + dim=group_dim, + ) + data = tensor.view(new_shape) + + # Quantize + if symmetric: + B = 2 ** (num_bits - 1) - 1 + scale = B / torch.max(data.abs(), dim=group_dim + 1, keepdim=True)[0] + data = data * scale + data = data.clamp_(-B, B).round_().to(torch.int8) + return data, scale, original_shape + else: + B = 2**num_bits - 1 + mn = torch.min(data, dim=group_dim + 1, keepdim=True)[0] + mx = torch.max(data, dim=group_dim + 1, keepdim=True)[0] + + scale = B / (mx - mn) + data = data - mn + data.mul_(scale) + + data = data.clamp_(0, B).round_().to(torch.uint8) + return data, mn, scale, original_shape + + +def decompress(packed_data, config): + """Simulate group-wise dequantization.""" + if not config.enabled: + return packed_data + + group_size, num_bits, group_dim, symmetric = ( + config.group_size, + config.num_bits, + config.group_dim, + config.symmetric, + ) + + # Dequantize + if symmetric: + data, scale, original_shape = packed_data + data = data / scale + else: + data, mn, scale, original_shape = packed_data + data = data / scale + data.add_(mn) + + # Unpad + pad_len = (group_size - original_shape[group_dim] % group_size) % group_size + if pad_len: + padded_original_shape = ( + original_shape[:group_dim] + + (original_shape[group_dim] + pad_len,) + + original_shape[group_dim + 1 :] + ) + data = data.reshape(padded_original_shape) + indices = [slice(0, x) for x in original_shape] + return data[indices].contiguous() + else: + return data.view(original_shape) diff --git a/graphgpt/model/convert_fp16.py b/graphgpt/model/convert_fp16.py new file mode 100644 index 0000000..efc40aa --- /dev/null +++ b/graphgpt/model/convert_fp16.py @@ -0,0 +1,26 @@ +""" +Usage: +python3 -m fastchat.model.convert_fp16 --in in-folder --out out-folder +""" +import argparse + +from transformers import AutoTokenizer, AutoModelForCausalLM +import torch + + +def convert_fp16(in_checkpoint, out_checkpoint): + tokenizer = AutoTokenizer.from_pretrained(in_checkpoint, use_fast=False) + model = AutoModelForCausalLM.from_pretrained( + in_checkpoint, torch_dtype=torch.float16, low_cpu_mem_usage=True + ) + model.save_pretrained(out_checkpoint) + tokenizer.save_pretrained(out_checkpoint) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--in-checkpoint", type=str, help="Path to the model") + parser.add_argument("--out-checkpoint", type=str, help="Path to the output model") + args = parser.parse_args() + + convert_fp16(args.in_checkpoint, args.out_checkpoint) diff --git a/graphgpt/model/graph_layers/__init__.py b/graphgpt/model/graph_layers/__init__.py new file mode 100644 index 0000000..d8cbde6 --- /dev/null +++ b/graphgpt/model/graph_layers/__init__.py @@ -0,0 +1,3 @@ +from graphgpt.model.graph_layers.mpnn import MPNN +from graphgpt.model.graph_layers.clip_graph import CLIP, GNN +from graphgpt.model.graph_layers.graph_transformer import graph_transformer \ No newline at end of file diff --git a/graphgpt/model/graph_layers/bpe_simple_vocab_16e6.txt.gz b/graphgpt/model/graph_layers/bpe_simple_vocab_16e6.txt.gz new file mode 100644 index 0000000..7b5088a Binary files /dev/null and b/graphgpt/model/graph_layers/bpe_simple_vocab_16e6.txt.gz differ diff --git a/graphgpt/model/graph_layers/clip_graph.py b/graphgpt/model/graph_layers/clip_graph.py new file mode 100644 index 0000000..62f1d58 --- /dev/null +++ b/graphgpt/model/graph_layers/clip_graph.py @@ -0,0 +1,314 @@ +from collections import OrderedDict +from typing import Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn +from typing import Any, Union, List +from graphgpt.model.graph_layers.simple_tokenizer import SimpleTokenizer as _Tokenizer +from torch_geometric.nn.conv import MessagePassing +from torch_scatter import scatter_add +from torch_geometric.utils import add_remaining_self_loops +from torch.nn import Parameter +from torch import nn, optim +from graphgpt.model.graph_layers.graph_transformer import graph_transformer +from transformers.configuration_utils import PretrainedConfig + +_tokenizer = _Tokenizer() + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + ret = super().forward(x.type(torch.float32)) + return ret.type(orig_type) + + +class QuickGELU(nn.Module): + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class ResidualAttentionBlock(nn.Module): + def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): + super().__init__() + + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential(OrderedDict([ + ("c_fc", nn.Linear(d_model, d_model * 4)), + ("gelu", QuickGELU()), + ("c_proj", nn.Linear(d_model * 4, d_model)) + ])) + self.ln_2 = LayerNorm(d_model) + self.attn_mask = attn_mask + + def attention(self, x: torch.Tensor): + self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None + return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] + + def forward(self, x: torch.Tensor): + x = x + self.attention(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + +class Transformer(nn.Module): + def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) + + def forward(self, x: torch.Tensor): + return self.resblocks(x) + + +class GNN(MessagePassing): + def __init__(self, args, **kwargs): + super(GNN, self).__init__(aggr='add', **kwargs) + self.config = PretrainedConfig() + self.vars = nn.ParameterList() + + w = nn.Parameter(torch.ones([args.gnn_hid, args.gnn_input])) + torch.nn.init.xavier_uniform_(w) + self.vars.append(w) + self.vars.append(nn.Parameter(torch.zeros(args.gnn_hid))) + + w = nn.Parameter(torch.ones([args.gnn_output, args.gnn_hid])) + torch.nn.init.xavier_uniform_(w) + self.vars.append(w) + self.vars.append(nn.Parameter(torch.zeros(args.gnn_output))) + + @staticmethod + def norm(edge_index, num_nodes, improved=False, dtype=None): + edge_weight = torch.ones((edge_index.size(1),), dtype=dtype, + device=edge_index.device) + + fill_value = 1.0 if not improved else 2.0 + edge_index, edge_weight = add_remaining_self_loops( + edge_index, edge_weight, fill_value, num_nodes) + + row, col = edge_index + deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes) + deg_inv_sqrt = deg.pow(-0.5) + deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 + + return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col] + + def forward(self, g, vars=None): + device = self.parameters()[0].device + g = g.to(device) + + edge_index = g.edge_index + x = g.graph_node + if vars is None: + vars = self.vars + improved = False + + w, b = vars[0], vars[1] + edge_index, norm = self.norm(edge_index, x.size(self.node_dim), improved, x.dtype) + x = self.propagate(edge_index, x=x, norm=norm) + w = w.to(x.device) + b = b.to(x.device) + x = F.linear(x, w, b) + x = F.leaky_relu(x) + + w, b = vars[2], vars[3] + edge_index, norm = self.norm(edge_index, x.size(self.node_dim), improved, x.dtype) + x = self.propagate(edge_index, x=x, norm=norm) + w = w.to(x.device) + b = b.to(x.device) + x = F.linear(x, w, b) + + return x + + def parameters(self): + return self.vars + + + +def Mv2SameDevice(var_list): + for vid in range(1, len(var_list)): + var_list[vid] = var_list[vid].to(var_list[0].device) + return var_list + + +class CLIP(nn.Module): + def __init__(self, + args + ): + super().__init__() + + self.context_length = args.context_length + self.args = args + self.edge_coef = args.edge_coef + + if args.gnn_type == 'gcn': + self.gnn = GNN(args) + elif args.gnn_type == 'gt': + self.gnn = graph_transformer(args) + self.transformer = Transformer( + width=args.transformer_width, + layers=args.transformer_layers, + heads=args.transformer_heads, + attn_mask=self.build_attention_mask() + ) + + self.vocab_size = args.vocab_size + self.token_embedding = nn.Embedding(args.vocab_size, + args.transformer_width) # the embedding for all possible tokens + self.positional_embedding = nn.Parameter(torch.empty(self.context_length, args.transformer_width)) + self.ln_final = LayerNorm(args.transformer_width) + + self.text_projection = nn.Parameter(torch.empty(args.transformer_width, args.embed_dim)) + # self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + if args.gnn_type == 'gcn': + self.dtype = self.gnn.vars[0].dtype + elif args.gnn_type == 'gt': + self.dtype = self.gnn.W_pos.dtype + + self.optim = optim.Adam([{'params': self.token_embedding.weight}, + {'params': self.positional_embedding}, + {'params': self.transformer.parameters()}, + {'params': self.text_projection}, + {'params': self.gnn.parameters()} + ], lr=args.lr) + + self.initialize_parameters() + + def initialize_parameters(self): + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.positional_embedding, std=0.01) + + proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) + attn_std = self.transformer.width ** -0.5 + fc_std = (2 * self.transformer.width) ** -0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + def encode_image(self, idx_train, g): + embs = self.gnn(g) + idx_train = idx_train.to(embs.device) + idx_train = idx_train + train_embs = embs[idx_train] + return train_embs + + def encode_text(self, text): + x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] + + x = x + self.positional_embedding.type(self.dtype) + x = x.permute(1, 0, + 2) # NLD -> LND, batch_size * context_length *emb_dim -> context_length * batch_size *emb_dim + x = self.transformer(x) + x = x.permute(1, 0, + 2) # LND -> NLD, context_length * batch_size *emb_dim -> batch_size * context_length *emb_dim + x = self.ln_final(x).type(self.dtype) + # x.shape = [batch_size, n_ctx, transformer.width] + # take features from the eot (end of token) embedding (eot_token is the highest number in each sequence) + # so there is node need to shorten the context length + x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] # + x = x @ self.text_projection + return x + + def forward(self, g, s_n, t_n, s_n_text, t_n_text, training=True): + + s_image_features = self.encode_image(s_n, g) + + s_text_features = self.encode_text(s_n_text) + + t_text_features = self.encode_text(t_n_text) + t_text_features = t_text_features.reshape(s_image_features.shape[0], self.args.neigh_num, self.args.gnn_output) + t_text_features = torch.mean(t_text_features, dim=1, keepdim=False) + # normalized features + s_image_features = s_image_features / s_image_features.norm(dim=-1, keepdim=True) + s_text_features = s_text_features / s_text_features.norm(dim=-1, keepdim=True) + t_text_features = t_text_features / t_text_features.norm(dim=-1, keepdim=True) + + # cosine similarity as logits + + labels = torch.arange(s_image_features.shape[0]).cuda() + + # logit_scale = self.logit_scale.exp() # the temporature hyperparameter + # logit_scale, s_image_features, s_text_features = Mv2SameDevice([logit_scale, s_image_features, s_text_features]) + # logits = logit_scale * s_image_features @ s_text_features.t() + # loss_i = F.cross_entropy(logits, labels) + # loss_t = F.cross_entropy(logits.T, labels) + # node_loss = (loss_i + loss_t) / 2 + + # logit_scale, s_image_features, t_text_features = Mv2SameDevice([logit_scale, s_image_features, t_text_features]) + # logits = logit_scale * s_image_features @ t_text_features.t() + # loss_i = F.cross_entropy(logits, labels) + # loss_t = F.cross_entropy(logits.T, labels) + # gt_loss = (loss_i + loss_t)/2 + + # logit_scale, s_text_features, t_text_features = Mv2SameDevice([logit_scale, s_text_features, t_text_features]) + # logits = logit_scale * s_text_features @ t_text_features.t() + # loss_i = F.cross_entropy(logits, labels) + # loss_t = F.cross_entropy(logits.T, labels) + # tt_loss = (loss_i + loss_t)/2 + + + + # shape = [global_batch_size, global_batch_size] + # return all_loss + return s_image_features, s_text_features, t_text_features, labels + + +def tokenize(texts: Union[str, List[str]], context_length: int = 128, truncate: bool = True) -> torch.LongTensor: + + """ + Returns the tokenized representation of given input string(s) + + Parameters + ---------- + texts : Union[str, List[str]] + An input string or a list of input strings to tokenize + + context_length : int + The context length to use; all CLIP models use 77 as the context length + + truncate: bool + Whether to truncate the text in case its encoding is longer than the context length + + Returns + ------- + A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] + """ + if isinstance(texts, str): + texts = [texts] + + sot_token = _tokenizer.encoder["<|startoftext|>"] + eot_token = _tokenizer.encoder["<|endoftext|>"] + all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] + result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + + for i, tokens in enumerate(all_tokens): + if len(tokens) > context_length: + if truncate: + tokens = tokens[:context_length] + tokens[-1] = eot_token + else: + raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") + result[i, :len(tokens)] = torch.tensor(tokens) + + return result + diff --git a/graphgpt/model/graph_layers/graph_transformer.py b/graphgpt/model/graph_layers/graph_transformer.py new file mode 100644 index 0000000..0e4e7c8 --- /dev/null +++ b/graphgpt/model/graph_layers/graph_transformer.py @@ -0,0 +1,142 @@ +import torch as t +from torch import nn +import torch.nn.functional as F +import math +from transformers.configuration_utils import PretrainedConfig + +init = nn.init.xavier_uniform_ +uniformInit = nn.init.uniform + +def PositionalEncoding(q_len, d_model, normalize=True): + pe = t.zeros(q_len, d_model) + position = t.arange(0, q_len).unsqueeze(1) + div_term = t.exp(t.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)) + pe[:, 0::2] = t.sin(position * div_term) + pe[:, 1::2] = t.cos(position * div_term) + if normalize: + pe = pe - pe.mean() + pe = pe / (pe.std() * 10) + return pe + + +def pos_encoding(pe, learn_pe, nvar, d_model): + # Positional encoding + if pe == None: + W_pos = t.empty((nvar, d_model)) # pe = None and learn_pe = False can be used to measure impact of pe + nn.init.uniform_(W_pos, -0.02, 0.02) + learn_pe = False + elif pe == 'zero': + W_pos = t.empty((nvar, 1)) + nn.init.uniform_(W_pos, -0.02, 0.02) + elif pe == 'zeros': + W_pos = t.empty((nvar, d_model)) + nn.init.uniform_(W_pos, -0.02, 0.02) + elif pe == 'normal' or pe == 'gauss': + W_pos = t.zeros((nvar, 1)) + t.nn.init.normal_(W_pos, mean=0.0, std=0.1) + elif pe == 'uniform': + W_pos = t.zeros((nvar, 1)) + nn.init.uniform_(W_pos, a=0.0, b=0.1) + elif pe == 'sincos': W_pos = PositionalEncoding(nvar, d_model, normalize=True) + else: raise ValueError(f"{pe} is not a valid pe (positional encoder. Available types: 'gauss'=='normal', \ + 'zeros', 'zero', uniform', 'sincos', None.)") + return nn.Parameter(W_pos, requires_grad=learn_pe) + + +class graph_transformer(nn.Module): + def __init__(self, args): + super(graph_transformer, self).__init__() + self.config = PretrainedConfig() + self.gtLayers = nn.Sequential(*[GTLayer(args) for i in range(args.gt_layers)]) + + self.W_pos = pos_encoding('zeros', True, 1, args.att_d_model) + + self.W_P = nn.Linear(args.gnn_input, args.att_d_model) + self.dropout = nn.Dropout(0.1) + self.inverW_P = nn.Linear(args.att_d_model, args.gnn_output) + self.args = args + + def forward(self, g): + # Adj: sp adj + # x: bs * n * d_model * num_patch + + # print(edge_index) + device = self.parameters().__next__().device + g = g.to(device) + + x = g.graph_node + + # x, W_P_weight, W_P_bias= Mv2Samedevice([x, self.W_P.weight, self.W_P.bias]) + # self.W_P.weight = nn.Parameter(W_P_weight.to(x.dtype)) + # self.W_P.bias = nn.Parameter(W_P_bias.to(x.dtype)) + # print(self.W_P.dtype, x.dtype) + z = self.W_P(x) + if self.args.if_pos: + embeds = self.dropout(z + self.W_pos) + else: + embeds = self.dropout(z) + for gt in self.gtLayers: + embeds = gt(g, embeds) # bs * num_patch * n * d_model + # embeds, inverW_P_weight, inverW_P_bias = Mv2Samedevice([embeds, self.inverW_P.weight, self.inverW_P.bias]) + # self.inverW_P.weight = nn.Parameter(inverW_P_weight.to(embeds.dtype)) + # self.inverW_P.bias = nn.Parameter(inverW_P_bias.to(embeds.dtype)) + ret = self.inverW_P(embeds) + return ret +def Mv2Samedevice(vars): + return [var.to(vars[0].device) for var in vars] + +class GTLayer(nn.Module): + def __init__(self, args): + super(GTLayer, self).__init__() + self.qTrans = nn.Parameter(init(t.empty(args.att_d_model, args.att_d_model))) + self.kTrans = nn.Parameter(init(t.empty(args.att_d_model, args.att_d_model))) + self.vTrans = nn.Parameter(init(t.empty(args.att_d_model, args.att_d_model))) + if args.att_norm: + self.norm = nn.LayerNorm(args.att_d_model, eps=1e-6) + self.args = args + + + + def forward(self, g, embeds): + # Adj: adj + # x: n * d_model + rows, cols = g.edge_index + nvar, _ = embeds.shape + # print(rows) + # print(cols) + + rowEmbeds = embeds[rows, :] + colEmbeds = embeds[cols, :] + evar, _ = rowEmbeds.shape + + # rowEmbeds, qTrans, kTrans, vTrans = Mv2Samedevice([rowEmbeds, self.qTrans, self.kTrans, self.vTrans]) + # self.qTrans = nn.Parameter(qTrans.to(rowEmbeds.dtype)) + # self.kTrans = nn.Parameter(kTrans.to(rowEmbeds.dtype)) + # self.vTrans = nn.Parameter(vTrans.to(rowEmbeds.dtype)) + qEmbeds = (rowEmbeds @ self.qTrans).view([evar, self.args.head, self.args.att_d_model // self.args.head]) + kEmbeds = (colEmbeds @ self.kTrans).view([evar, self.args.head, self.args.att_d_model // self.args.head]) + vEmbeds = (colEmbeds @ self.vTrans).view([evar, self.args.head, self.args.att_d_model // self.args.head]) + + att = t.einsum('ehd, ehd -> eh', qEmbeds, kEmbeds) + att = t.clamp(att, -10.0, 10.0) + expAtt = t.exp(att) + + tem = t.zeros([nvar, self.args.head]).to(expAtt.device, dtype=expAtt.dtype) + # print(tem.device, expAtt.device, rows.device) + rows = rows.to(expAtt.device) + attNorm = (tem.index_add_(0, rows, expAtt))[rows, :] + att = expAtt / (attNorm + 1e-8) # bleh + + resEmbeds = t.einsum('eh, ehd -> ehd', att, vEmbeds).view([evar, self.args.att_d_model]) + tem = t.zeros([nvar, self.args.att_d_model]).to(resEmbeds.device, dtype=resEmbeds.dtype) + rows = rows.to(resEmbeds.device) + tem = tem.to(resEmbeds.dtype) + resEmbeds = tem.index_add_(0, rows, resEmbeds) # nd + resEmbeds = resEmbeds + embeds + if self.args.att_norm: + # resEmbeds, norm_weight, norm_bias = Mv2Samedevice([resEmbeds, self.norm.weight, self.norm.bias]) + # self.norm.weight = nn.Parameter(norm_weight.to(resEmbeds.dtype)) + # self.norm.bias = nn.Parameter(norm_bias.to(resEmbeds.dtype)) + resEmbeds = self.norm(resEmbeds) + + return resEmbeds \ No newline at end of file diff --git a/graphgpt/model/graph_layers/mpnn.py b/graphgpt/model/graph_layers/mpnn.py new file mode 100644 index 0000000..0b159a1 --- /dev/null +++ b/graphgpt/model/graph_layers/mpnn.py @@ -0,0 +1,88 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch_geometric.utils import remove_self_loops, add_self_loops, degree +import json +import copy +from transformers import AutoTokenizer +import transformers +from transformers.configuration_utils import PretrainedConfig +import os + +def gcn_conv(h, edge_index): + # print(edge_index) + N, node_feas = h.shape + edge_index, _ = remove_self_loops(edge_index) + edge_index, _ = add_self_loops(edge_index, num_nodes=N) + + src, dst = edge_index + deg = degree(dst, num_nodes=N) + + deg_src = deg[src].pow(-0.5) + deg_src.masked_fill_(deg_src == float('inf'), 0) + deg_dst = deg[dst].pow(-0.5) + deg_dst.masked_fill_(deg_dst == float('inf'), 0) + edge_weight = deg_src * deg_dst + + a = torch.sparse_coo_tensor(edge_index, edge_weight, torch.Size([N, N])).t() + rows, cols = edge_index + edge_msg = h[rows, :] * torch.unsqueeze(edge_weight, dim=-1) + col_embeds = h[cols, :] + tem = torch.zeros([N, node_feas]).to(edge_msg.device) + rows = rows.to(edge_msg.device) + h_prime = tem.index_add_(0, rows, edge_msg) # nd + # h = h.float() + # h_prime = a @ h + # h_prime = h_prime.bfloat16() + return h_prime + +# Implementation of MPNN, which can become MLP or GCN depending on whether using message passing +class MPNN(nn.Module): + def __init__(self, in_channels, hidden_channels, out_channels, **kwargs): + super(MPNN, self).__init__() + self.config = PretrainedConfig() + self.dropout = kwargs.get('dropout')# args.dropout + self.num_layers = kwargs.get('num_layers')# args.num_layers + self.ff_bias = True # Use bias for FF layers in default + + self.bns = nn.BatchNorm1d(hidden_channels, affine=False, track_running_stats=False) + self.activation = F.relu + self.if_param = kwargs.get('if_param') + + if self.if_param: + self.fcs = nn.ModuleList([]) + self.fcs.append(nn.Linear(in_channels, hidden_channels, bias=self.ff_bias)) + for _ in range(self.num_layers - 2): self.fcs.append(nn.Linear(hidden_channels, hidden_channels, bias=self.ff_bias)) #1s + self.fcs.append(nn.Linear(hidden_channels, out_channels, bias=self.ff_bias)) #1 + self.reset_parameters() + + + def reset_parameters(self): + for mlp in self.fcs: + nn.init.xavier_uniform_(mlp.weight, gain=1.414) + nn.init.zeros_(mlp.bias) + + def forward(self, g, use_conv=True): + + x = g.graph_node + edge_index = g.edge_index + try: + device = self.parameters().__next__().device + except: + device = x.device + x = x.to(device) + edge_index = edge_index.to(device) + for i in range(self.num_layers - 1): + if self.if_param: x = x @ self.fcs[i].weight.t() + if use_conv: x = gcn_conv(x, edge_index) # Optionally replace 'gcn_conv' with other conv functions in conv.py + if self.ff_bias and self.if_param: x = x + self.fcs[i].bias + try: + x = self.activation(self.bns(x)) + except: + x = self.activation((x)) + x = F.dropout(x, p=self.dropout, training=self.training) + + if self.if_param: x = x @ self.fcs[-1].weight.t() + if use_conv: x = gcn_conv(x, edge_index) + if self.ff_bias and self.if_param: x = x + self.fcs[-1].bias + return x diff --git a/graphgpt/model/graph_layers/simple_tokenizer.py b/graphgpt/model/graph_layers/simple_tokenizer.py new file mode 100644 index 0000000..0a66286 --- /dev/null +++ b/graphgpt/model/graph_layers/simple_tokenizer.py @@ -0,0 +1,132 @@ +import gzip +import html +import os +from functools import lru_cache + +import ftfy +import regex as re + + +@lru_cache() +def default_bpe(): + return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a signficant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8+n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r'\s+', ' ', text) + text = text.strip() + return text + + +class SimpleTokenizer(object): + def __init__(self, bpe_path: str = default_bpe()): + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') + merges = merges[1:49152-256-2+1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v+'' for v in vocab] + for merge in merges: + vocab.append(''.join(merge)) + vocab.extend(['<|startoftext|>', '<|endoftext|>']) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} + self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + ( token[-1] + '',) + pairs = get_pairs(word) + + if not pairs: + return token+'' + + while True: + bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word)-1 and word[i+1] == second: + new_word.append(first+second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = ' '.join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) + bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) + return bpe_tokens + + def decode(self, tokens): + text = ''.join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') + return text diff --git a/graphgpt/model/make_delta.py b/graphgpt/model/make_delta.py new file mode 100644 index 0000000..480ba8f --- /dev/null +++ b/graphgpt/model/make_delta.py @@ -0,0 +1,48 @@ +""" +Make the delta weights by subtracting base weights. + +Usage: +python3 -m fastchat.model.make_delta --base ~/model_weights/llama-13b --target ~/model_weights/vicuna-13b --delta ~/model_weights/vicuna-13b-delta --hub-repo-id lmsys/vicuna-13b-delta-v1.1 +""" +import argparse + +import torch +from tqdm import tqdm +from transformers import AutoTokenizer, AutoModelForCausalLM + + +def make_delta(base_model_path, target_model_path, delta_path): + print(f"Loading the base model from {base_model_path}") + base = AutoModelForCausalLM.from_pretrained( + base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True + ) + + print(f"Loading the target model from {target_model_path}") + target = AutoModelForCausalLM.from_pretrained( + target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True + ) + target_tokenizer = AutoTokenizer.from_pretrained(target_model_path, use_fast=False) + + print("Calculating the delta") + for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"): + assert name in base.state_dict() + param.data -= base.state_dict()[name] + + print(f"Saving the delta to {delta_path}") + if args.hub_repo_id: + kwargs = {"push_to_hub": True, "repo_id": args.hub_repo_id} + else: + kwargs = {} + target.save_pretrained(delta_path, **kwargs) + target_tokenizer.save_pretrained(delta_path, **kwargs) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--base-model-path", type=str, required=True) + parser.add_argument("--target-model-path", type=str, required=True) + parser.add_argument("--delta-path", type=str, required=True) + parser.add_argument("--hub-repo-id", type=str) + args = parser.parse_args() + + make_delta(args.base_model_path, args.target_model_path, args.delta_path) diff --git a/graphgpt/model/model_adapter.py b/graphgpt/model/model_adapter.py new file mode 100644 index 0000000..bc11e5e --- /dev/null +++ b/graphgpt/model/model_adapter.py @@ -0,0 +1,551 @@ +"""Model adapter registration.""" + +import math +import sys +from typing import List, Optional +import warnings + +if sys.version_info >= (3, 9): + from functools import cache +else: + from functools import lru_cache as cache + +import psutil +import torch +from transformers import ( + AutoConfig, + AutoModel, + AutoModelForCausalLM, + AutoModelForSeq2SeqLM, + AutoTokenizer, + LlamaTokenizer, + LlamaForCausalLM, + T5Tokenizer, +) + +from fastchat.conversation import Conversation, get_conv_template +from fastchat.model.compression import load_compress_model +from fastchat.model.monkey_patch_non_inplace import ( + replace_llama_attn_with_non_inplace_operations, +) +from fastchat.utils import get_gpu_memory + + +class BaseAdapter: + """The base and the default model adapter.""" + + def match(self, model_path: str): + return True + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) + model = AutoModelForCausalLM.from_pretrained( + model_path, low_cpu_mem_usage=True, **from_pretrained_kwargs + ) + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("one_shot") + + +# A global registry for all model adapters +model_adapters: List[BaseAdapter] = [] + + +def register_model_adapter(cls): + """Register a model adapter.""" + model_adapters.append(cls()) + + +@cache +def get_model_adapter(model_path: str) -> BaseAdapter: + """Get a model adapter for a model_path.""" + for adapter in model_adapters: + if adapter.match(model_path): + return adapter + raise ValueError(f"No valid model adapter for {model_path}") + + +def raise_warning_for_incompatible_cpu_offloading_configuration( + device: str, load_8bit: bool, cpu_offloading: bool +): + if cpu_offloading: + if not load_8bit: + warnings.warn( + "The cpu-offloading feature can only be used while also using 8-bit-quantization.\n" + "Use '--load-8bit' to enable 8-bit-quantization\n" + "Continuing without cpu-offloading enabled\n" + ) + return False + if not "linux" in sys.platform: + warnings.warn( + "CPU-offloading is only supported on linux-systems due to the limited compatability with the bitsandbytes-package\n" + "Continuing without cpu-offloading enabled\n" + ) + return False + if device != "cuda": + warnings.warn( + "CPU-offloading is only enabled when using CUDA-devices\n" + "Continuing without cpu-offloading enabled\n" + ) + return False + return cpu_offloading + + +def load_model( + model_path: str, + device: str, + num_gpus: int, + max_gpu_memory: Optional[str] = None, + load_8bit: bool = False, + cpu_offloading: bool = False, + debug: bool = False, +): + """Load a model from Hugging Face.""" + + # Handle device mapping + cpu_offloading = raise_warning_for_incompatible_cpu_offloading_configuration( + device, load_8bit, cpu_offloading + ) + if device == "cpu": + kwargs = {"torch_dtype": torch.float32} + elif device == "cuda": + kwargs = {"torch_dtype": torch.float16} + if num_gpus != 1: + kwargs["device_map"] = "auto" + if max_gpu_memory is None: + kwargs[ + "device_map" + ] = "sequential" # This is important for not the same VRAM sizes + available_gpu_memory = get_gpu_memory(num_gpus) + kwargs["max_memory"] = { + i: str(int(available_gpu_memory[i] * 0.85)) + "GiB" + for i in range(num_gpus) + } + else: + kwargs["max_memory"] = {i: max_gpu_memory for i in range(num_gpus)} + elif device == "mps": + kwargs = {"torch_dtype": torch.float16} + # Avoid bugs in mps backend by not using in-place operations. + replace_llama_attn_with_non_inplace_operations() + else: + raise ValueError(f"Invalid device: {device}") + + if cpu_offloading: + # raises an error on incompatible platforms + from transformers import BitsAndBytesConfig + + if "max_memory" in kwargs: + kwargs["max_memory"]["cpu"] = ( + str(math.floor(psutil.virtual_memory().available / 2**20)) + "Mib" + ) + kwargs["quantization_config"] = BitsAndBytesConfig( + load_in_8bit_fp32_cpu_offload=cpu_offloading + ) + kwargs["load_in_8bit"] = load_8bit + elif load_8bit: + if num_gpus != 1: + warnings.warn( + "8-bit quantization is not supported for multi-gpu inference." + ) + else: + return load_compress_model( + model_path=model_path, device=device, torch_dtype=kwargs["torch_dtype"] + ) + + # Load model + adapter = get_model_adapter(model_path) + model, tokenizer = adapter.load_model(model_path, kwargs) + + if (device == "cuda" and num_gpus == 1 and not cpu_offloading) or device == "mps": + model.to(device) + + if debug: + print(model) + + return model, tokenizer + + +def get_conversation_template(model_path: str) -> Conversation: + adapter = get_model_adapter(model_path) + return adapter.get_default_conv_template(model_path) + + +def add_model_args(parser): + parser.add_argument( + "--model-path", + type=str, + default="lmsys/fastchat-t5-3b-v1.0", + help="The path to the weights. This can be a local folder or a Hugging Face repo ID.", + ) + parser.add_argument( + "--device", + type=str, + choices=["cpu", "cuda", "mps"], + default="cuda", + help="The device type", + ) + parser.add_argument( + "--gpus", + type=str, + default=None, + help="A single GPU like 1 or multiple GPUs like 0,2", + ) + parser.add_argument("--num-gpus", type=int, default=1) + parser.add_argument( + "--max-gpu-memory", + type=str, + help="The maximum memory per gpu. Use a string like '13Gib'", + ) + parser.add_argument( + "--load-8bit", action="store_true", help="Use 8-bit quantization" + ) + parser.add_argument( + "--cpu-offloading", + action="store_true", + help="Only when using 8-bit quantization: Offload excess weights to the CPU that don't fit on the GPU", + ) + + +def remove_parent_directory_name(model_path): + """Remove parent directory name.""" + if model_path[-1] == "/": + model_path = model_path[:-1] + return model_path.split("/")[-1] + + +class VicunaAdapter(BaseAdapter): + "Model adapater for vicuna-v1.1" + + def match(self, model_path: str): + return "vicuna" in model_path + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) + model = AutoModelForCausalLM.from_pretrained( + model_path, + low_cpu_mem_usage=True, + **from_pretrained_kwargs, + ) + self.raise_warning_for_old_weights(model) + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + if "v0" in remove_parent_directory_name(model_path): + return get_conv_template("one_shot") + return get_conv_template("vicuna_v1.1") + + def raise_warning_for_old_weights(self, model): + if isinstance(model, LlamaForCausalLM) and model.model.vocab_size > 32000: + warnings.warn( + "\nYou are probably using the old Vicuna-v0 model, " + "which will generate unexpected results with the " + "current fastchat.\nYou can try one of the following methods:\n" + "1. Upgrade your weights to the new Vicuna-v1.1: https://github.com/lm-sys/FastChat#vicuna-weights.\n" + "2. Use the old conversation template by `python3 -m fastchat.serve.cli --model-path /path/to/vicuna-v0 --conv-template conv_one_shot`\n" + "3. Downgrade fschat to fschat==0.1.10 (Not recommonded).\n" + ) + + +class T5Adapter(BaseAdapter): + """The model adapter for lmsys/fastchat-t5-3b-v1.0""" + + def match(self, model_path: str): + return "t5" in model_path + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + tokenizer = T5Tokenizer.from_pretrained(model_path, use_fast=False) + model = AutoModelForSeq2SeqLM.from_pretrained( + model_path, low_cpu_mem_usage=True, **from_pretrained_kwargs + ) + return model, tokenizer + + +class KoalaAdapter(BaseAdapter): + """The model adapter for koala""" + + def match(self, model_path: str): + return "koala" in model_path + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("koala_v1") + + +class AlpacaAdapter(BaseAdapter): + """The model adapter for alpaca.""" + + def match(self, model_path: str): + return "alpaca" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("alpaca") + + +class ChatGLMAdapter(BaseAdapter): + """The model adapter for THUDM/chatglm-6b""" + + def match(self, model_path: str): + return "chatglm" in model_path + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + model = AutoModel.from_pretrained( + model_path, trust_remote_code=True, **from_pretrained_kwargs + ) + return model, tokenizer + + +class DollyV2Adapter(BaseAdapter): + """The model adapter for databricks/dolly-v2-12b""" + + def match(self, model_path: str): + return "dolly-v2" in model_path + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True) + model = AutoModelForCausalLM.from_pretrained( + model_path, + low_cpu_mem_usage=True, + **from_pretrained_kwargs, + ) + # 50277 means "### End" + tokenizer.eos_token_id = 50277 + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("dolly_v2") + + +class OasstPythiaAdapter(BaseAdapter): + """The model adapter for OpenAssistant/oasst-sft-1-pythia-12b""" + + def match(self, model_path: str): + return "oasst" in model_path and "pythia" in model_path + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True) + model = AutoModelForCausalLM.from_pretrained( + model_path, + low_cpu_mem_usage=True, + **from_pretrained_kwargs, + ) + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("oasst_pythia") + + +class StableLMAdapter(BaseAdapter): + """The model adapter for StabilityAI/stablelm-tuned-alpha-7b""" + + def match(self, model_path: str): + return "stablelm" in model_path + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True) + model = AutoModelForCausalLM.from_pretrained( + model_path, + low_cpu_mem_usage=True, + **from_pretrained_kwargs, + ) + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("stablelm") + + +class MPTAdapter(BaseAdapter): + """The model adapter for mosaicml/mpt-7b-chat""" + + def match(self, model_path: str): + return "mpt" in model_path + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + model = AutoModelForCausalLM.from_pretrained( + model_path, + low_cpu_mem_usage=True, + trust_remote_code=True, + max_seq_len=8192, + **from_pretrained_kwargs, + ) + tokenizer = AutoTokenizer.from_pretrained( + model_path, trust_remote_code=True, use_fast=True + ) + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("mpt") + + +class BaizeAdapter(BaseAdapter): + """The model adapter for project-baize/baize-lora-7B""" + + def match(self, model_path: str): + return "baize" in model_path + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("baize") + + +class RwkvAdapter(BaseAdapter): + """The model adapter for BlinkDL/RWKV-4-Raven""" + + def match(self, model_path: str): + return "RWKV-4" in model_path + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + from fastchat.model.rwkv_model import RwkvModel + + model = RwkvModel(model_path) + tokenizer = AutoTokenizer.from_pretrained( + "EleutherAI/pythia-160m", use_fast=True + ) + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("rwkv") + + +class OpenBuddyAdapter(BaseAdapter): + """The model adapter for OpenBuddy/openbuddy-7b-v1.1-bf16-enc""" + + def match(self, model_path: str): + return "openbuddy" in model_path + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + if "-bf16" in model_path: + from_pretrained_kwargs["torch_dtype"] = torch.bfloat16 + warnings.warn( + "## This is a bf16(bfloat16) variant of OpenBuddy. Please make sure your GPU supports bf16." + ) + model = LlamaForCausalLM.from_pretrained( + model_path, low_cpu_mem_usage=True, **from_pretrained_kwargs + ) + tokenizer = LlamaTokenizer.from_pretrained(model_path) + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("openbuddy") + + +class PhoenixAdapter(BaseAdapter): + """The model adapter for FreedomIntelligence/phoenix-inst-chat-7b""" + + def match(self, model_path: str): + return "phoenix" in model_path + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True) + model = AutoModelForCausalLM.from_pretrained( + model_path, + low_cpu_mem_usage=True, + **from_pretrained_kwargs, + ) + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("phoenix") + + +class ChatGPTAdapter(BaseAdapter): + """The model adapter for ChatGPT.""" + + def match(self, model_path: str): + return model_path == "gpt-3.5-turbo" or model_path == "gpt-4" + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + raise NotImplementedError() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("chatgpt") + + +class ClaudeAdapter(BaseAdapter): + """The model adapter for Claude.""" + + def match(self, model_path: str): + return model_path in ["claude-v1", "claude-instant-v1"] + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + raise NotImplementedError() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("claude") + + +class BardAdapter(BaseAdapter): + """The model adapter for Bard.""" + + def match(self, model_path: str): + return model_path == "bard" + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + raise NotImplementedError() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("bard") + + +class BiLLaAdapter(BaseAdapter): + """The model adapter for BiLLa.""" + + def match(self, model_path: str): + return "billa" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("billa") + + +class RedPajamaINCITEAdapter(BaseAdapter): + """The model adapter for RedPajama INCITE.""" + + def match(self, model_path: str): + return "redpajama-incite" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + tokenizer = AutoTokenizer.from_pretrained(model_path) # no use_fast=False + model = AutoModelForCausalLM.from_pretrained( + model_path, + low_cpu_mem_usage=True, + **from_pretrained_kwargs, + ) + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("redpajama-incite") + + +class H2OGPTAdapter(BaseAdapter): + """The model adapter for h2oGPT.""" + + def match(self, model_path: str): + return "h2ogpt" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("h2ogpt") + + +# Note: the registration order matters. +# The one registered earlier has a higher matching priority. +register_model_adapter(VicunaAdapter) +register_model_adapter(T5Adapter) +register_model_adapter(KoalaAdapter) +register_model_adapter(AlpacaAdapter) +register_model_adapter(ChatGLMAdapter) +register_model_adapter(DollyV2Adapter) +register_model_adapter(OasstPythiaAdapter) +register_model_adapter(StableLMAdapter) +register_model_adapter(BaizeAdapter) +register_model_adapter(RwkvAdapter) +register_model_adapter(OpenBuddyAdapter) +register_model_adapter(PhoenixAdapter) +register_model_adapter(BardAdapter) +register_model_adapter(ChatGPTAdapter) +register_model_adapter(ClaudeAdapter) +register_model_adapter(MPTAdapter) +register_model_adapter(BiLLaAdapter) +register_model_adapter(RedPajamaINCITEAdapter) +register_model_adapter(H2OGPTAdapter) + +# After all adapters, try the default base adapter. +register_model_adapter(BaseAdapter) diff --git a/graphgpt/model/model_registry.py b/graphgpt/model/model_registry.py new file mode 100644 index 0000000..f8df996 --- /dev/null +++ b/graphgpt/model/model_registry.py @@ -0,0 +1,141 @@ +"""Additional information of the models.""" +from collections import namedtuple +from typing import List + + +ModelInfo = namedtuple("ModelInfo", ["simple_name", "link", "description"]) + + +model_info = {} + + +def register_model_info( + full_names: List[str], simple_name: str, link: str, description: str +): + info = ModelInfo(simple_name, link, description) + + for full_name in full_names: + model_info[full_name] = info + + +def get_model_info(name: str) -> ModelInfo: + return model_info[name] + + +register_model_info( + ["gpt-4"], "ChatGPT-4", "https://openai.com/research/gpt-4", "ChatGPT-4 by OpenAI" +) +register_model_info( + ["gpt-3.5-turbo"], + "ChatGPT-3.5", + "https://openai.com/blog/chatgpt", + "ChatGPT-3.5 by OpenAI", +) +register_model_info( + ["claude-v1"], + "Claude", + "https://www.anthropic.com/index/introducing-claude", + "Claude by Anthropic", +) +register_model_info( + ["claude-instant-v1"], + "Claude Instant", + "https://www.anthropic.com/index/introducing-claude", + "Claude Instant by Anthropic", +) +register_model_info( + ["palm-2"], + "PaLM 2 Chat", + "https://cloud.google.com/vertex-ai/docs/release-notes#May_10_2023", + "PaLM 2 for Chat (chat-bison@001) by Google", +) +register_model_info( + ["vicuna-13b", "vicuna-7b"], + "Vicuna", + "https://lmsys.org/blog/2023-03-30-vicuna/", + "a chat assistant fine-tuned from LLaMA on user-shared conversations by LMSYS", +) +register_model_info( + ["koala-13b"], + "Koala", + "https://bair.berkeley.edu/blog/2023/04/03/koala", + "a dialogue model for academic research by BAIR", +) +register_model_info( + ["oasst-pythia-12b"], + "OpenAssistant (oasst)", + "https://open-assistant.io", + "an Open Assistant for everyone by LAION", +) +register_model_info( + ["RWKV-4-Raven-14B"], + "RWKV-4-Raven", + "https://huggingface.co/BlinkDL/rwkv-4-raven", + "an RNN with transformer-level LLM performance", +) +register_model_info( + ["alpaca-13b"], + "Alpaca", + "https://crfm.stanford.edu/2023/03/13/alpaca.html", + "a model fine-tuned from LLaMA on instruction-following demonstrations by Stanford", +) +register_model_info( + ["chatglm-6b"], + "ChatGLM", + "https://chatglm.cn/blog", + "an open bilingual dialogue language model by Tsinghua University", +) +register_model_info( + ["llama-13b"], + "LLaMA", + "https://arxiv.org/abs/2302.13971", + "open and efficient foundation language models by Meta", +) +register_model_info( + ["dolly-v2-12b"], + "Dolly", + "https://www.databricks.com/blog/2023/04/12/dolly-first-open-commercially-viable-instruction-tuned-llm", + "an instruction-tuned open large language model by Databricks", +) +register_model_info( + ["stablelm-tuned-alpha-7b"], + "StableLM", + "https://github.com/stability-AI/stableLM", + "Stability AI language models", +) +register_model_info( + ["fastchat-t5-3b"], + "FastChat-T5", + "https://huggingface.co/lmsys/fastchat-t5-3b-v1.0", + "a chat assistant fine-tuned from FLAN-T5 by LMSYS", +) +register_model_info( + ["phoenix-inst-chat-7b"], + "Phoenix-7B", + "https://huggingface.co/FreedomIntelligence/phoenix-inst-chat-7b", + "a multilingual chat assistant fine-tuned from Bloomz to democratize ChatGPT across languages by CUHK(SZ)", +) +register_model_info( + ["mpt-7b-chat"], + "MPT-Chat", + "https://www.mosaicml.com/blog/mpt-7b", + "a chatbot fine-tuned from MPT-7B by MosaicML", +) +register_model_info( + ["billa-7b-sft"], + "BiLLa-7B-SFT", + "https://huggingface.co/Neutralzz/BiLLa-7B-SFT", + "an instruction-tuned bilingual LLaMA with enhanced reasoning ability by an independent researcher", +) +register_model_info( + ["h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt-v2"], + "h2oGPT-GM-7b", + "https://huggingface.co/h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt-v2", + "an instruction-tuned OpenLLaMA with enhanced conversational ability by H2O.ai", +) +register_model_info( + ["baize-v2-7b", "baize-v2-13b"], + "Baize v2", + "https://github.com/project-baize/baize-chatbot#v2", + "A chatbot fine-tuned from LLaMA with ChatGPT self-chat data and Self-Disillation with Feedback (SDF) by UCSD and SYSU.", +) diff --git a/graphgpt/model/monkey_patch_non_inplace.py b/graphgpt/model/monkey_patch_non_inplace.py new file mode 100644 index 0000000..9661d70 --- /dev/null +++ b/graphgpt/model/monkey_patch_non_inplace.py @@ -0,0 +1,118 @@ +""" +Monkey patch the llama implementation in the huggingface/transformers library. +Avoid bugs in mps backend by not using in-place operations. +""" +import math +from typing import List, Optional, Tuple + +import torch +from torch import nn +import transformers + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2].clone() + x2 = x[..., x.shape[-1] // 2 :].clone() + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids): + gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1] + gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3]) + cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices) + sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = ( + self.q_proj(hidden_states) + .view(bsz, q_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + key_states = ( + self.k_proj(hidden_states) + .view(bsz, q_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + value_states = ( + self.v_proj(hidden_states) + .view(bsz, q_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin, position_ids + ) + # [bsz, nh, t, hd] + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt( + self.head_dim + ) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + attn_weights = torch.max( + attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min) + ) + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( + query_states.dtype + ) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +def replace_llama_attn_with_non_inplace_operations(): + """Avoid bugs in mps backend by not using in-place operations.""" + transformers.models.llama.modeling_llama.LlamaAttention.forward = forward diff --git a/graphgpt/model/utils.py b/graphgpt/model/utils.py new file mode 100644 index 0000000..f7fa93d --- /dev/null +++ b/graphgpt/model/utils.py @@ -0,0 +1,26 @@ +import torch + +from transformers import AutoConfig, StoppingCriteria + + +class KeywordsStoppingCriteria(StoppingCriteria): + def __init__(self, keywords, tokenizer, input_ids): + self.keywords = keywords + self.keyword_ids = [tokenizer(keyword).input_ids for keyword in keywords] + self.keyword_ids = [keyword_id[0] for keyword_id in self.keyword_ids if type(keyword_id) is list and len(keyword_id) == 1] + self.tokenizer = tokenizer + self.start_len = None + self.input_ids = input_ids + + def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: + if self.start_len is None: + self.start_len = self.input_ids.shape[1] + else: + for keyword_id in self.keyword_ids: + if output_ids[0, -1] == keyword_id: + return True + outputs = self.tokenizer.batch_decode(output_ids[:, self.start_len:], skip_special_tokens=True)[0] + for keyword in self.keywords: + if keyword in outputs: + return True + return False \ No newline at end of file diff --git a/graphgpt/protocol/openai_api_protocol.py b/graphgpt/protocol/openai_api_protocol.py new file mode 100644 index 0000000..466375c --- /dev/null +++ b/graphgpt/protocol/openai_api_protocol.py @@ -0,0 +1,172 @@ +from typing import Literal, Optional, List, Dict, Any, Union + +import time + +import shortuuid +from pydantic import BaseModel, Field + + +class ErrorResponse(BaseModel): + object: str = "error" + message: str + code: int + + +class ModelPermission(BaseModel): + id: str = Field(default_factory=lambda: f"modelperm-{shortuuid.random()}") + object: str = "model_permission" + created: int = Field(default_factory=lambda: int(time.time())) + allow_create_engine: bool = False + allow_sampling: bool = True + allow_logprobs: bool = True + allow_search_indices: bool = True + allow_view: bool = True + allow_fine_tuning: bool = False + organization: str = "*" + group: Optional[str] = None + is_blocking: str = False + + +class ModelCard(BaseModel): + id: str + object: str = "model" + created: int = Field(default_factory=lambda: int(time.time())) + owned_by: str = "fastchat" + root: Optional[str] = None + parent: Optional[str] = None + permission: List[ModelPermission] = [] + + +class ModelList(BaseModel): + object: str = "list" + data: List[ModelCard] = [] + + +class UsageInfo(BaseModel): + prompt_tokens: int = 0 + total_tokens: int = 0 + completion_tokens: Optional[int] = 0 + + +class ChatCompletionRequest(BaseModel): + model: str + messages: Union[str, List[Dict[str, str]]] + temperature: Optional[float] = 0.7 + top_p: Optional[float] = 1.0 + n: Optional[int] = 1 + max_tokens: Optional[int] = None + stop: Optional[Union[str, List[str]]] = None + stream: Optional[bool] = False + presence_penalty: Optional[float] = 0.0 + frequency_penalty: Optional[float] = 0.0 + user: Optional[str] = None + + +class ChatMessage(BaseModel): + role: str + content: str + + +class ChatCompletionResponseChoice(BaseModel): + index: int + message: ChatMessage + finish_reason: Optional[Literal["stop", "length"]] + + +class ChatCompletionResponse(BaseModel): + id: str = Field(default_factory=lambda: f"chatcmpl-{shortuuid.random()}") + object: str = "chat.completion" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[ChatCompletionResponseChoice] + usage: UsageInfo + + +class DeltaMessage(BaseModel): + role: Optional[str] = None + content: Optional[str] = None + + +class ChatCompletionResponseStreamChoice(BaseModel): + index: int + delta: DeltaMessage + finish_reason: Optional[Literal["stop", "length"]] + + +class ChatCompletionStreamResponse(BaseModel): + id: str = Field(default_factory=lambda: f"chatcmpl-{shortuuid.random()}") + object: str = "chat.completion.chunk" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[ChatCompletionResponseStreamChoice] + +class TokenCheckRequest(BaseModel): + model: str + prompt: str + max_tokens: int + +class TokenCheckResponse(BaseModel): + fits: bool + tokenCount: int + contextLength: int + +class EmbeddingsRequest(BaseModel): + model: Optional[str] = None + engine: Optional[str] = None + input: Union[str, List[Any]] + user: Optional[str] = None + + +class EmbeddingsResponse(BaseModel): + object: str = "list" + data: List[Dict[str, Any]] + model: str + usage: UsageInfo + + +class CompletionRequest(BaseModel): + model: str + prompt: Union[str, List[Any]] + suffix: Optional[str] = None + temperature: Optional[float] = 0.7 + n: Optional[int] = 1 + max_tokens: Optional[int] = 16 + stop: Optional[Union[str, List[str]]] = None + stream: Optional[bool] = False + top_p: Optional[float] = 1.0 + logprobs: Optional[int] = None + echo: Optional[bool] = False + presence_penalty: Optional[float] = 0.0 + frequency_penalty: Optional[float] = 0.0 + user: Optional[str] = None + + +class CompletionResponseChoice(BaseModel): + index: int + text: str + logprobs: Optional[int] = None + finish_reason: Optional[Literal["stop", "length"]] + + +class CompletionResponse(BaseModel): + id: str = Field(default_factory=lambda: f"cmpl-{shortuuid.random()}") + object: str = "text_completion" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[CompletionResponseChoice] + usage: UsageInfo + + +class CompletionResponseStreamChoice(BaseModel): + index: int + text: str + logprobs: Optional[float] = None + finish_reason: Optional[Literal["stop", "length"]] = None + + +class CompletionStreamResponse(BaseModel): + id: str = Field(default_factory=lambda: f"cmpl-{shortuuid.random()}") + object: str = "text_completion" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[CompletionResponseStreamChoice] diff --git a/graphgpt/serve/__init__.py b/graphgpt/serve/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/graphgpt/serve/__pycache__/__init__.cpython-311.pyc b/graphgpt/serve/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000..626ea65 Binary files /dev/null and b/graphgpt/serve/__pycache__/__init__.cpython-311.pyc differ diff --git a/graphgpt/serve/__pycache__/controller_graph.cpython-311.pyc b/graphgpt/serve/__pycache__/controller_graph.cpython-311.pyc new file mode 100644 index 0000000..afeb574 Binary files /dev/null and b/graphgpt/serve/__pycache__/controller_graph.cpython-311.pyc differ diff --git a/graphgpt/serve/__pycache__/gradio_web_server.cpython-311.pyc b/graphgpt/serve/__pycache__/gradio_web_server.cpython-311.pyc new file mode 100644 index 0000000..3e83b8d Binary files /dev/null and b/graphgpt/serve/__pycache__/gradio_web_server.cpython-311.pyc differ diff --git a/graphgpt/serve/__pycache__/gradio_web_server_graph.cpython-311.pyc b/graphgpt/serve/__pycache__/gradio_web_server_graph.cpython-311.pyc new file mode 100644 index 0000000..d91fd87 Binary files /dev/null and b/graphgpt/serve/__pycache__/gradio_web_server_graph.cpython-311.pyc differ diff --git a/graphgpt/serve/api_provider.py b/graphgpt/serve/api_provider.py new file mode 100644 index 0000000..268b81d --- /dev/null +++ b/graphgpt/serve/api_provider.py @@ -0,0 +1,148 @@ +"""Call API providers.""" + +import os +import random +import time + +from fastchat.utils import build_logger + + +logger = build_logger("gradio_web_server", "gradio_web_server.log") + + +def openai_api_stream_iter(model_name, messages, temperature, top_p, max_new_tokens): + import openai + + # Make requests + gen_params = { + "model": model_name, + "prompt": messages, + "temperature": temperature, + "top_p": top_p, + } + logger.info(f"==== request ====\n{gen_params}") + + res = openai.ChatCompletion.create( + model=model_name, messages=messages, temperature=temperature, stream=True + ) + text = "" + for chunk in res: + text += chunk["choices"][0]["delta"].get("content", "") + data = { + "text": text, + "error_code": 0, + } + yield data + + +def anthropic_api_stream_iter(model_name, prompt, temperature, top_p, max_new_tokens): + import anthropic + + c = anthropic.Client(os.environ["ANTHROPIC_API_KEY"]) + + # Make requests + gen_params = { + "model": model_name, + "prompt": prompt, + "temperature": temperature, + "top_p": top_p, + } + logger.info(f"==== request ====\n{gen_params}") + + res = c.completion_stream( + prompt=prompt, + stop_sequences=[anthropic.HUMAN_PROMPT], + max_tokens_to_sample=max_new_tokens, + temperature=temperature, + top_p=top_p, + model=model_name, + stream=True, + ) + for chunk in res: + data = { + "text": chunk["completion"], + "error_code": 0, + } + yield data + + +def bard_api_stream_iter(state): + # TODO: we will use the official PaLM 2 API sooner or later, + # and we will update this function accordingly. So here we just hard code the + # Bard worker address. It is going to be deprecated anyway. + conv = state.conv + + # Make requests + gen_params = { + "model": "bard", + "prompt": state.messages, + } + logger.info(f"==== request ====\n{gen_params}") + + response = requests.post( + "http://localhost:18900/chat", + json={ + "content": conv.messages[-2][-1], + "state": state.bard_session_state, + }, + stream=False, + timeout=WORKER_API_TIMEOUT, + ) + resp_json = response.json() + state.bard_session_state = resp_json["state"] + content = resp_json["content"] + # The Bard Web API does not support streaming yet. Here we have to simulate + # the streaming behavior by adding some time.sleep(). + pos = 0 + while pos < len(content): + # This is a fancy way to simulate token generation latency combined + # with a Poisson process. + pos += random.randint(1, 5) + time.sleep(random.expovariate(50)) + data = { + "text": content[:pos], + "error_code": 0, + } + yield data + + +def init_palm_chat(model_name): + import vertexai # pip3 install google-cloud-aiplatform + from vertexai.preview.language_models import ChatModel + + project_id = os.environ["GCP_PROJECT_ID"] + location = "us-central1" + vertexai.init(project=project_id, location=location) + + chat_model = ChatModel.from_pretrained(model_name) + chat = chat_model.start_chat(examples=[]) + return chat + + +def palm_api_stream_iter(chat, message, temperature, top_p, max_new_tokens): + parameters = { + "temperature": temperature, + "top_p": top_p, + "max_output_tokens": max_new_tokens, + } + gen_params = { + "model": "bard", + "prompt": message, + } + gen_params.update(parameters) + logger.info(f"==== request ====\n{gen_params}") + + response = chat.send_message(message, **parameters) + content = response.text + + pos = 0 + while pos < len(content): + # This is a fancy way to simulate token generation latency combined + # with a Poisson process. + pos += random.randint(10, 20) + time.sleep(random.expovariate(50)) + data = { + "text": content[:pos], + "error_code": 0, + } + yield data diff --git a/graphgpt/serve/bard_worker.py b/graphgpt/serve/bard_worker.py new file mode 100644 index 0000000..dc524a4 --- /dev/null +++ b/graphgpt/serve/bard_worker.py @@ -0,0 +1,159 @@ +""" +Adapted from https://github.com/acheong08/Bard. +""" +import argparse +import json +import random +import re +import string + +from fastapi import FastAPI +import httpx +from pydantic import BaseModel, Field +from typing import List, Optional, Union +import uvicorn + + +class ConversationState(BaseModel): + conversation_id: str = "" + response_id: str = "" + choice_id: str = "" + req_id: int = 0 + + +class Message(BaseModel): + content: str + state: ConversationState = Field(default_factory=ConversationState) + + +class Response(BaseModel): + content: str + factualityQueries: Optional[List] + textQuery: Optional[Union[str, List]] + choices: List[dict] + state: ConversationState + + +class Chatbot: + """ + A class to interact with Google Bard. + Parameters + session_id: str + The __Secure-1PSID cookie. + """ + + def __init__(self, session_id): + headers = { + "Host": "bard.google.com", + "X-Same-Domain": "1", + "User-Agent": "Mozilla/5.0 (Windows NT 10.0; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.114 Safari/537.36", + "Content-Type": "application/x-www-form-urlencoded;charset=UTF-8", + "Origin": "https://bard.google.com", + "Referer": "https://bard.google.com/", + } + self.session = httpx.AsyncClient() + self.session.headers = headers + self.session.cookies.set("__Secure-1PSID", session_id) + self.SNlM0e = None + + async def _get_snlm0e(self): + resp = await self.session.get(url="https://bard.google.com/", timeout=10) + # Find "SNlM0e":"" + if resp.status_code != 200: + raise Exception("Could not get Google Bard") + SNlM0e = re.search(r"SNlM0e\":\"(.*?)\"", resp.text).group(1) + return SNlM0e + + async def ask(self, message: Message) -> Response: + """ + Send a message to Google Bard and return the response. + :param message: The message to send to Google Bard. + :return: A dict containing the response from Google Bard. + """ + if message.state.conversation_id == "": + message.state.req_id = int("".join(random.choices(string.digits, k=4))) + # url params + params = { + # "bl": "boq_assistant-bard-web-server_20230315.04_p2", + # This is a newer API version + "bl": "boq_assistant-bard-web-server_20230507.20_p2", + "_reqid": str(message.state.req_id), + "rt": "c", + } + + # message arr -> data["f.req"]. Message is double json stringified + message_struct = [ + [message.content], + None, + [ + message.state.conversation_id, + message.state.response_id, + message.state.choice_id, + ], + ] + data = { + "f.req": json.dumps([None, json.dumps(message_struct)]), + "at": self.SNlM0e, + } + + # do the request! + resp = await self.session.post( + "https://bard.google.com/_/BardChatUi/data/assistant.lamda.BardFrontendService/StreamGenerate", + params=params, + data=data, + timeout=60, + ) + + chat_data = json.loads(resp.content.splitlines()[3])[0][2] + if not chat_data: + return Response( + content=f"Google Bard encountered an error: {resp.content}.", + factualityQueries=[], + textQuery="", + choices=[], + state=message.state, + ) + json_chat_data = json.loads(chat_data) + conversation = ConversationState( + conversation_id=json_chat_data[1][0], + response_id=json_chat_data[1][1], + choice_id=json_chat_data[4][0][0], + req_id=message.state.req_id + 100000, + ) + return Response( + content=json_chat_data[0][0], + factualityQueries=json_chat_data[3], + textQuery=json_chat_data[2][0] if json_chat_data[2] is not None else "", + choices=[{"id": i[0], "content": i[1]} for i in json_chat_data[4]], + state=conversation, + ) + + +app = FastAPI() +chatbot = None + + +@app.on_event("startup") +async def startup_event(): + global chatbot + cookie = json.load(open("bard_cookie.json")) + chatbot = Chatbot(cookie["__Secure-1PSID"]) + chatbot.SNlM0e = await chatbot._get_snlm0e() + + +@app.post("/chat", response_model=Response) +async def chat(message: Message): + response = await chatbot.ask(message) + return response + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("Google Bard worker") + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--port", type=int, default=18900) + parser.add_argument("--reload", action="store_true") + args = parser.parse_args() + uvicorn.run( + "bard_worker:app", host=args.host, port=args.port, log_level="info", + reload=args.reload + ) diff --git a/graphgpt/serve/cacheflow_worker.py b/graphgpt/serve/cacheflow_worker.py new file mode 100644 index 0000000..de3ed3b --- /dev/null +++ b/graphgpt/serve/cacheflow_worker.py @@ -0,0 +1,346 @@ +""" +A model worker executes the model based on Cacheflow. + +Install Cacheflow first. Then, assuming controller is live: +1. ray start --head +2. python3 -m fastchat.serve.cacheflow_worker --model-path path_to_vicuna + +launch Gradio: +3. python3 -m fastchat.serve.gradio_web_server --concurrency-count 10000 +""" +import argparse +import asyncio +import json +import threading +import time +import uuid +from typing import List, Dict + +import requests +import torch +import uvicorn +from fastapi import FastAPI, Request, BackgroundTasks +from fastapi.responses import StreamingResponse +from transformers import AutoTokenizer + +from cacheflow.master.server import Server, initialize_ray_cluster +from cacheflow.sampling_params import SamplingParams +from cacheflow.sequence import Sequence, SequenceGroup +from cacheflow.utils import Counter, get_gpu_memory, get_cpu_memory +from fastchat.constants import WORKER_HEART_BEAT_INTERVAL +from fastchat.utils import build_logger, pretty_print_semaphore + +GB = 1 << 30 +TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds + +worker_id = str(uuid.uuid4())[:6] +logger = build_logger("model_worker", f"model_worker_{worker_id}.log") +global_counter = 0 +seed = torch.cuda.current_device() + + +def heart_beat_worker(controller): + while True: + time.sleep(WORKER_HEART_BEAT_INTERVAL) + controller.send_heart_beat() + + +class CacheFlowWorker: + def __init__( + self, + controller_addr, + worker_addr, + worker_id, + no_register, + model_path, + model_name, + block_size, + seed, + swap_space, + max_num_batched_tokens, + distributed_init_method, + all_stage_devices, + ): + self.controller_addr = controller_addr + self.worker_addr = worker_addr + self.worker_id = worker_id + if model_path.endswith("/"): + model_path = model_path[:-1] + self.model_name = model_name or model_path.split("/")[-1] + + logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...") + self.block_size = block_size + + # FIXME(Hao): we need to pass the tokenizer into cacheflow because we need + # to detect the stopping criteria "###". + self.tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) + self.seq_group_counter = Counter() + self.seq_counter = Counter() + # FIXME(Hao): hard code context len + self.context_len = 2048 + # pipeline_parallel_size = 1, + # tensor_parallel_size = 1, + # dtype = torch.float16 + remote_server_class = Server + self.server = remote_server_class( + model=self.model_name, + model_path=model_path, + pipeline_parallel_size=1, + tensor_parallel_size=1, + block_size=block_size, + dtype=torch.float16, + seed=seed, + swap_space=swap_space, + max_num_batched_tokens=max_num_batched_tokens, + num_nodes=1, + num_devices_per_node=4, + distributed_init_method=distributed_init_method, + all_stage_devices=all_stage_devices, + gpu_memory=get_gpu_memory(), + cpu_memory=get_cpu_memory(), + ) + self.running_seq_groups: Dict[int, SequenceGroup] = {} + self.sequence_group_events: Dict[int, asyncio.Event] = {} + self.is_server_running = False + + if not no_register: + time.sleep(30) # wait for model loading + self.register_to_controller() + self.heart_beat_thread = threading.Thread( + target=heart_beat_worker, args=(self,) + ) + self.heart_beat_thread.start() + + def register_to_controller(self): + logger.info("Register to controller") + + url = self.controller_addr + "/register_worker" + data = { + "worker_name": self.worker_addr, + "check_heart_beat": True, + "worker_status": self.get_status(), + } + r = requests.post(url, json=data) + assert r.status_code == 200 + + def send_heart_beat(self): + logger.info( + f"Send heart beat. Models: {[self.model_name]}. " + f"Semaphore: {pretty_print_semaphore(model_semaphore)}. " + f"global_counter: {global_counter}" + ) + + url = self.controller_addr + "/receive_heart_beat" + + while True: + try: + ret = requests.post( + url, + json={ + "worker_name": self.worker_addr, + "queue_length": self.get_queue_length(), + }, + timeout=5, + ) + exist = ret.json()["exist"] + break + except requests.exceptions.RequestException as e: + logger.error(f"heart beat error: {e}") + time.sleep(5) + + if not exist: + self.register_to_controller() + + def get_queue_length(self): + if ( + model_semaphore is None + or model_semaphore._value is None + or model_semaphore._waiters is None + ): + return 0 + else: + return ( + args.limit_model_concurrency + - model_semaphore._value + + len(model_semaphore._waiters) + ) + + def get_status(self): + return { + "model_names": [self.model_name], + "speed": 1, + "queue_length": self.get_queue_length(), + } + + async def server_step(self): + self.is_server_running = True + updated_seq_groups = self.server.step() + self.is_server_running = False + # Notify the waiting coroutines that there new outputs ready. + for seq_group in updated_seq_groups: + group_id = seq_group.group_id + self.running_seq_groups[group_id] = seq_group + self.sequence_group_events[group_id].set() + + async def generate_stream(self, params): + tokenizer = self.tokenizer + context = params["prompt"] + temperature = float(params.get("temperature", 1.0)) + top_p = float(params.get("top_p", 1.0)) + max_new_tokens = min(int(params.get("max_new_tokens", 256)), 1024) + stop_str = params.get("stop", None) + echo = params.get("echo", True) + stop_token_ids = params.get("stop_token_ids", None) or [] + stop_token_ids.append(tokenizer.eos_token_id) + + input_ids = tokenizer(context).input_ids + max_src_len = self.context_len - max_new_tokens - 8 + input_ids = input_ids[-max_src_len:] + + # make sampling params in cacheflow + top_p = max(top_p, 1e-5) + if temperature <= 1e-5: + top_p = 1.0 + sampling_params = SamplingParams( + n=1, + temperature=temperature, + top_p=top_p, + use_beam_search=False, + stop_token_ids=stop_token_ids, + max_num_steps=max_new_tokens, + num_logprobs=0, + context_window_size=None, + ) + + if stop_str is not None: + sampling_params.stop_str = stop_str + # we might sample multiple sequences, but in chatbot, this is one + seqs: List[Sequence] = [] + for _ in range(sampling_params.n): + seq_id = next(self.seq_counter) + seq = Sequence(seq_id, input_ids, block_size=self.block_size) + seqs.append(seq) + + arrival_time = time.time() + group_id = next(self.seq_group_counter) + # logger.info(f"Group {group_id} arrives at {time.time()}") + seq_group = SequenceGroup(group_id, seqs, arrival_time) + group_event = asyncio.Event() + self.running_seq_groups[group_id] = seq_group + self.sequence_group_events[group_id] = group_event + self.server.add_sequence_groups([(seq_group, sampling_params)]) + while True: + if not self.is_server_running: + await self.server_step() + try: + await asyncio.wait_for( + group_event.wait(), timeout=TIMEOUT_TO_PREVENT_DEADLOCK + ) + except: + pass + group_event.clear() + seq_group = self.running_seq_groups[group_id] + all_outputs = [] + for seq in seq_group.seqs: + token_ids = seq.get_token_ids() + if not echo: + token_ids = token_ids[len(input_ids) :] + output = self.tokenizer.decode(token_ids, skip_special_tokens=True) + if stop_str is not None: + if output.endswith(stop_str): + output = output[: -len(stop_str)] + all_outputs.append(output) + assert len(seq_group.seqs) == 1 + ret = { + "text": all_outputs[0], + "error_code": 0, + } + yield (json.dumps(ret) + "\0").encode("utf-8") + if seq_group.is_finished(): + del self.running_seq_groups[group_id] + del self.sequence_group_events[group_id] + break + + +app = FastAPI() +model_semaphore = None + + +def release_model_semaphore(): + model_semaphore.release() + + +@app.post("/worker_generate_stream") +async def generate_stream(request: Request): + global model_semaphore, global_counter + global_counter += 1 + params = await request.json() + + if model_semaphore is None: + model_semaphore = asyncio.Semaphore(args.limit_model_concurrency) + await model_semaphore.acquire() + background_tasks = BackgroundTasks() + background_tasks.add_task(release_model_semaphore) + # return StreamingResponse(generator, background=background_tasks) + return StreamingResponse( + worker.generate_stream(params), background=background_tasks + ) + + +@app.post("/worker_get_status") +async def get_status(request: Request): + return worker.get_status() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--port", type=int, default=21002) + parser.add_argument("--worker-address", type=str, default="http://localhost:21002") + parser.add_argument( + "--controller-address", type=str, default="http://localhost:21001" + ) + parser.add_argument( + "--model-path", type=str, default="/home/haozhang/weights/hf-llama-7b" + ) + parser.add_argument("--model-name", type=str) + parser.add_argument("--limit-model-concurrency", type=int, default=1024) + parser.add_argument("--stream-interval", type=int, default=2) + parser.add_argument("--no-register", action="store_true") + # cacheflow specific params + parser.add_argument( + "--block-size", type=int, default=8, choices=[8, 16], help="token block size" + ) + parser.add_argument( + "--swap-space", type=int, default=20, help="CPU swap space size (GiB) per GPU" + ) + parser.add_argument( + "--max-num-batched-tokens", + type=int, + default=2560, + help="maximum number of batched tokens", + ) + args = parser.parse_args() + + ( + num_nodes, + num_devices_per_node, + distributed_init_method, + all_stage_devices, + ) = initialize_ray_cluster(pipeline_parallel_size=1, tensor_parallel_size=1) + + worker = CacheFlowWorker( + args.controller_address, + args.worker_address, + worker_id, + args.no_register, + args.model_path, + args.model_name, + args.block_size, + seed, + args.swap_space, + args.max_num_batched_tokens, + distributed_init_method, + all_stage_devices, + ) + uvicorn.run(app, host=args.host, port=args.port, log_level="info") diff --git a/graphgpt/serve/cli.py b/graphgpt/serve/cli.py new file mode 100644 index 0000000..3d97a17 --- /dev/null +++ b/graphgpt/serve/cli.py @@ -0,0 +1,200 @@ +""" +Chat with a model with command line interface. + +Usage: +python3 -m fastchat.serve.cli --model lmsys/fastchat-t5-3b-v1.0 +python3 -m fastchat.serve.cli --model ~/model_weights/vicuna-7b +""" +import argparse +import os +import re +import sys + +from prompt_toolkit import PromptSession +from prompt_toolkit.auto_suggest import AutoSuggestFromHistory +from prompt_toolkit.completion import WordCompleter +from prompt_toolkit.history import InMemoryHistory +from rich.console import Console +from rich.markdown import Markdown +from rich.live import Live + +from fastchat.model.model_adapter import add_model_args +from fastchat.serve.inference import chat_loop, ChatIO + + +class SimpleChatIO(ChatIO): + def prompt_for_input(self, role) -> str: + return input(f"{role}: ") + + def prompt_for_output(self, role: str): + print(f"{role}: ", end="", flush=True) + + def stream_output(self, output_stream): + pre = 0 + for outputs in output_stream: + output_text = outputs["text"] + output_text = output_text.strip().split(" ") + now = len(output_text) - 1 + if now > pre: + print(" ".join(output_text[pre:now]), end=" ", flush=True) + pre = now + print(" ".join(output_text[pre:]), flush=True) + return " ".join(output_text) + + +class RichChatIO(ChatIO): + def __init__(self): + self._prompt_session = PromptSession(history=InMemoryHistory()) + self._completer = WordCompleter( + words=["!exit", "!reset"], pattern=re.compile("$") + ) + self._console = Console() + + def prompt_for_input(self, role) -> str: + self._console.print(f"[bold]{role}:") + # TODO(suquark): multiline input has some issues. fix it later. + prompt_input = self._prompt_session.prompt( + completer=self._completer, + multiline=False, + auto_suggest=AutoSuggestFromHistory(), + key_bindings=None, + ) + self._console.print() + return prompt_input + + def prompt_for_output(self, role: str): + self._console.print(f"[bold]{role}:") + + def stream_output(self, output_stream): + """Stream output from a role.""" + # TODO(suquark): the console flickers when there is a code block + # above it. We need to cut off "live" when a code block is done. + + # Create a Live context for updating the console output + with Live(console=self._console, refresh_per_second=4) as live: + # Read lines from the stream + for outputs in output_stream: + if not outputs: + continue + text = outputs["text"] + # Render the accumulated text as Markdown + # NOTE: this is a workaround for the rendering "unstandard markdown" + # in rich. The chatbots output treat "\n" as a new line for + # better compatibility with real-world text. However, rendering + # in markdown would break the format. It is because standard markdown + # treat a single "\n" in normal text as a space. + # Our workaround is adding two spaces at the end of each line. + # This is not a perfect solution, as it would + # introduce trailing spaces (only) in code block, but it works well + # especially for console output, because in general the console does not + # care about trailing spaces. + lines = [] + for line in text.splitlines(): + lines.append(line) + if line.startswith("```"): + # Code block marker - do not add trailing spaces, as it would + # break the syntax highlighting + lines.append("\n") + else: + lines.append(" \n") + markdown = Markdown("".join(lines)) + # Update the Live console output + live.update(markdown) + self._console.print() + return text + + +class ProgrammaticChatIO(ChatIO): + def prompt_for_input(self, role) -> str: + print(f"[!OP:{role}]: ", end="", flush=True) + contents = "" + # `end_sequence` is a randomly-generated, 16-digit number + # that signals the end of a message. It is unlikely to occur in + # message content. + end_sequence = "9745805894023423" + while True: + if len(contents) >= 16: + last_chars = contents[-16:] + if last_chars == end_sequence: + break + try: + char = sys.stdin.read(1) + contents = contents + char + except EOFError: + continue + return contents[:-16] + + def prompt_for_output(self, role: str): + print(f"[!OP:{role}]: ", end="", flush=True) + + def stream_output(self, output_stream): + pre = 0 + for outputs in output_stream: + output_text = outputs["text"] + output_text = output_text.strip().split(" ") + now = len(output_text) - 1 + if now > pre: + print(" ".join(output_text[pre:now]), end=" ", flush=True) + pre = now + print(" ".join(output_text[pre:]), flush=True) + return " ".join(output_text) + + +def main(args): + if args.gpus: + if len(args.gpus.split(",")) < args.num_gpus: + raise ValueError( + f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!" + ) + os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus + + if args.style == "simple": + chatio = SimpleChatIO() + elif args.style == "rich": + chatio = RichChatIO() + elif args.style == "programmatic": + chatio = ProgrammaticChatIO() + else: + raise ValueError(f"Invalid style for console: {args.style}") + try: + chat_loop( + args.model_path, + args.device, + args.num_gpus, + args.max_gpu_memory, + args.load_8bit, + args.cpu_offloading, + args.conv_template, + args.temperature, + args.repetition_penalty, + args.max_new_tokens, + chatio, + args.debug, + ) + except KeyboardInterrupt: + print("exit...") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + add_model_args(parser) + parser.add_argument( + "--conv-template", type=str, default=None, help="Conversation prompt template." + ) + parser.add_argument("--temperature", type=float, default=0.7) + parser.add_argument("--repetition_penalty", type=float, default=1.0) + parser.add_argument("--max-new-tokens", type=int, default=512) + parser.add_argument( + "--style", + type=str, + default="simple", + choices=["simple", "rich", "programmatic"], + help="Display style.", + ) + parser.add_argument( + "--debug", + action="store_true", + help="Print useful debug information (e.g., prompts)", + ) + args = parser.parse_args() + main(args) \ No newline at end of file diff --git a/graphgpt/serve/controller.py b/graphgpt/serve/controller.py new file mode 100644 index 0000000..3b1ecf6 --- /dev/null +++ b/graphgpt/serve/controller.py @@ -0,0 +1,361 @@ +""" +A controller manages distributed workers. +It sends worker addresses to clients. +""" +import argparse +import asyncio +import dataclasses +from enum import Enum, auto +import json +import logging +import time +from typing import List, Union +import threading + +from fastapi import FastAPI, Request +from fastapi.responses import StreamingResponse +import numpy as np +import requests +import uvicorn + +from fastchat.constants import (CONTROLLER_HEART_BEAT_EXPIRATION, ErrorCode, + SERVER_ERROR_MSG) +from fastchat.utils import build_logger + + +logger = build_logger("controller", "controller.log") + + +class DispatchMethod(Enum): + LOTTERY = auto() + SHORTEST_QUEUE = auto() + + @classmethod + def from_str(cls, name): + if name == "lottery": + return cls.LOTTERY + elif name == "shortest_queue": + return cls.SHORTEST_QUEUE + else: + raise ValueError(f"Invalid dispatch method") + + +@dataclasses.dataclass +class WorkerInfo: + model_names: List[str] + speed: int + queue_length: int + check_heart_beat: bool + last_heart_beat: str + + +def heart_beat_controller(controller): + while True: + time.sleep(CONTROLLER_HEART_BEAT_EXPIRATION) + controller.remove_stable_workers_by_expiration() + + +class Controller: + def __init__(self, dispatch_method: str): + # Dict[str -> WorkerInfo] + self.worker_info = {} + self.dispatch_method = DispatchMethod.from_str(dispatch_method) + + self.heart_beat_thread = threading.Thread( + target=heart_beat_controller, args=(self,) + ) + self.heart_beat_thread.start() + + logger.info("Init controller") + + def register_worker( + self, worker_name: str, check_heart_beat: bool, worker_status: dict + ): + if worker_name not in self.worker_info: + logger.info(f"Register a new worker: {worker_name}") + else: + logger.info(f"Register an existing worker: {worker_name}") + + if not worker_status: + worker_status = self.get_worker_status(worker_name) + if not worker_status: + return False + + self.worker_info[worker_name] = WorkerInfo( + worker_status["model_names"], + worker_status["speed"], + worker_status["queue_length"], + check_heart_beat, + time.time(), + ) + + logger.info(f"Register done: {worker_name}, {worker_status}") + return True + + def get_worker_status(self, worker_name: str): + try: + r = requests.post(worker_name + "/worker_get_status", timeout=5) + except requests.exceptions.RequestException as e: + logger.error(f"Get status fails: {worker_name}, {e}") + return None + + if r.status_code != 200: + logger.error(f"Get status fails: {worker_name}, {r}") + return None + + return r.json() + + def remove_worker(self, worker_name: str): + del self.worker_info[worker_name] + + def refresh_all_workers(self): + old_info = dict(self.worker_info) + self.worker_info = {} + + for w_name, w_info in old_info.items(): + if not self.register_worker(w_name, w_info.check_heart_beat, None): + logger.info(f"Remove stale worker: {w_name}") + + def list_models(self): + model_names = set() + + for w_name, w_info in self.worker_info.items(): + model_names.update(w_info.model_names) + + return list(model_names) + + def get_worker_address(self, model_name: str): + if self.dispatch_method == DispatchMethod.LOTTERY: + worker_names = [] + worker_speeds = [] + for w_name, w_info in self.worker_info.items(): + if model_name in w_info.model_names: + worker_names.append(w_name) + worker_speeds.append(w_info.speed) + worker_speeds = np.array(worker_speeds, dtype=np.float32) + norm = np.sum(worker_speeds) + if norm < 1e-4: + return "" + worker_speeds = worker_speeds / norm + if True: # Directly return address + pt = np.random.choice(np.arange(len(worker_names)), p=worker_speeds) + worker_name = worker_names[pt] + return worker_name + + # Check status before returning + while True: + pt = np.random.choice(np.arange(len(worker_names)), p=worker_speeds) + worker_name = worker_names[pt] + + if self.get_worker_status(worker_name): + break + else: + self.remove_worker(worker_name) + worker_speeds[pt] = 0 + norm = np.sum(worker_speeds) + if norm < 1e-4: + return "" + worker_speeds = worker_speeds / norm + continue + return worker_name + elif self.dispatch_method == DispatchMethod.SHORTEST_QUEUE: + worker_names = [] + worker_qlen = [] + for w_name, w_info in self.worker_info.items(): + if model_name in w_info.model_names: + worker_names.append(w_name) + worker_qlen.append(w_info.queue_length / w_info.speed) + if len(worker_names) == 0: + return "" + min_index = np.argmin(worker_qlen) + w_name = worker_names[min_index] + self.worker_info[w_name].queue_length += 1 + logger.info( + f"names: {worker_names}, queue_lens: {worker_qlen}, ret: {w_name}" + ) + return w_name + else: + raise ValueError(f"Invalid dispatch method: {self.dispatch_method}") + + def receive_heart_beat(self, worker_name: str, queue_length: int): + if worker_name not in self.worker_info: + logger.info(f"Receive unknown heart beat. {worker_name}") + return False + + self.worker_info[worker_name].queue_length = queue_length + self.worker_info[worker_name].last_heart_beat = time.time() + logger.info(f"Receive heart beat. {worker_name}") + return True + + def remove_stable_workers_by_expiration(self): + expire = time.time() - CONTROLLER_HEART_BEAT_EXPIRATION + to_delete = [] + for worker_name, w_info in self.worker_info.items(): + if w_info.check_heart_beat and w_info.last_heart_beat < expire: + to_delete.append(worker_name) + + for worker_name in to_delete: + self.remove_worker(worker_name) + + def handle_no_worker(params): + logger.info(f"no worker: {params['model']}") + ret = { + "text": SERVER_ERROR_MSG, + "error_code": ErrorCode.CONTROLLER_NO_WORKER, + } + return json.dumps(ret).encode() + b"\0" + + def handle_worker_timeout(worker_address): + logger.info(f"worker timeout: {worker_address}") + ret = { + "text": SERVER_ERROR_MSG, + "error_code": ErrorCode.CONTROLLER_WORKER_TIMEOUT, + } + return json.dumps(ret).encode() + b"\0" + + def worker_api_generate_stream(self, params): + worker_addr = self.get_worker_address(params["model"]) + if not worker_addr: + yield self.handle_no_worker(params) + + try: + response = requests.post( + worker_addr + "/worker_generate_stream", + json=params, + stream=True, + timeout=15, + ) + for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): + if chunk: + yield chunk + b"\0" + except requests.exceptions.RequestException as e: + yield self.handle_worker_timeout(worker_addr) + + def worker_api_generate_completion(self, params): + worker_addr = self.get_worker_address(params["model"]) + if not worker_addr: + return self.handle_no_worker(params) + + try: + response = requests.post( + worker_addr + "/worker_generate_completion", + json=params, + timeout=15, + ) + return response.json() + except requests.exceptions.RequestException as e: + return self.handle_worker_timeout(worker_addr) + + def worker_api_embeddings(self, params): + worker_addr = self.get_worker_address(params["model"]) + if not worker_addr: + return self.handle_no_worker(params) + + try: + response = requests.post( + worker_addr + "/worker_get_embeddings", + json=params, + timeout=15, + ) + return response.json() + except requests.exceptions.RequestException as e: + return self.handle_worker_timeout(worker_addr) + + # Let the controller act as a worker to achieve hierarchical + # management. This can be used to connect isolated sub networks. + def worker_api_get_status(self): + model_names = set() + speed = 0 + queue_length = 0 + + for w_name in self.worker_info: + worker_status = self.get_worker_status(w_name) + if worker_status is not None: + model_names.update(worker_status["model_names"]) + speed += worker_status["speed"] + queue_length += worker_status["queue_length"] + + return { + "model_names": list(model_names), + "speed": speed, + "queue_length": queue_length, + } + + +app = FastAPI() + + +@app.post("/register_worker") +async def register_worker(request: Request): + data = await request.json() + controller.register_worker( + data["worker_name"], data["check_heart_beat"], data.get("worker_status", None) + ) + + +@app.post("/refresh_all_workers") +async def refresh_all_workers(): + models = controller.refresh_all_workers() + + +@app.post("/list_models") +async def list_models(): + models = controller.list_models() + return {"models": models} + + +@app.post("/get_worker_address") +async def get_worker_address(request: Request): + data = await request.json() + addr = controller.get_worker_address(data["model"]) + return {"address": addr} + + +@app.post("/receive_heart_beat") +async def receive_heart_beat(request: Request): + data = await request.json() + exist = controller.receive_heart_beat(data["worker_name"], data["queue_length"]) + return {"exist": exist} + + +@app.post("/worker_generate_stream") +async def worker_api_generate_stream(request: Request): + params = await request.json() + generator = controller.worker_api_generate_stream(params) + return StreamingResponse(generator) + + +@app.post("/worker_generate_completion") +async def worker_api_generate_completion(request: Request): + params = await request.json() + output = controller.worker_api_generate_completion(params) + return output + + +@app.post("/worker_get_embeddings") +async def worker_api_embeddings(request: Request): + params = await request.json() + output = controller.worker_api_embeddings(params) + return output + + +@app.post("/worker_get_status") +async def worker_api_get_status(request: Request): + return controller.worker_api_get_status() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--port", type=int, default=21001) + parser.add_argument( + "--dispatch-method", + type=str, + choices=["lottery", "shortest_queue"], + default="shortest_queue", + ) + args = parser.parse_args() + logger.info(f"args: {args}") + + controller = Controller(args.dispatch_method) + uvicorn.run(app, host=args.host, port=args.port, log_level="info") diff --git a/graphgpt/serve/controller_graph.py b/graphgpt/serve/controller_graph.py new file mode 100644 index 0000000..f5bfb9c --- /dev/null +++ b/graphgpt/serve/controller_graph.py @@ -0,0 +1,298 @@ +""" +A controller manages distributed workers. +It sends worker addresses to clients. +""" +import argparse +import asyncio +import dataclasses +from enum import Enum, auto +import json +import logging +import time +from typing import List, Union +import threading + +from fastapi import FastAPI, Request +from fastapi.responses import StreamingResponse +import numpy as np +import requests +import uvicorn + +from graphgpt.constants import CONTROLLER_HEART_BEAT_EXPIRATION +from graphgpt.utils import build_logger, server_error_msg + + +logger = build_logger("controller", "controller.log") + + +class DispatchMethod(Enum): + LOTTERY = auto() + SHORTEST_QUEUE = auto() + + @classmethod + def from_str(cls, name): + if name == "lottery": + return cls.LOTTERY + elif name == "shortest_queue": + return cls.SHORTEST_QUEUE + else: + raise ValueError(f"Invalid dispatch method") + + +@dataclasses.dataclass +class WorkerInfo: + model_names: List[str] + speed: int + queue_length: int + check_heart_beat: bool + last_heart_beat: str + + +def heart_beat_controller(controller): + while True: + time.sleep(CONTROLLER_HEART_BEAT_EXPIRATION) + controller.remove_stable_workers_by_expiration() + + +class Controller: + def __init__(self, dispatch_method: str): + # Dict[str -> WorkerInfo] + self.worker_info = {} + self.dispatch_method = DispatchMethod.from_str(dispatch_method) + + self.heart_beat_thread = threading.Thread( + target=heart_beat_controller, args=(self,)) + self.heart_beat_thread.start() + + logger.info("Init controller") + + def register_worker(self, worker_name: str, check_heart_beat: bool, + worker_status: dict): + if worker_name not in self.worker_info: + logger.info(f"Register a new worker: {worker_name}") + else: + logger.info(f"Register an existing worker: {worker_name}") + + if not worker_status: + worker_status = self.get_worker_status(worker_name) + if not worker_status: + return False + + self.worker_info[worker_name] = WorkerInfo( + worker_status["model_names"], worker_status["speed"], worker_status["queue_length"], + check_heart_beat, time.time()) + + logger.info(f"Register done: {worker_name}, {worker_status}") + return True + + def get_worker_status(self, worker_name: str): + try: + r = requests.post(worker_name + "/worker_get_status", timeout=5) + except requests.exceptions.RequestException as e: + logger.error(f"Get status fails: {worker_name}, {e}") + return None + + if r.status_code != 200: + logger.error(f"Get status fails: {worker_name}, {r}") + return None + + return r.json() + + def remove_worker(self, worker_name: str): + del self.worker_info[worker_name] + + def refresh_all_workers(self): + old_info = dict(self.worker_info) + self.worker_info = {} + + for w_name, w_info in old_info.items(): + if not self.register_worker(w_name, w_info.check_heart_beat, None): + logger.info(f"Remove stale worker: {w_name}") + + def list_models(self): + model_names = set() + + for w_name, w_info in self.worker_info.items(): + model_names.update(w_info.model_names) + + return list(model_names) + + def get_worker_address(self, model_name: str): + if self.dispatch_method == DispatchMethod.LOTTERY: + worker_names = [] + worker_speeds = [] + for w_name, w_info in self.worker_info.items(): + if model_name in w_info.model_names: + worker_names.append(w_name) + worker_speeds.append(w_info.speed) + worker_speeds = np.array(worker_speeds, dtype=np.float32) + norm = np.sum(worker_speeds) + if norm < 1e-4: + return "" + worker_speeds = worker_speeds / norm + if True: # Directly return address + pt = np.random.choice(np.arange(len(worker_names)), + p=worker_speeds) + worker_name = worker_names[pt] + return worker_name + + # Check status before returning + while True: + pt = np.random.choice(np.arange(len(worker_names)), + p=worker_speeds) + worker_name = worker_names[pt] + + if self.get_worker_status(worker_name): + break + else: + self.remove_worker(worker_name) + worker_speeds[pt] = 0 + norm = np.sum(worker_speeds) + if norm < 1e-4: + return "" + worker_speeds = worker_speeds / norm + continue + return worker_name + elif self.dispatch_method == DispatchMethod.SHORTEST_QUEUE: + worker_names = [] + worker_qlen = [] + for w_name, w_info in self.worker_info.items(): + if model_name in w_info.model_names: + worker_names.append(w_name) + worker_qlen.append(w_info.queue_length / w_info.speed) + if len(worker_names) == 0: + return "" + min_index = np.argmin(worker_qlen) + w_name = worker_names[min_index] + self.worker_info[w_name].queue_length += 1 + logger.info(f"names: {worker_names}, queue_lens: {worker_qlen}, ret: {w_name}") + return w_name + else: + raise ValueError(f"Invalid dispatch method: {self.dispatch_method}") + + def receive_heart_beat(self, worker_name: str, queue_length: int): + if worker_name not in self.worker_info: + logger.info(f"Receive unknown heart beat. {worker_name}") + return False + + self.worker_info[worker_name].queue_length = queue_length + self.worker_info[worker_name].last_heart_beat = time.time() + logger.info(f"Receive heart beat. {worker_name}") + return True + + def remove_stable_workers_by_expiration(self): + expire = time.time() - CONTROLLER_HEART_BEAT_EXPIRATION + to_delete = [] + for worker_name, w_info in self.worker_info.items(): + if w_info.check_heart_beat and w_info.last_heart_beat < expire: + to_delete.append(worker_name) + + for worker_name in to_delete: + self.remove_worker(worker_name) + + def worker_api_generate_stream(self, params): + worker_addr = self.get_worker_address(params["model"]) + if not worker_addr: + logger.info(f"no worker: {params['model']}") + ret = { + "text": server_error_msg, + "error_code": 2, + } + yield json.dumps(ret).encode() + b"\0" + + try: + response = requests.post(worker_addr + "/worker_generate_stream", + json=params, stream=True, timeout=5) + for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): + if chunk: + yield chunk + b"\0" + except requests.exceptions.RequestException as e: + logger.info(f"worker timeout: {worker_addr}") + ret = { + "text": server_error_msg, + "error_code": 3, + } + yield json.dumps(ret).encode() + b"\0" + + + # Let the controller act as a worker to achieve hierarchical + # management. This can be used to connect isolated sub networks. + def worker_api_get_status(self): + model_names = set() + speed = 0 + queue_length = 0 + + for w_name in self.worker_info: + worker_status = self.get_worker_status(w_name) + if worker_status is not None: + model_names.update(worker_status["model_names"]) + speed += worker_status["speed"] + queue_length += worker_status["queue_length"] + + return { + "model_names": list(model_names), + "speed": speed, + "queue_length": queue_length, + } + + +app = FastAPI() + + +@app.post("/register_worker") +async def register_worker(request: Request): + data = await request.json() + controller.register_worker( + data["worker_name"], data["check_heart_beat"], + data.get("worker_status", None)) + + +@app.post("/refresh_all_workers") +async def refresh_all_workers(): + models = controller.refresh_all_workers() + + +@app.post("/list_models") +async def list_models(): + models = controller.list_models() + return {"models": models} + + +@app.post("/get_worker_address") +async def get_worker_address(request: Request): + data = await request.json() + addr = controller.get_worker_address(data["model"]) + return {"address": addr} + + +@app.post("/receive_heart_beat") +async def receive_heart_beat(request: Request): + data = await request.json() + exist = controller.receive_heart_beat( + data["worker_name"], data["queue_length"]) + return {"exist": exist} + + +@app.post("/worker_generate_stream") +async def worker_api_generate_stream(request: Request): + params = await request.json() + generator = controller.worker_api_generate_stream(params) + return StreamingResponse(generator) + + +@app.post("/worker_get_status") +async def worker_api_get_status(request: Request): + return controller.worker_api_get_status() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--port", type=int, default=21001) + parser.add_argument("--dispatch-method", type=str, choices=[ + "lottery", "shortest_queue"], default="shortest_queue") + args = parser.parse_args() + logger.info(f"args: {args}") + + controller = Controller(args.dispatch_method) + uvicorn.run(app, host=args.host, port=args.port, log_level="info") \ No newline at end of file diff --git a/graphgpt/serve/examples/extreme_ironing.jpg b/graphgpt/serve/examples/extreme_ironing.jpg new file mode 100644 index 0000000..638b078 Binary files /dev/null and b/graphgpt/serve/examples/extreme_ironing.jpg differ diff --git a/graphgpt/serve/examples/waterview.jpg b/graphgpt/serve/examples/waterview.jpg new file mode 100644 index 0000000..6f44eba Binary files /dev/null and b/graphgpt/serve/examples/waterview.jpg differ diff --git a/graphgpt/serve/gateway/README.md b/graphgpt/serve/gateway/README.md new file mode 100644 index 0000000..b3afaf1 --- /dev/null +++ b/graphgpt/serve/gateway/README.md @@ -0,0 +1,57 @@ +# fastchat Nginx Gateway + +## Purpose of the Gateway + +The Nginx gateway serves the following purposes: + +1. Protects Gradio servers by acting as a firewall. +2. Facilitates dynamic mounting and unmounting of Gradio servers. +3. Provides load balancing for Gradio servers. +4. Offers additional security features, such as total connection limit. +5. Reduces attack surface by requiring only a single public port to be exposed for serving. + +## Deployment and Updating of the Gateway + +### Installing Nginx + +On Debian-based distributions (e.g., Ubuntu): + +```bash +sudo apt update +sudo apt install nginx +``` +On Red Hat-based distributions (e.g., CentOS, Fedora): + +```bash +sudo yum install epel-release +sudo yum install nginx +``` + +### Deployment + +Copy `nginx.conf` to `/etc/nginx/nginx.conf` (need sudo permission). + +Replace the port number 7860 in `server localhost:7860` with the port where you deploy the Gradio web server. + +Modify `upstream websocket` to configure Gradio servers behind the gateway. + +Lastly, update Nginx. + + +### HTTPS Deployment with a Public Domain URL + +Make sure you obtain the HTTPS certificate and the private key used to generate the certificate. + +Fill the addresses to your certificate and private key in the `[PATH_TO_SSL_CERT]` and `[PATH_TO_PRIVATE_KEY]` fields. + +If you have your own domain url to serve the chatbot, replace the chat.lmsys.org url with your own domain url. + +### Updating + +Every time when `/etc/nginx/nginx.conf` is modified, you need to update the Nginx service: + +```bash +sudo nginx -t # check `/etc/nginx/nginx.conf` +sudo systemctl reload nginx # restart Nginx service to load the new config +sudo systemctl status nginx # check the status of the Nginx service. It should be active (running). +``` diff --git a/graphgpt/serve/gateway/nginx.conf b/graphgpt/serve/gateway/nginx.conf new file mode 100644 index 0000000..b88ca8c --- /dev/null +++ b/graphgpt/serve/gateway/nginx.conf @@ -0,0 +1,97 @@ +user www-data; +worker_processes auto; +pid /run/nginx.pid; +include /etc/nginx/modules-enabled/*.conf; + +events { + worker_connections 1024; # maximum number of connections that a worker process can handle concurrently + # multi_accept on; # enabling multi_accept can help improve performance under high load, but may increase the number of simultaneous connections that a worker process can handle + +} + +http { + ## + # Basic Settings + ## + + sendfile on; # enable sendfile for performance optimization + tcp_nopush on; # enable TCP no-pushing + tcp_nodelay on; # enable TCP no-delay + keepalive_timeout 65; # sets the timeout for keep-alive connections + types_hash_max_size 2048; # maximum size of the types hash table + # server_tokens off; # disable server token (i.e., server signature) in response headers to improve security + + # server_names_hash_bucket_size 64; + # server_name_in_redirect off; + + include /etc/nginx/mime.types; # include MIME types file + default_type application/octet-stream; # default MIME type for unknown file types + + ## + # SSL Settings + ## + + ssl_protocols TLSv1.2; # specify SSL/TLS protocols to use + ssl_prefer_server_ciphers on; # prefer server ciphers over client ciphers + + ## + # Logging Settings + ## + + access_log /var/log/nginx/access.log; # path to access log file + error_log /var/log/nginx/error.log; # path to error log file + + ## + # Gzip Settings + ## + gzip on; # enable Gzip compression + + ## + # Virtual Host Configs + ## + + include /etc/nginx/conf.d/*.conf; # include all configuration files in conf.d directory + include /etc/nginx/sites-enabled/*; # include all enabled sites configuration files + + # WebSocket Proxy: https://www.nginx.com/blog/websocket-nginx/ + map $http_upgrade $connection_upgrade { + default upgrade; + '' close; + } + + upstream websocket { + ip_hash; # load balancing by IP to guarantee session persistence + server localhost:7860; # The port should be the gradio web server port + # server localhost:7861; # extra gradio server if more than one + } + + limit_conn_status 429; + limit_conn_zone $binary_remote_addr zone=perip:10m; # limit number of connections per IP + limit_conn_zone $server_name zone=perserver:10m; # limit number of connections per server + + server { + listen 443 ssl; # the listening port of our server + ssl_certificate [PATH_TO_SSL_CERT]; + ssl_certificate_key [PATH_TO_PRIVATE_KEY]; + server_name chat.lmsys.org; # replace the url with your own domain url + limit_conn perserver 1024; # connections per server + location / { + proxy_pass http://websocket; # proxy all requests to the defined upstream server + limit_conn perip 5; # connections per IP + proxy_set_header Host $host; # set the Host header for the upstream server + proxy_set_header X-Real-IP $remote_addr; # set the client IP address as the real IP for the upstream server + proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; # set the client IP addresses in the X-Forwarded-For header + proxy_http_version 1.1; # use HTTP version 1.1 for upstream communication + proxy_set_header Upgrade $http_upgrade; + proxy_set_header Connection "Upgrade"; # set the Connection header to Upgrade to enable WebSocket communication + } + } + + # the following block routes all HTTP traffic to HTTPS via nginx + server { + listen 80; + server_name chat.lmsys.org; + return 301 https://chat.lmsys.org$request_uri; + } + +} diff --git a/graphgpt/serve/gradio_block_arena_anony.py b/graphgpt/serve/gradio_block_arena_anony.py new file mode 100644 index 0000000..217d9f1 --- /dev/null +++ b/graphgpt/serve/gradio_block_arena_anony.py @@ -0,0 +1,506 @@ +""" +Chatbot Arena (battle) tab. +Users chat with two anonymous models. +""" + +import json +import time + +import gradio as gr +import numpy as np + +from fastchat.constants import ( + MODERATION_MSG, + CONVERSATION_LIMIT_MSG, + INPUT_CHAR_LEN_LIMIT, + CONVERSATION_LEN_LIMIT, +) +from fastchat.model.model_adapter import get_conversation_template +from fastchat.serve.gradio_patch import Chatbot as grChatbot +from fastchat.serve.gradio_web_server import ( + State, + http_bot, + get_conv_log_filename, + no_change_btn, + enable_btn, + disable_btn, + learn_more_md, +) +from fastchat.utils import ( + build_logger, + violates_moderation, +) + +logger = build_logger("gradio_web_server_multi", "gradio_web_server_multi.log") + +num_models = 2 +enable_moderation = False +anony_names = ["", ""] +models = [] + + +def set_global_vars_anony(enable_moderation_): + global enable_moderation + enable_moderation = enable_moderation_ + + +def load_demo_side_by_side_anony(models_, url_params): + global models + models = models_ + + states = (None,) * num_models + selector_updates = ( + gr.Markdown.update(visible=True), + gr.Markdown.update(visible=True), + ) + + return ( + states + + selector_updates + + (gr.Chatbot.update(visible=True),) * num_models + + ( + gr.Textbox.update(visible=True), + gr.Box.update(visible=True), + gr.Row.update(visible=True), + gr.Row.update(visible=True), + gr.Accordion.update(visible=True), + ) + ) + + +def vote_last_response(states, vote_type, model_selectors, request: gr.Request): + with open(get_conv_log_filename(), "a") as fout: + data = { + "tstamp": round(time.time(), 4), + "type": vote_type, + "models": [x for x in model_selectors], + "states": [x.dict() for x in states], + "ip": request.client.host, + } + fout.write(json.dumps(data) + "\n") + + if ":" not in model_selectors[0]: + for i in range(15): + names = ( + "### Model A: " + states[0].model_name, + "### Model B: " + states[1].model_name, + ) + yield names + ("",) + (disable_btn,) * 4 + time.sleep(0.2) + else: + names = ( + "### Model A: " + states[0].model_name, + "### Model B: " + states[1].model_name, + ) + yield names + ("",) + (disable_btn,) * 4 + + +def leftvote_last_response( + state0, state1, model_selector0, model_selector1, request: gr.Request +): + logger.info(f"leftvote (anony). ip: {request.client.host}") + for x in vote_last_response( + [state0, state1], "leftvote", [model_selector0, model_selector1], request + ): + yield x + + +def rightvote_last_response( + state0, state1, model_selector0, model_selector1, request: gr.Request +): + logger.info(f"rightvote (anony). ip: {request.client.host}") + for x in vote_last_response( + [state0, state1], "rightvote", [model_selector0, model_selector1], request + ): + yield x + + +def tievote_last_response( + state0, state1, model_selector0, model_selector1, request: gr.Request +): + logger.info(f"tievote (anony). ip: {request.client.host}") + for x in vote_last_response( + [state0, state1], "tievote", [model_selector0, model_selector1], request + ): + yield x + + +def bothbad_vote_last_response( + state0, state1, model_selector0, model_selector1, request: gr.Request +): + logger.info(f"bothbad_vote (anony). ip: {request.client.host}") + for x in vote_last_response( + [state0, state1], "bothbad_vote", [model_selector0, model_selector1], request + ): + yield x + + +def regenerate(state0, state1, request: gr.Request): + logger.info(f"regenerate (anony). ip: {request.client.host}") + states = [state0, state1] + for i in range(num_models): + states[i].conv.update_last_message(None) + return states + [x.to_gradio_chatbot() for x in states] + [""] + [disable_btn] * 6 + + +def clear_history(request: gr.Request): + logger.info(f"clear_history (anony). ip: {request.client.host}") + return ( + [None] * num_models + + [None] * num_models + + anony_names + + [""] + + [disable_btn] * 6 + ) + + +def share_click(state0, state1, model_selector0, model_selector1, request: gr.Request): + logger.info(f"share (anony). ip: {request.client.host}") + if state0 is not None and state1 is not None: + vote_last_response( + [state0, state1], "share", [model_selector0, model_selector1], request + ) + + +DEFAULT_WEIGHTS = { + "gpt-4": 1.5, + "gpt-3.5-turbo": 1.5, + "claude-v1": 1.5, + "claude-instant-v1": 1.5, + "bard": 1.5, + "vicuna-13b": 1.5, + "koala-13b": 1.5, + "vicuna-7b": 1.2, + "mpt-7b-chat": 1.2, + "oasst-pythia-12b": 1.2, + "RWKV-4-Raven-14B": 1.2, + "fastchat-t5-3b": 1, + "alpaca-13b": 1, + "chatglm-6b": 1, + "stablelm-tuned-alpha-7b": 0.5, + "dolly-v2-12b": 0.5, + "llama-13b": 0.1, +} + + +def add_text( + state0, state1, model_selector0, model_selector1, text, request: gr.Request +): + logger.info(f"add_text (anony). ip: {request.client.host}. len: {len(text)}") + states = [state0, state1] + model_selectors = [model_selector0, model_selector1] + + if states[0] is None: + assert states[1] is None + weights = [DEFAULT_WEIGHTS.get(m, 1.0) for m in models] + if len(models) > 1: + weights = weights / np.sum(weights) + model_left, model_right = np.random.choice( + models, size=(2,), p=weights, replace=False + ) + else: + model_left = model_right = models[0] + + states = [ + State(model_left), + State(model_right), + ] + + if len(text) <= 0: + for i in range(num_models): + states[i].skip_next = True + return ( + states + + [x.to_gradio_chatbot() for x in states] + + [""] + + [ + no_change_btn, + ] + * 6 + ) + + if enable_moderation: + flagged = violates_moderation(text) + if flagged: + logger.info( + f"violate moderation (anony). ip: {request.client.host}. text: {text}" + ) + for i in range(num_models): + states[i].skip_next = True + return ( + states + + [x.to_gradio_chatbot() for x in states] + + [MODERATION_MSG] + + [ + no_change_btn, + ] + * 6 + ) + + conv = states[0].conv + if (len(conv.messages) - conv.offset) // 2 >= CONVERSATION_LEN_LIMIT: + logger.info( + f"hit conversation length limit. ip: {request.client.host}. text: {text}" + ) + for i in range(num_models): + states[i].skip_next = True + return ( + states + + [x.to_gradio_chatbot() for x in states] + + [CONVERSATION_LIMIT_MSG] + + [ + no_change_btn, + ] + * 6 + ) + + text = text[:INPUT_CHAR_LEN_LIMIT] # Hard cut-off + for i in range(num_models): + states[i].conv.append_message(states[i].conv.roles[0], text) + states[i].conv.append_message(states[i].conv.roles[1], None) + states[i].skip_next = False + + return ( + states + + [x.to_gradio_chatbot() for x in states] + + [""] + + [ + disable_btn, + ] + * 6 + ) + + +def http_bot_all( + state0, + state1, + temperature, + top_p, + max_new_tokens, + request: gr.Request, +): + logger.info(f"http_bot_all (anony). ip: {request.client.host}") + + if state0.skip_next: + # This generate call is skipped due to invalid inputs + yield ( + state0, + state1, + state0.to_gradio_chatbot(), + state1.to_gradio_chatbot(), + ) + (no_change_btn,) * 6 + return + + states = [state0, state1] + gen = [] + for i in range(num_models): + gen.append( + http_bot( + states[i], + temperature, + top_p, + max_new_tokens, + request, + ) + ) + + chatbots = [None] * num_models + while True: + stop = True + for i in range(num_models): + try: + ret = next(gen[i]) + states[i], chatbots[i] = ret[0], ret[1] + stop = False + except StopIteration: + pass + yield states + chatbots + [disable_btn] * 6 + if stop: + break + + for i in range(10): + if i % 2 == 0: + yield states + chatbots + [disable_btn] * 4 + [enable_btn] * 2 + else: + yield states + chatbots + [enable_btn] * 6 + time.sleep(0.2) + + +def build_side_by_side_ui_anony(models): + notice_markdown = """ +# ⚔️ Chatbot Arena ⚔️ +### Rules +- Chat with two anonymous models side-by-side and vote for which one is better! +- You can do multiple rounds of conversations before voting. +- The names of the models will be revealed after your vote. Conversations with identity keywords (e.g., ChatGPT, Bard, Vicuna) or any votes after the names are revealed will not count towards the leaderboard. +- Click "Clear history" to start a new round. +- [[Blog](https://lmsys.org/blog/2023-05-03-arena/)] [[GitHub]](https://github.com/lm-sys/FastChat) [[Twitter]](https://twitter.com/lmsysorg) [[Discord]](https://discord.gg/KjdtsE9V) + +### Terms of use +By using this service, users are required to agree to the following terms: The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. **The service collects user dialogue data and reserves the right to distribute it under a Creative Commons Attribution (CC-BY) license.** The demo works better on desktop devices with a wide screen. + +### Battle +Please scroll down and start chatting. You can view a leaderboard of participating models in the fourth tab above labeled 'Leaderboard' or by clicking [here](?leaderboard). The models include both closed-source models (e.g., ChatGPT) and open-source models (e.g., Vicuna). +""" + + states = [gr.State() for _ in range(num_models)] + model_selectors = [None] * num_models + chatbots = [None] * num_models + + gr.Markdown(notice_markdown, elem_id="notice_markdown") + + with gr.Box(elem_id="share-region-anony"): + with gr.Row(): + for i in range(num_models): + with gr.Column(): + model_selectors[i] = gr.Markdown(anony_names[i]) + + with gr.Row(): + for i in range(num_models): + label = "Model A" if i == 0 else "Model B" + with gr.Column(): + chatbots[i] = grChatbot( + label=label, elem_id=f"chatbot", visible=False + ).style(height=550) + + with gr.Box() as button_row: + with gr.Row(): + leftvote_btn = gr.Button(value="👈 A is better", interactive=False) + rightvote_btn = gr.Button(value="👉 B is better", interactive=False) + tie_btn = gr.Button(value="🤝 Tie", interactive=False) + bothbad_btn = gr.Button(value="👎 Both are bad", interactive=False) + + with gr.Row(): + with gr.Column(scale=20): + textbox = gr.Textbox( + show_label=False, + placeholder="Enter text and press ENTER", + visible=False, + ).style(container=False) + with gr.Column(scale=1, min_width=50): + send_btn = gr.Button(value="Send", visible=False) + + with gr.Row() as button_row2: + regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False) + clear_btn = gr.Button(value="🗑️ Clear history", interactive=False) + share_btn = gr.Button(value="📷 Share") + + with gr.Accordion("Parameters", open=False, visible=True) as parameter_row: + temperature = gr.Slider( + minimum=0.0, + maximum=1.0, + value=0.7, + step=0.1, + interactive=True, + label="Temperature", + ) + top_p = gr.Slider( + minimum=0.0, + maximum=1.0, + value=1.0, + step=0.1, + interactive=True, + label="Top P", + ) + max_output_tokens = gr.Slider( + minimum=16, + maximum=1024, + value=512, + step=64, + interactive=True, + label="Max output tokens", + ) + + gr.Markdown(learn_more_md) + + # Register listeners + btn_list = [ + leftvote_btn, + rightvote_btn, + tie_btn, + bothbad_btn, + regenerate_btn, + clear_btn, + ] + leftvote_btn.click( + leftvote_last_response, + states + model_selectors, + model_selectors + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], + ) + rightvote_btn.click( + rightvote_last_response, + states + model_selectors, + model_selectors + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], + ) + tie_btn.click( + tievote_last_response, + states + model_selectors, + model_selectors + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], + ) + bothbad_btn.click( + bothbad_vote_last_response, + states + model_selectors, + model_selectors + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], + ) + regenerate_btn.click( + regenerate, states, states + chatbots + [textbox] + btn_list + ).then( + http_bot_all, + states + [temperature, top_p, max_output_tokens], + states + chatbots + btn_list, + ) + clear_btn.click( + clear_history, None, states + chatbots + model_selectors + [textbox] + btn_list + ) + + share_js = """ +function (a, b, c, d) { + const captureElement = document.querySelector('#share-region-anony'); + html2canvas(captureElement) + .then(canvas => { + canvas.style.display = 'none' + document.body.appendChild(canvas) + return canvas + }) + .then(canvas => { + const image = canvas.toDataURL('image/png') + const a = document.createElement('a') + a.setAttribute('download', 'chatbot-arena.png') + a.setAttribute('href', image) + a.click() + canvas.remove() + }); + return [a, b, c, d]; +} +""" + share_btn.click(share_click, states + model_selectors, [], _js=share_js) + + textbox.submit( + add_text, + states + model_selectors + [textbox], + states + chatbots + [textbox] + btn_list, + ).then( + http_bot_all, + states + [temperature, top_p, max_output_tokens], + states + chatbots + btn_list, + ) + send_btn.click( + add_text, + states + model_selectors + [textbox], + states + chatbots + [textbox] + btn_list, + ).then( + http_bot_all, + states + [temperature, top_p, max_output_tokens], + states + chatbots + btn_list, + ) + + return ( + states, + model_selectors, + chatbots, + textbox, + send_btn, + button_row, + button_row2, + parameter_row, + ) diff --git a/graphgpt/serve/gradio_block_arena_named.py b/graphgpt/serve/gradio_block_arena_named.py new file mode 100644 index 0000000..fc525eb --- /dev/null +++ b/graphgpt/serve/gradio_block_arena_named.py @@ -0,0 +1,468 @@ +""" +Chatbot Arena (side-by-side) tab. +Users chat with two chosen models. +""" + +import json +import time + +import gradio as gr +import numpy as np + +from fastchat.constants import ( + MODERATION_MSG, + CONVERSATION_LIMIT_MSG, + INPUT_CHAR_LEN_LIMIT, + CONVERSATION_LEN_LIMIT, +) +from fastchat.model.model_adapter import get_conversation_template +from fastchat.serve.gradio_patch import Chatbot as grChatbot +from fastchat.serve.gradio_web_server import ( + State, + http_bot, + get_conv_log_filename, + get_model_description_md, + no_change_btn, + enable_btn, + disable_btn, + learn_more_md, +) +from fastchat.utils import ( + build_logger, + violates_moderation, +) + + +logger = build_logger("gradio_web_server_multi", "gradio_web_server_multi.log") + +num_models = 2 +enable_moderation = False + + +def set_global_vars_named(enable_moderation_): + global enable_moderation + enable_moderation = enable_moderation_ + + +def load_demo_side_by_side_named(models, url_params): + states = (None,) * num_models + + model_left = models[0] if len(models) > 0 else "" + if len(models) > 1: + weights = ([8, 4, 2, 1] + [1] * 32)[: len(models) - 1] + weights = weights / np.sum(weights) + model_right = np.random.choice(models[1:], p=weights) + else: + model_right = model_left + + selector_updates = ( + gr.Dropdown.update(model_left, visible=True), + gr.Dropdown.update(model_right, visible=True), + ) + + return ( + states + + selector_updates + + (gr.Chatbot.update(visible=True),) * num_models + + ( + gr.Textbox.update(visible=True), + gr.Box.update(visible=True), + gr.Row.update(visible=True), + gr.Row.update(visible=True), + gr.Accordion.update(visible=True), + ) + ) + + +def vote_last_response(states, vote_type, model_selectors, request: gr.Request): + with open(get_conv_log_filename(), "a") as fout: + data = { + "tstamp": round(time.time(), 4), + "type": vote_type, + "models": [x for x in model_selectors], + "states": [x.dict() for x in states], + "ip": request.client.host, + } + fout.write(json.dumps(data) + "\n") + + +def leftvote_last_response( + state0, state1, model_selector0, model_selector1, request: gr.Request +): + logger.info(f"leftvote (named). ip: {request.client.host}") + vote_last_response( + [state0, state1], "leftvote", [model_selector0, model_selector1], request + ) + return ("",) + (disable_btn,) * 4 + + +def rightvote_last_response( + state0, state1, model_selector0, model_selector1, request: gr.Request +): + logger.info(f"rightvote (named). ip: {request.client.host}") + vote_last_response( + [state0, state1], "rightvote", [model_selector0, model_selector1], request + ) + return ("",) + (disable_btn,) * 4 + + +def tievote_last_response( + state0, state1, model_selector0, model_selector1, request: gr.Request +): + logger.info(f"tievote (named). ip: {request.client.host}") + vote_last_response( + [state0, state1], "tievote", [model_selector0, model_selector1], request + ) + return ("",) + (disable_btn,) * 4 + + +def bothbad_vote_last_response( + state0, state1, model_selector0, model_selector1, request: gr.Request +): + logger.info(f"bothbad_vote (named). ip: {request.client.host}") + vote_last_response( + [state0, state1], "bothbad_vote", [model_selector0, model_selector1], request + ) + return ("",) + (disable_btn,) * 4 + + +def regenerate(state0, state1, request: gr.Request): + logger.info(f"regenerate (named). ip: {request.client.host}") + states = [state0, state1] + for i in range(num_models): + states[i].conv.update_last_message(None) + return states + [x.to_gradio_chatbot() for x in states] + [""] + [disable_btn] * 6 + + +def clear_history(request: gr.Request): + logger.info(f"clear_history (named). ip: {request.client.host}") + return [None] * num_models + [None] * num_models + [""] + [disable_btn] * 6 + + +def share_click(state0, state1, model_selector0, model_selector1, request: gr.Request): + logger.info(f"share (named). ip: {request.client.host}") + if state0 is not None and state1 is not None: + vote_last_response( + [state0, state1], "share", [model_selector0, model_selector1], request + ) + + +def add_text( + state0, state1, model_selector0, model_selector1, text, request: gr.Request +): + logger.info(f"add_text (named). ip: {request.client.host}. len: {len(text)}") + states = [state0, state1] + model_selectors = [model_selector0, model_selector1] + + for i in range(num_models): + if states[i] is None: + states[i] = State(model_selectors[i]) + + if len(text) <= 0: + for i in range(num_models): + states[i].skip_next = True + return ( + states + + [x.to_gradio_chatbot() for x in states] + + [""] + + [ + no_change_btn, + ] + * 6 + ) + + if enable_moderation: + flagged = violates_moderation(text) + if flagged: + logger.info( + f"violate moderation (named). ip: {request.client.host}. text: {text}" + ) + for i in range(num_models): + states[i].skip_next = True + return ( + states + + [x.to_gradio_chatbot() for x in states] + + [MODERATION_MSG] + + [ + no_change_btn, + ] + * 6 + ) + + conv = states[0].conv + if (len(conv.messages) - conv.offset) // 2 >= CONVERSATION_LEN_LIMIT: + logger.info( + f"hit conversation length limit. ip: {request.client.host}. text: {text}" + ) + for i in range(num_models): + states[i].skip_next = True + return ( + states + + [x.to_gradio_chatbot() for x in states] + + [CONVERSATION_LIMIT_MSG] + + [ + no_change_btn, + ] + * 6 + ) + + text = text[:INPUT_CHAR_LEN_LIMIT] # Hard cut-off + for i in range(num_models): + states[i].conv.append_message(states[i].conv.roles[0], text) + states[i].conv.append_message(states[i].conv.roles[1], None) + states[i].skip_next = False + + return ( + states + + [x.to_gradio_chatbot() for x in states] + + [""] + + [ + disable_btn, + ] + * 6 + ) + + +def http_bot_all( + state0, + state1, + temperature, + top_p, + max_new_tokens, + request: gr.Request, +): + logger.info(f"http_bot_all (named). ip: {request.client.host}") + + if state0.skip_next: + # This generate call is skipped due to invalid inputs + yield ( + state0, + state1, + state0.to_gradio_chatbot(), + state1.to_gradio_chatbot(), + ) + (no_change_btn,) * 6 + return + + states = [state0, state1] + gen = [] + for i in range(num_models): + gen.append( + http_bot( + states[i], + temperature, + top_p, + max_new_tokens, + request, + ) + ) + + chatbots = [None] * num_models + while True: + stop = True + for i in range(num_models): + try: + ret = next(gen[i]) + states[i], chatbots[i] = ret[0], ret[1] + stop = False + except StopIteration: + pass + yield states + chatbots + [disable_btn] * 6 + if stop: + break + + for i in range(10): + if i % 2 == 0: + yield states + chatbots + [disable_btn] * 4 + [enable_btn] * 2 + else: + yield states + chatbots + [enable_btn] * 6 + time.sleep(0.2) + + +def build_side_by_side_ui_named(models): + notice_markdown = """ +# ⚔️ Chatbot Arena ⚔️ +### Rules +- Chat with two models side-by-side and vote for which one is better! +- You pick the models you want to chat with. +- You can do multiple rounds of conversations before voting. +- Click "Clear history" to start a new round. +- [[Blog](https://lmsys.org/blog/2023-05-03-arena/)] [[GitHub]](https://github.com/lm-sys/FastChat) [[Twitter]](https://twitter.com/lmsysorg) [[Discord]](https://discord.gg/KjdtsE9V) + +### Terms of use +By using this service, users are required to agree to the following terms: The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. **The service collects user dialogue data and reserves the right to distribute it under a Creative Commons Attribution (CC-BY) license.** The demo works better on desktop devices with a wide screen. + +### Choose two models to chat with (view [leaderboard](?leaderboard)) +""" + + states = [gr.State() for _ in range(num_models)] + model_selectors = [None] * num_models + chatbots = [None] * num_models + + model_description_md = get_model_description_md(models) + notice = gr.Markdown( + notice_markdown + model_description_md, elem_id="notice_markdown" + ) + + with gr.Box(elem_id="share-region-named"): + with gr.Row(): + for i in range(num_models): + with gr.Column(): + model_selectors[i] = gr.Dropdown( + choices=models, + value=models[i] if len(models) > i else "", + interactive=True, + show_label=False, + ).style(container=False) + + with gr.Row(): + for i in range(num_models): + label = "Model A" if i == 0 else "Model B" + with gr.Column(): + chatbots[i] = grChatbot( + label=label, elem_id=f"chatbot", visible=False + ).style(height=550) + + with gr.Box() as button_row: + with gr.Row(): + leftvote_btn = gr.Button(value="👈 A is better", interactive=False) + rightvote_btn = gr.Button(value="👉 B is better", interactive=False) + tie_btn = gr.Button(value="🤝 Tie", interactive=False) + bothbad_btn = gr.Button(value="👎 Both are bad", interactive=False) + + with gr.Row(): + with gr.Column(scale=20): + textbox = gr.Textbox( + show_label=False, + placeholder="Enter text and press ENTER", + visible=False, + ).style(container=False) + with gr.Column(scale=1, min_width=50): + send_btn = gr.Button(value="Send", visible=False) + + with gr.Row() as button_row2: + regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False) + clear_btn = gr.Button(value="🗑️ Clear history", interactive=False) + share_btn = gr.Button(value="📷 Share") + + with gr.Accordion("Parameters", open=False, visible=True) as parameter_row: + temperature = gr.Slider( + minimum=0.0, + maximum=1.0, + value=0.7, + step=0.1, + interactive=True, + label="Temperature", + ) + top_p = gr.Slider( + minimum=0.0, + maximum=1.0, + value=1.0, + step=0.1, + interactive=True, + label="Top P", + ) + max_output_tokens = gr.Slider( + minimum=16, + maximum=1024, + value=512, + step=64, + interactive=True, + label="Max output tokens", + ) + + gr.Markdown(learn_more_md) + + # Register listeners + btn_list = [ + leftvote_btn, + rightvote_btn, + tie_btn, + bothbad_btn, + regenerate_btn, + clear_btn, + ] + leftvote_btn.click( + leftvote_last_response, + states + model_selectors, + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], + ) + rightvote_btn.click( + rightvote_last_response, + states + model_selectors, + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], + ) + tie_btn.click( + tievote_last_response, + states + model_selectors, + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], + ) + bothbad_btn.click( + bothbad_vote_last_response, + states + model_selectors, + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], + ) + regenerate_btn.click( + regenerate, states, states + chatbots + [textbox] + btn_list + ).then( + http_bot_all, + states + [temperature, top_p, max_output_tokens], + states + chatbots + btn_list, + ) + clear_btn.click(clear_history, None, states + chatbots + [textbox] + btn_list) + + share_js = """ +function (a, b, c, d) { + const captureElement = document.querySelector('#share-region-named'); + html2canvas(captureElement) + .then(canvas => { + canvas.style.display = 'none' + document.body.appendChild(canvas) + return canvas + }) + .then(canvas => { + const image = canvas.toDataURL('image/png') + const a = document.createElement('a') + a.setAttribute('download', 'chatbot-arena.png') + a.setAttribute('href', image) + a.click() + canvas.remove() + }); + return [a, b, c, d]; +} +""" + share_btn.click(share_click, states + model_selectors, [], _js=share_js) + + for i in range(num_models): + model_selectors[i].change( + clear_history, None, states + chatbots + [textbox] + btn_list + ) + + textbox.submit( + add_text, + states + model_selectors + [textbox], + states + chatbots + [textbox] + btn_list, + ).then( + http_bot_all, + states + [temperature, top_p, max_output_tokens], + states + chatbots + btn_list, + ) + send_btn.click( + add_text, + states + model_selectors + [textbox], + states + chatbots + [textbox] + btn_list, + ).then( + http_bot_all, + states + [temperature, top_p, max_output_tokens], + states + chatbots + btn_list, + ) + + return ( + states, + model_selectors, + chatbots, + textbox, + send_btn, + button_row, + button_row2, + parameter_row, + ) diff --git a/graphgpt/serve/gradio_css.py b/graphgpt/serve/gradio_css.py new file mode 100644 index 0000000..d64a8cf --- /dev/null +++ b/graphgpt/serve/gradio_css.py @@ -0,0 +1,77 @@ +code_highlight_css = """ +#chatbot .hll { background-color: #ffffcc } +#chatbot .c { color: #408080; font-style: italic } +#chatbot .err { border: 1px solid #FF0000 } +#chatbot .k { color: #008000; font-weight: bold } +#chatbot .o { color: #666666 } +#chatbot .ch { color: #408080; font-style: italic } +#chatbot .cm { color: #408080; font-style: italic } +#chatbot .cp { color: #BC7A00 } +#chatbot .cpf { color: #408080; font-style: italic } +#chatbot .c1 { color: #408080; font-style: italic } +#chatbot .cs { color: #408080; font-style: italic } +#chatbot .gd { color: #A00000 } +#chatbot .ge { font-style: italic } +#chatbot .gr { color: #FF0000 } +#chatbot .gh { color: #000080; font-weight: bold } +#chatbot .gi { color: #00A000 } +#chatbot .go { color: #888888 } +#chatbot .gp { color: #000080; font-weight: bold } +#chatbot .gs { font-weight: bold } +#chatbot .gu { color: #800080; font-weight: bold } +#chatbot .gt { color: #0044DD } +#chatbot .kc { color: #008000; font-weight: bold } +#chatbot .kd { color: #008000; font-weight: bold } +#chatbot .kn { color: #008000; font-weight: bold } +#chatbot .kp { color: #008000 } +#chatbot .kr { color: #008000; font-weight: bold } +#chatbot .kt { color: #B00040 } +#chatbot .m { color: #666666 } +#chatbot .s { color: #BA2121 } +#chatbot .na { color: #7D9029 } +#chatbot .nb { color: #008000 } +#chatbot .nc { color: #0000FF; font-weight: bold } +#chatbot .no { color: #880000 } +#chatbot .nd { color: #AA22FF } +#chatbot .ni { color: #999999; font-weight: bold } +#chatbot .ne { color: #D2413A; font-weight: bold } +#chatbot .nf { color: #0000FF } +#chatbot .nl { color: #A0A000 } +#chatbot .nn { color: #0000FF; font-weight: bold } +#chatbot .nt { color: #008000; font-weight: bold } +#chatbot .nv { color: #19177C } +#chatbot .ow { color: #AA22FF; font-weight: bold } +#chatbot .w { color: #bbbbbb } +#chatbot .mb { color: #666666 } +#chatbot .mf { color: #666666 } +#chatbot .mh { color: #666666 } +#chatbot .mi { color: #666666 } +#chatbot .mo { color: #666666 } +#chatbot .sa { color: #BA2121 } +#chatbot .sb { color: #BA2121 } +#chatbot .sc { color: #BA2121 } +#chatbot .dl { color: #BA2121 } +#chatbot .sd { color: #BA2121; font-style: italic } +#chatbot .s2 { color: #BA2121 } +#chatbot .se { color: #BB6622; font-weight: bold } +#chatbot .sh { color: #BA2121 } +#chatbot .si { color: #BB6688; font-weight: bold } +#chatbot .sx { color: #008000 } +#chatbot .sr { color: #BB6688 } +#chatbot .s1 { color: #BA2121 } +#chatbot .ss { color: #19177C } +#chatbot .bp { color: #008000 } +#chatbot .fm { color: #0000FF } +#chatbot .vc { color: #19177C } +#chatbot .vg { color: #19177C } +#chatbot .vi { color: #19177C } +#chatbot .vm { color: #19177C } +#chatbot .il { color: #666666 } +""" +# .highlight { background: #f8f8f8; } + +table_css = """ +table { + line-height: 0em +} +""" diff --git a/graphgpt/serve/gradio_patch.py b/graphgpt/serve/gradio_patch.py new file mode 100644 index 0000000..af8731d --- /dev/null +++ b/graphgpt/serve/gradio_patch.py @@ -0,0 +1,168 @@ +""" +Adopted from https://github.com/gradio-app/gradio/blob/main/gradio/components.py +Fix a markdown render problem. +""" +from __future__ import annotations + +from gradio.components import * +from markdown2 import Markdown +import nh3 + + +class _Keywords(Enum): + NO_VALUE = "NO_VALUE" # Used as a sentinel to determine if nothing is provided as a argument for `value` in `Component.update()` + FINISHED_ITERATING = "FINISHED_ITERATING" # Used to skip processing of a component's value (needed for generators + state) + + +@document("style") +class Chatbot(Changeable, Selectable, IOComponent, JSONSerializable): + """ + Displays a chatbot output showing both user submitted messages and responses. Supports a subset of Markdown including bold, italics, code, and images. + Preprocessing: this component does *not* accept input. + Postprocessing: expects function to return a {List[Tuple[str | None | Tuple, str | None | Tuple]]}, a list of tuples with user message and response messages. Messages should be strings, tuples, or Nones. If the message is a string, it can include Markdown. If it is a tuple, it should consist of (string filepath to image/video/audio, [optional string alt text]). Messages that are `None` are not displayed. + + Demos: chatbot_simple, chatbot_multimodal + """ + + def __init__( + self, + value: List[Tuple[str | None, str | None]] | Callable | None = None, + color_map: Dict[str, str] | None = None, # Parameter moved to Chatbot.style() + *, + label: str | None = None, + every: float | None = None, + show_label: bool = True, + visible: bool = True, + elem_id: str | None = None, + elem_classes: List[str] | str | None = None, + **kwargs, + ): + """ + Parameters: + value: Default value to show in chatbot. If callable, the function will be called whenever the app loads to set the initial value of the component. + label: component name in interface. + every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. Queue must be enabled. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute. + show_label: if True, will display label. + visible: If False, component will be hidden. + elem_id: An optional string that is assigned as the id of this component in the HTML DOM. Can be used for targeting CSS styles. + elem_classes: An optional list of strings that are assigned as the classes of this component in the HTML DOM. Can be used for targeting CSS styles. + """ + if color_map is not None: + warnings.warn( + "The 'color_map' parameter has been deprecated.", + ) + # self.md = utils.get_markdown_parser() + self.md = Markdown(extras=["fenced-code-blocks", "tables", "break-on-newline"]) + self.select: EventListenerMethod + """ + Event listener for when the user selects message from Chatbot. + Uses event data gradio.SelectData to carry `value` referring to text of selected message, and `index` tuple to refer to [message, participant] index. + See EventData documentation on how to use this event data. + """ + + IOComponent.__init__( + self, + label=label, + every=every, + show_label=show_label, + visible=visible, + elem_id=elem_id, + elem_classes=elem_classes, + value=value, + **kwargs, + ) + + def get_config(self): + return { + "value": self.value, + "selectable": self.selectable, + **IOComponent.get_config(self), + } + + @staticmethod + def update( + value: Any | Literal[_Keywords.NO_VALUE] | None = _Keywords.NO_VALUE, + label: str | None = None, + show_label: bool | None = None, + visible: bool | None = None, + ): + updated_config = { + "label": label, + "show_label": show_label, + "visible": visible, + "value": value, + "__type__": "update", + } + return updated_config + + def _process_chat_messages( + self, chat_message: str | Tuple | List | Dict | None + ) -> str | Dict | None: + if chat_message is None: + return None + elif isinstance(chat_message, (tuple, list)): + mime_type = processing_utils.get_mimetype(chat_message[0]) + return { + "name": chat_message[0], + "mime_type": mime_type, + "alt_text": chat_message[1] if len(chat_message) > 1 else None, + "data": None, # These last two fields are filled in by the frontend + "is_file": True, + } + elif isinstance( + chat_message, dict + ): # This happens for previously processed messages + return chat_message + elif isinstance(chat_message, str): + # return self.md.render(chat_message) + return str(self.md.convert(chat_message)) + else: + raise ValueError(f"Invalid message for Chatbot component: {chat_message}") + + def postprocess( + self, + y: List[ + Tuple[str | Tuple | List | Dict | None, str | Tuple | List | Dict | None] + ], + ) -> List[Tuple[str | Dict | None, str | Dict | None]]: + """ + Parameters: + y: List of tuples representing the message and response pairs. Each message and response should be a string, which may be in Markdown format. It can also be a tuple whose first element is a string filepath or URL to an image/video/audio, and second (optional) element is the alt text, in which case the media file is displayed. It can also be None, in which case that message is not displayed. + Returns: + List of tuples representing the message and response. Each message and response will be a string of HTML, or a dictionary with media information. + """ + if y is None: + return [] + processed_messages = [] + for message_pair in y: + assert isinstance( + message_pair, (tuple, list) + ), f"Expected a list of lists or list of tuples. Received: {message_pair}" + assert ( + len(message_pair) == 2 + ), f"Expected a list of lists of length 2 or list of tuples of length 2. Received: {message_pair}" + processed_messages.append( + ( + # self._process_chat_messages(message_pair[0]), + '
'
+                    + nh3.clean(message_pair[0])
+                    + "
", + self._process_chat_messages(message_pair[1]), + ) + ) + return processed_messages + + def style(self, height: int | None = None, **kwargs): + """ + This method can be used to change the appearance of the Chatbot component. + """ + if height is not None: + self._style["height"] = height + if kwargs.get("color_map") is not None: + warnings.warn("The 'color_map' parameter has been deprecated.") + + Component.style( + self, + **kwargs, + ) + return self diff --git a/graphgpt/serve/gradio_web_server.py b/graphgpt/serve/gradio_web_server.py new file mode 100644 index 0000000..4e0c0d3 --- /dev/null +++ b/graphgpt/serve/gradio_web_server.py @@ -0,0 +1,714 @@ +""" +The gradio demo server for chatting with a single model. +""" + +import argparse +from collections import defaultdict +import datetime +import json +import os +import random +import time +import uuid + +import gradio as gr +import requests + +from fastchat.conversation import SeparatorStyle +from fastchat.constants import ( + LOGDIR, + WORKER_API_TIMEOUT, + ErrorCode, + MODERATION_MSG, + CONVERSATION_LIMIT_MSG, + SERVER_ERROR_MSG, + INPUT_CHAR_LEN_LIMIT, + CONVERSATION_LEN_LIMIT, +) +from fastchat.model.model_adapter import get_conversation_template +from fastchat.model.model_registry import model_info +from fastchat.serve.api_provider import ( + anthropic_api_stream_iter, + bard_api_stream_iter, + openai_api_stream_iter, + palm_api_stream_iter, + init_palm_chat, +) +from fastchat.serve.gradio_patch import Chatbot as grChatbot +from fastchat.serve.gradio_css import code_highlight_css +from fastchat.utils import ( + build_logger, + violates_moderation, + get_window_url_params_js, +) + + +logger = build_logger("gradio_web_server", "gradio_web_server.log") + +headers = {"User-Agent": "fastchat Client"} + +no_change_btn = gr.Button.update() +enable_btn = gr.Button.update(interactive=True) +disable_btn = gr.Button.update(interactive=False) + +controller_url = None +enable_moderation = False + +learn_more_md = """ +### License +The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation. +""" + + +class State: + def __init__(self, model_name): + self.conv = get_conversation_template(model_name) + self.conv_id = uuid.uuid4().hex + self.skip_next = False + self.model_name = model_name + + if model_name == "bard": + self.bard_session_state = { + "conversation_id": "", + "response_id": "", + "choice_id": "", + "req_id": 0, + } + # According to release note, "chat-bison@001" is PaLM 2 for chat. + # https://cloud.google.com/vertex-ai/docs/release-notes#May_10_2023 + self.palm_chat = init_palm_chat("chat-bison@001") + + def to_gradio_chatbot(self): + return self.conv.to_gradio_chatbot() + + def dict(self): + base = self.conv.dict() + base.update( + { + "conv_id": self.conv_id, + "model_name": self.model_name, + } + ) + return base + + +def set_global_vars(controller_url_, enable_moderation_): + global controller_url, enable_moderation + controller_url = controller_url_ + enable_moderation = enable_moderation_ + + +def get_conv_log_filename(): + t = datetime.datetime.now() + name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json") + return name + + +def get_model_list(controller_url): + ret = requests.post(controller_url + "/refresh_all_workers") + assert ret.status_code == 200 + ret = requests.post(controller_url + "/list_models") + models = ret.json()["models"] + priority = {k: f"___{i:02d}" for i, k in enumerate(model_info)} + models.sort(key=lambda x: priority.get(x, x)) + logger.info(f"Models: {models}") + return models + + +def load_demo_refresh_model_list(url_params): + models = get_model_list(controller_url) + selected_model = models[0] if len(models) > 0 else "" + if "model" in url_params: + model = url_params["model"] + if model in models: + selected_model = model + + dropdown_update = gr.Dropdown.update( + choices=models, value=selected_model, visible=True + ) + + state = None + return ( + state, + dropdown_update, + gr.Chatbot.update(visible=True), + gr.Textbox.update(visible=True), + gr.Button.update(visible=True), + gr.Row.update(visible=True), + gr.Accordion.update(visible=True), + ) + + +def load_demo_reload_model(url_params, request: gr.Request): + logger.info( + f"load_demo_reload_model. ip: {request.client.host}. params: {url_params}" + ) + return load_demo_refresh_model_list(url_params) + + +def load_demo_single(models, url_params): + dropdown_update = gr.Dropdown.update(visible=True) + if "model" in url_params: + model = url_params["model"] + if model in models: + dropdown_update = gr.Dropdown.update(value=model, visible=True) + + state = None + return ( + state, + dropdown_update, + gr.Chatbot.update(visible=True), + gr.Textbox.update(visible=True), + gr.Button.update(visible=True), + gr.Row.update(visible=True), + gr.Accordion.update(visible=True), + ) + + +def load_demo(url_params, request: gr.Request): + logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}") + return load_demo_single(models, url_params) + + +def vote_last_response(state, vote_type, model_selector, request: gr.Request): + with open(get_conv_log_filename(), "a") as fout: + data = { + "tstamp": round(time.time(), 4), + "type": vote_type, + "model": model_selector, + "state": state.dict(), + "ip": request.client.host, + } + fout.write(json.dumps(data) + "\n") + + +def upvote_last_response(state, model_selector, request: gr.Request): + logger.info(f"upvote. ip: {request.client.host}") + vote_last_response(state, "upvote", model_selector, request) + return ("",) + (disable_btn,) * 3 + + +def downvote_last_response(state, model_selector, request: gr.Request): + logger.info(f"downvote. ip: {request.client.host}") + vote_last_response(state, "downvote", model_selector, request) + return ("",) + (disable_btn,) * 3 + + +def flag_last_response(state, model_selector, request: gr.Request): + logger.info(f"flag. ip: {request.client.host}") + vote_last_response(state, "flag", model_selector, request) + return ("",) + (disable_btn,) * 3 + + +def regenerate(state, request: gr.Request): + logger.info(f"regenerate. ip: {request.client.host}") + state.conv.update_last_message(None) + return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5 + + +def clear_history(request: gr.Request): + logger.info(f"clear_history. ip: {request.client.host}") + state = None + return (state, [], "") + (disable_btn,) * 5 + + +def add_text(state, model_selector, text, request: gr.Request): + logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}") + + if state is None: + state = State(model_selector) + + if len(text) <= 0: + state.skip_next = True + return (state, state.to_gradio_chatbot(), "") + (no_change_btn,) * 5 + + if enable_moderation: + flagged = violates_moderation(text) + if flagged: + logger.info(f"violate moderation. ip: {request.client.host}. text: {text}") + state.skip_next = True + return (state, state.to_gradio_chatbot(), MODERATION_MSG) + ( + no_change_btn, + ) * 5 + + conv = state.conv + if (len(conv.messages) - conv.offset) // 2 >= CONVERSATION_LEN_LIMIT: + logger.info( + f"hit conversation length limit. ip: {request.client.host}. text: {text}" + ) + state.skip_next = True + return (state, state.to_gradio_chatbot(), CONVERSATION_LIMIT_MSG) + ( + no_change_btn, + ) * 5 + + text = text[:INPUT_CHAR_LEN_LIMIT] # Hard cut-off + conv.append_message(conv.roles[0], text) + conv.append_message(conv.roles[1], None) + return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5 + + +def post_process_code(code): + sep = "\n```" + if sep in code: + blocks = code.split(sep) + if len(blocks) % 2 == 1: + for i in range(1, len(blocks), 2): + blocks[i] = blocks[i].replace("\\_", "_") + code = sep.join(blocks) + return code + + +def model_worker_stream_iter( + conv, + model_name, + worker_addr, + prompt, + temperature, + repetition_penalty, + top_p, + max_new_tokens, +): + # Make requests + gen_params = { + "model": model_name, + "prompt": prompt, + "temperature": temperature, + "repetition_penalty": repetition_penalty, + "top_p": top_p, + "max_new_tokens": max_new_tokens, + "stop": conv.stop_str, + "stop_token_ids": conv.stop_token_ids, + "echo": False, + } + logger.info(f"==== request ====\n{gen_params}") + + # Stream output + response = requests.post( + worker_addr + "/worker_generate_stream", + headers=headers, + json=gen_params, + stream=True, + timeout=WORKER_API_TIMEOUT, + ) + for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): + if chunk: + data = json.loads(chunk.decode()) + yield data + + +def http_bot(state, temperature, top_p, max_new_tokens, request: gr.Request): + logger.info(f"http_bot. ip: {request.client.host}") + start_tstamp = time.time() + temperature = float(temperature) + top_p = float(top_p) + max_new_tokens = int(max_new_tokens) + + if state.skip_next: + # This generate call is skipped due to invalid inputs + state.skip_next = False + yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5 + return + + conv, model_name = state.conv, state.model_name + if model_name == "gpt-3.5-turbo" or model_name == "gpt-4": + prompt = conv.to_openai_api_messages() + stream_iter = openai_api_stream_iter( + model_name, prompt, temperature, top_p, max_new_tokens + ) + elif model_name in ["claude-v1", "claude-instant-v1"]: + prompt = conv.get_prompt() + stream_iter = anthropic_api_stream_iter( + model_name, prompt, temperature, top_p, max_new_tokens + ) + elif model_name == "bard": + # stream_iter = bard_api_stream_iter(state) + stream_iter = palm_api_stream_iter( + state.palm_chat, conv.messages[-2][1], temperature, top_p, max_new_tokens + ) + else: + # Query worker address + ret = requests.post( + controller_url + "/get_worker_address", json={"model": model_name} + ) + worker_addr = ret.json()["address"] + logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}") + + # No available worker + if worker_addr == "": + conv.update_last_message(SERVER_ERROR_MSG) + yield ( + state, + state.to_gradio_chatbot(), + disable_btn, + disable_btn, + disable_btn, + enable_btn, + enable_btn, + ) + return + + # Construct prompt + if "chatglm" in model_name: + prompt = list(list(x) for x in conv.messages[conv.offset :]) + else: + prompt = conv.get_prompt() + + # Construct repetition_penalty + if "t5" in model_name: + repetition_penalty = 1.2 + else: + repetition_penalty = 1.0 + stream_iter = model_worker_stream_iter( + conv, + model_name, + worker_addr, + prompt, + temperature, + repetition_penalty, + top_p, + max_new_tokens, + ) + + conv.update_last_message("▌") + yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5 + + try: + for data in stream_iter: + if data["error_code"] == 0: + output = data["text"].strip() + if "vicuna" in model_name: + output = post_process_code(output) + conv.update_last_message(output + "▌") + yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5 + else: + output = data["text"] + f"\n\n(error_code: {data['error_code']})" + conv.update_last_message(output) + yield (state, state.to_gradio_chatbot()) + ( + disable_btn, + disable_btn, + disable_btn, + enable_btn, + enable_btn, + ) + return + time.sleep(0.02) + except requests.exceptions.RequestException as e: + conv.update_last_message( + f"{SERVER_ERROR_MSG}\n\n" + f"(error_code: {ErrorCode.GRADIO_REQUEST_ERROR}, {e})" + ) + yield (state, state.to_gradio_chatbot()) + ( + disable_btn, + disable_btn, + disable_btn, + enable_btn, + enable_btn, + ) + return + except Exception as e: + conv.update_last_message( + f"{SERVER_ERROR_MSG}\n\n" + f"(error_code: {ErrorCode.GRADIO_STREAM_UNKNOWN_ERROR}, {e})" + ) + yield (state, state.to_gradio_chatbot()) + ( + disable_btn, + disable_btn, + disable_btn, + enable_btn, + enable_btn, + ) + return + + # Delete "▌" + conv.update_last_message(conv.messages[-1][-1][:-1]) + yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 + + finish_tstamp = time.time() + logger.info(f"{output}") + + with open(get_conv_log_filename(), "a") as fout: + data = { + "tstamp": round(finish_tstamp, 4), + "type": "chat", + "model": model_name, + "gen_params": { + "temperature": temperature, + "top_p": top_p, + "max_new_tokens": max_new_tokens, + }, + "start": round(start_tstamp, 4), + "finish": round(finish_tstamp, 4), + "state": state.dict(), + "ip": request.client.host, + } + fout.write(json.dumps(data) + "\n") + + +block_css = ( + code_highlight_css + + """ +pre { + white-space: pre-wrap; /* Since CSS 2.1 */ + white-space: -moz-pre-wrap; /* Mozilla, since 1999 */ + white-space: -pre-wrap; /* Opera 4-6 */ + white-space: -o-pre-wrap; /* Opera 7 */ + word-wrap: break-word; /* Internet Explorer 5.5+ */ +} +#notice_markdown th { + display: none; +} +""" +) + + +def get_model_description_md(models): + model_description_md = """ +| | | | +| ---- | ---- | ---- | +""" + ct = 0 + visited = set() + for i, name in enumerate(models): + if name in model_info: + minfo = model_info[name] + if minfo.simple_name in visited: + continue + visited.add(minfo.simple_name) + one_model_md = f"[{minfo.simple_name}]({minfo.link}): {minfo.description}" + else: + visited.add(name) + one_model_md = ( + f"[{name}](): Add the description at fastchat/model/model_registry.py" + ) + + if ct % 3 == 0: + model_description_md += "|" + model_description_md += f" {one_model_md} |" + if ct % 3 == 2: + model_description_md += "\n" + ct += 1 + return model_description_md + + +def build_single_model_ui(models): + notice_markdown = """ +# 🏔️ Chat with Open Large Language Models +- Vicuna: An Open-Source Chatbot Impressing GPT-4 with 90% ChatGPT Quality. [[Blog post]](https://lmsys.org/blog/2023-03-30-vicuna/) +- Koala: A Dialogue Model for Academic Research. [[Blog post]](https://bair.berkeley.edu/blog/2023/04/03/koala/) +- [[GitHub]](https://github.com/lm-sys/FastChat) [[Twitter]](https://twitter.com/lmsysorg) [[Discord]](https://discord.gg/KjdtsE9V) + +### Terms of use +By using this service, users are required to agree to the following terms: The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. **The service collects user dialogue data and reserves the right to distribute it under a Creative Commons Attribution (CC-BY) license.** + +### Choose a model to chat with +""" + + state = gr.State() + model_description_md = get_model_description_md(models) + gr.Markdown(notice_markdown + model_description_md, elem_id="notice_markdown") + + with gr.Row(elem_id="model_selector_row"): + model_selector = gr.Dropdown( + choices=models, + value=models[0] if len(models) > 0 else "", + interactive=True, + show_label=False, + ).style(container=False) + + chatbot = grChatbot( + elem_id="chatbot", label="Scroll down and start chatting", visible=False + ).style(height=550) + with gr.Row(): + with gr.Column(scale=20): + textbox = gr.Textbox( + show_label=False, + placeholder="Enter text and press ENTER", + visible=False, + ).style(container=False) + with gr.Column(scale=1, min_width=50): + send_btn = gr.Button(value="Send", visible=False) + + with gr.Row(visible=False) as button_row: + upvote_btn = gr.Button(value="👍 Upvote", interactive=False) + downvote_btn = gr.Button(value="👎 Downvote", interactive=False) + flag_btn = gr.Button(value="⚠️ Flag", interactive=False) + regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False) + clear_btn = gr.Button(value="🗑️ Clear history", interactive=False) + + with gr.Accordion("Parameters", open=False, visible=False) as parameter_row: + temperature = gr.Slider( + minimum=0.0, + maximum=1.0, + value=0.7, + step=0.1, + interactive=True, + label="Temperature", + ) + top_p = gr.Slider( + minimum=0.0, + maximum=1.0, + value=1.0, + step=0.1, + interactive=True, + label="Top P", + ) + max_output_tokens = gr.Slider( + minimum=16, + maximum=1024, + value=512, + step=64, + interactive=True, + label="Max output tokens", + ) + + gr.Markdown(learn_more_md) + + # Register listeners + btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn] + upvote_btn.click( + upvote_last_response, + [state, model_selector], + [textbox, upvote_btn, downvote_btn, flag_btn], + ) + downvote_btn.click( + downvote_last_response, + [state, model_selector], + [textbox, upvote_btn, downvote_btn, flag_btn], + ) + flag_btn.click( + flag_last_response, + [state, model_selector], + [textbox, upvote_btn, downvote_btn, flag_btn], + ) + regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then( + http_bot, + [state, temperature, top_p, max_output_tokens], + [state, chatbot] + btn_list, + ) + clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list) + + model_selector.change(clear_history, None, [state, chatbot, textbox] + btn_list) + + textbox.submit( + add_text, [state, model_selector, textbox], [state, chatbot, textbox] + btn_list + ).then( + http_bot, + [state, temperature, top_p, max_output_tokens], + [state, chatbot] + btn_list, + ) + send_btn.click( + add_text, [state, model_selector, textbox], [state, chatbot, textbox] + btn_list + ).then( + http_bot, + [state, temperature, top_p, max_output_tokens], + [state, chatbot] + btn_list, + ) + + return state, model_selector, chatbot, textbox, send_btn, button_row, parameter_row + + +def build_demo(models): + with gr.Blocks( + title="Chat with Open Large Language Models", + theme=gr.themes.Base(), + css=block_css, + ) as demo: + url_params = gr.JSON(visible=False) + + ( + state, + model_selector, + chatbot, + textbox, + send_btn, + button_row, + parameter_row, + ) = build_single_model_ui(models) + + if args.model_list_mode == "once": + demo.load( + load_demo, + [url_params], + [ + state, + model_selector, + chatbot, + textbox, + send_btn, + button_row, + parameter_row, + ], + _js=get_window_url_params_js, + ) + elif args.model_list_mode == "reload": + demo.load( + load_demo_reload_model, + [url_params], + [ + state, + model_selector, + chatbot, + textbox, + send_btn, + button_row, + parameter_row, + ], + _js=get_window_url_params_js, + ) + else: + raise ValueError(f"Unknown model list mode: {args.model_list_mode}") + + return demo + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="0.0.0.0") + parser.add_argument("--port", type=int) + parser.add_argument("--controller-url", type=str, default="http://localhost:21001") + parser.add_argument("--concurrency-count", type=int, default=10) + parser.add_argument( + "--model-list-mode", + type=str, + default="once", + choices=["once", "reload"], + help="Whether to load the model list once or reload the model list every time.", + ) + parser.add_argument("--share", action="store_true") + parser.add_argument( + "--moderate", action="store_true", help="Enable content moderation" + ) + parser.add_argument( + "--add-chatgpt", + action="store_true", + help="Add OpenAI's ChatGPT models (gpt-3.5-turbo, gpt-4)", + ) + parser.add_argument( + "--add-claude", + action="store_true", + help="Add Anthropic's Claude models (claude-v1, claude-instant-v1)", + ) + parser.add_argument( + "--add-bard", + action="store_true", + help="Add Google's Bard model (PaLM 2 for Chat: chat-bison@001)", + ) + args = parser.parse_args() + logger.info(f"args: {args}") + + set_global_vars(args.controller_url, args.moderate) + models = get_model_list(args.controller_url) + + if args.add_chatgpt: + models = ["gpt-3.5-turbo", "gpt-4"] + models + if args.add_claude: + models = ["claude-v1", "claude-instant-v1"] + models + if args.add_bard: + models = ["bard"] + models + + demo = build_demo(models) + demo.queue( + concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False + ).launch( + server_name=args.host, server_port=args.port, share=args.share, max_threads=200 + ) diff --git a/graphgpt/serve/gradio_web_server_graph.py b/graphgpt/serve/gradio_web_server_graph.py new file mode 100644 index 0000000..e3e1004 --- /dev/null +++ b/graphgpt/serve/gradio_web_server_graph.py @@ -0,0 +1,420 @@ +import argparse +import datetime +import json +import os +import time + +import gradio as gr +import requests + +from graphgpt.conversation import (default_conversation, conv_templates, + SeparatorStyle) +from graphgpt.constants import LOGDIR +from graphgpt.utils import (build_logger, server_error_msg, + violates_moderation, moderation_msg) +import hashlib + + +logger = build_logger("gradio_web_server", "gradio_web_server.log") + +headers = {"User-Agent": "GraphGPT Client"} + +no_change_btn = gr.Button.update() +enable_btn = gr.Button.update(interactive=True) +disable_btn = gr.Button.update(interactive=False) + +priority = { + "vicuna-13b": "aaaaaaa", + "koala-13b": "aaaaaab", +} + + +def get_conv_log_filename(): + t = datetime.datetime.now() + name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json") + return name + + +def get_model_list(): + ret = requests.post(args.controller_url + "/refresh_all_workers") + assert ret.status_code == 200 + ret = requests.post(args.controller_url + "/list_models") + models = ret.json()["models"] + models.sort(key=lambda x: priority.get(x, x)) + logger.info(f"Models: {models}") + return models + + +get_window_url_params = """ +function() { + const params = new URLSearchParams(window.location.search); + url_params = Object.fromEntries(params); + console.log(url_params); + return url_params; + } +""" + + +def load_demo(url_params, request: gr.Request): + logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}") + + dropdown_update = gr.Dropdown.update(visible=True) + if "model" in url_params: + model = url_params["model"] + if model in models: + dropdown_update = gr.Dropdown.update( + value=model, visible=True) + + state = default_conversation.copy() + return state, dropdown_update + + +def load_demo_refresh_model_list(request: gr.Request): + logger.info(f"load_demo. ip: {request.client.host}") + models = get_model_list() + state = default_conversation.copy() + dropdown_update = gr.Dropdown.update( + choices=models, + value=models[0] if len(models) > 0 else "" + ) + return state, dropdown_update + + +def vote_last_response(state, vote_type, model_selector, request: gr.Request): + with open(get_conv_log_filename(), "a") as fout: + data = { + "tstamp": round(time.time(), 4), + "type": vote_type, + "model": model_selector, + "state": state.dict(), + "ip": request.client.host, + } + fout.write(json.dumps(data) + "\n") + + +def upvote_last_response(state, model_selector, request: gr.Request): + logger.info(f"upvote. ip: {request.client.host}") + vote_last_response(state, "upvote", model_selector, request) + return ("",) + (disable_btn,) * 3 + + +def downvote_last_response(state, model_selector, request: gr.Request): + logger.info(f"downvote. ip: {request.client.host}") + vote_last_response(state, "downvote", model_selector, request) + return ("",) + (disable_btn,) * 3 + + +def flag_last_response(state, model_selector, request: gr.Request): + logger.info(f"flag. ip: {request.client.host}") + vote_last_response(state, "flag", model_selector, request) + return ("",) + (disable_btn,) * 3 + + +def regenerate(state, image_process_mode, request: gr.Request): + logger.info(f"regenerate. ip: {request.client.host}") + state.messages[-1][-1] = None + prev_human_msg = state.messages[-2] + if type(prev_human_msg[1]) in (tuple, list): + prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode) + state.skip_next = False + return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5 + + +def clear_history(request: gr.Request): + logger.info(f"clear_history. ip: {request.client.host}") + state = default_conversation.copy() + return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5 + + +def add_text(state, text, image, image_process_mode, request: gr.Request): + logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}") + if len(text) <= 0 and image is None: + state.skip_next = True + return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5 + if args.moderate: + flagged = violates_moderation(text) + if flagged: + state.skip_next = True + return (state, state.to_gradio_chatbot(), moderation_msg, None) + ( + no_change_btn,) * 5 + + text = text[:1536] # Hard cut-off + if image is not None: + text = text[:1200] # Hard cut-off for images + if '' not in text: + # text = '' + text + text = text + '\n' + text = (text, image, image_process_mode) + if len(state.get_images(return_pil=True)) > 0: + state = default_conversation.copy() + state.append_message(state.roles[0], text) + state.append_message(state.roles[1], None) + state.skip_next = False + return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5 + + +def http_bot(state, model_selector, temperature, top_p, max_new_tokens, request: gr.Request): + logger.info(f"http_bot. ip: {request.client.host}") + start_tstamp = time.time() + model_name = model_selector + + if state.skip_next: + # This generate call is skipped due to invalid inputs + yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5 + return + + if len(state.messages) == state.offset + 2: + # First round of conversation + if "llava" in model_name.lower(): + if 'llama-2' in model_name.lower(): + template_name = "llava_llama_2" + elif "v1" in model_name.lower(): + if 'mmtag' in model_name.lower(): + template_name = "v1_mmtag" + elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower(): + template_name = "v1_mmtag" + else: + template_name = "llava_v1" + elif "mpt" in model_name.lower(): + template_name = "mpt" + else: + if 'mmtag' in model_name.lower(): + template_name = "v0_mmtag" + elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower(): + template_name = "v0_mmtag" + else: + template_name = "llava_v0" + elif "mpt" in model_name: + template_name = "mpt_text" + elif "llama-2" in model_name: + template_name = "llama_2" + else: + template_name = "vicuna_v1" + new_state = conv_templates[template_name].copy() + new_state.append_message(new_state.roles[0], state.messages[-2][1]) + new_state.append_message(new_state.roles[1], None) + state = new_state + + # Query worker address + controller_url = args.controller_url + ret = requests.post(controller_url + "/get_worker_address", + json={"model": model_name}) + worker_addr = ret.json()["address"] + logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}") + + # No available worker + if worker_addr == "": + state.messages[-1][-1] = server_error_msg + yield (state, state.to_gradio_chatbot(), disable_btn, disable_btn, disable_btn, enable_btn, enable_btn) + return + + # Construct prompt + prompt = state.get_prompt() + + all_images = state.get_images(return_pil=True) + all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images] + for image, hash in zip(all_images, all_image_hash): + t = datetime.datetime.now() + filename = os.path.join(LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash}.jpg") + if not os.path.isfile(filename): + os.makedirs(os.path.dirname(filename), exist_ok=True) + image.save(filename) + + # Make requests + pload = { + "model": model_name, + "prompt": prompt, + "temperature": float(temperature), + "top_p": float(top_p), + "max_new_tokens": min(int(max_new_tokens), 1536), + "stop": state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2, + "images": f'List of {len(state.get_images())} images: {all_image_hash}', + } + logger.info(f"==== request ====\n{pload}") + + pload['images'] = state.get_images() + + state.messages[-1][-1] = "▌" + yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5 + + try: + # Stream output + response = requests.post(worker_addr + "/worker_generate_stream", + headers=headers, json=pload, stream=True, timeout=10) + for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): + if chunk: + data = json.loads(chunk.decode()) + if data["error_code"] == 0: + output = data["text"][len(prompt):].strip() + state.messages[-1][-1] = output + "▌" + yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5 + else: + output = data["text"] + f" (error_code: {data['error_code']})" + state.messages[-1][-1] = output + yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn) + return + time.sleep(0.03) + except requests.exceptions.RequestException as e: + state.messages[-1][-1] = server_error_msg + yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn) + return + + state.messages[-1][-1] = state.messages[-1][-1][:-1] + yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 + + finish_tstamp = time.time() + logger.info(f"{output}") + + with open(get_conv_log_filename(), "a") as fout: + data = { + "tstamp": round(finish_tstamp, 4), + "type": "chat", + "model": model_name, + "start": round(start_tstamp, 4), + "finish": round(start_tstamp, 4), + "state": state.dict(), + "images": all_image_hash, + "ip": request.client.host, + } + fout.write(json.dumps(data) + "\n") + +title_markdown = (""" +# 🌋 LLaVA: Large Language and Vision Assistant +[[Project Page](https://llava-vl.github.io)] [[Code](https://github.com/haotian-liu/LLaVA)] [[Model](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md)] | 📚 [[LLaVA](https://arxiv.org/abs/2304.08485)] [[LLaVA-v1.5](https://arxiv.org/abs/2310.03744)] +""") + +tos_markdown = (""" +### Terms of use +By using this service, users are required to agree to the following terms: +The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research. +Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator. +For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality. +""") + + +learn_more_markdown = (""" +### License +The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation. +""") + +block_css = """ + +#buttons button { + min-width: min(120px,100%); +} + +""" + +def build_demo(embed_mode): + textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False) + with gr.Blocks(title="LLaVA", theme=gr.themes.Default(), css=block_css) as demo: + state = gr.State() + + if not embed_mode: + gr.Markdown(title_markdown) + + with gr.Row(): + with gr.Column(scale=3): + with gr.Row(elem_id="model_selector_row"): + model_selector = gr.Dropdown( + choices=models, + value=models[0] if len(models) > 0 else "", + interactive=True, + show_label=False, + container=False) + + imagebox = gr.Image(type="pil") + image_process_mode = gr.Radio( + ["Crop", "Resize", "Pad", "Default"], + value="Default", + label="Preprocess for non-square image", visible=False) + + cur_dir = os.path.dirname(os.path.abspath(__file__)) + gr.Examples(examples=[ + [f"{cur_dir}/examples/extreme_ironing.jpg", "What is unusual about this image?"], + [f"{cur_dir}/examples/waterview.jpg", "What are the things I should be cautious about when I visit here?"], + ], inputs=[imagebox, textbox]) + + with gr.Accordion("Parameters", open=False) as parameter_row: + temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True, label="Temperature",) + top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",) + max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",) + + with gr.Column(scale=8): + chatbot = gr.Chatbot(elem_id="chatbot", label="LLaVA Chatbot", height=550) + with gr.Row(): + with gr.Column(scale=8): + textbox.render() + with gr.Column(scale=1, min_width=50): + submit_btn = gr.Button(value="Send", variant="primary") + with gr.Row(elem_id="buttons") as button_row: + upvote_btn = gr.Button(value="👍 Upvote", interactive=False) + downvote_btn = gr.Button(value="👎 Downvote", interactive=False) + flag_btn = gr.Button(value="⚠️ Flag", interactive=False) + #stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False) + regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False) + clear_btn = gr.Button(value="🗑️ Clear", interactive=False) + + if not embed_mode: + gr.Markdown(tos_markdown) + gr.Markdown(learn_more_markdown) + url_params = gr.JSON(visible=False) + + # Register listeners + btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn] + upvote_btn.click(upvote_last_response, + [state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn]) + downvote_btn.click(downvote_last_response, + [state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn]) + flag_btn.click(flag_last_response, + [state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn]) + regenerate_btn.click(regenerate, [state, image_process_mode], + [state, chatbot, textbox, imagebox] + btn_list).then( + http_bot, [state, model_selector, temperature, top_p, max_output_tokens], + [state, chatbot] + btn_list) + clear_btn.click(clear_history, None, [state, chatbot, textbox, imagebox] + btn_list) + + textbox.submit(add_text, [state, textbox, imagebox, image_process_mode], [state, chatbot, textbox, imagebox] + btn_list + ).then(http_bot, [state, model_selector, temperature, top_p, max_output_tokens], + [state, chatbot] + btn_list) + submit_btn.click(add_text, [state, textbox, imagebox, image_process_mode], [state, chatbot, textbox, imagebox] + btn_list + ).then(http_bot, [state, model_selector, temperature, top_p, max_output_tokens], + [state, chatbot] + btn_list) + + if args.model_list_mode == "once": + demo.load(load_demo, [url_params], [state, model_selector], + _js=get_window_url_params) + elif args.model_list_mode == "reload": + demo.load(load_demo_refresh_model_list, None, [state, model_selector]) + else: + raise ValueError(f"Unknown model list mode: {args.model_list_mode}") + + return demo + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="0.0.0.0") + parser.add_argument("--port", type=int) + parser.add_argument("--controller-url", type=str, default="http://localhost:21001") + parser.add_argument("--concurrency-count", type=int, default=10) + parser.add_argument("--model-list-mode", type=str, default="once", + choices=["once", "reload"]) + parser.add_argument("--share", action="store_true") + parser.add_argument("--moderate", action="store_true") + parser.add_argument("--embed", action="store_true") + args = parser.parse_args() + logger.info(f"args: {args}") + + models = get_model_list() + + logger.info(args) + demo = build_demo(args.embed) + demo.queue( + concurrency_count=args.concurrency_count, + api_open=False + ).launch( + server_name=args.host, + server_port=args.port, + share=args.share + ) \ No newline at end of file diff --git a/graphgpt/serve/gradio_web_server_multi.py b/graphgpt/serve/gradio_web_server_multi.py new file mode 100644 index 0000000..3be9243 --- /dev/null +++ b/graphgpt/serve/gradio_web_server_multi.py @@ -0,0 +1,219 @@ +""" +The gradio demo server with multiple tabs. +It supports chatting with a single model or chatting with two models side-by-side. +""" + +import argparse +import pickle + +import gradio as gr + +from fastchat.serve.gradio_block_arena_anony import ( + build_side_by_side_ui_anony, + load_demo_side_by_side_anony, + set_global_vars_anony, +) +from fastchat.serve.gradio_block_arena_named import ( + build_side_by_side_ui_named, + load_demo_side_by_side_named, + set_global_vars_named, +) +from fastchat.serve.gradio_patch import Chatbot as grChatbot +from fastchat.serve.gradio_web_server import ( + set_global_vars, + block_css, + build_single_model_ui, + get_model_list, + load_demo_single, +) +from fastchat.serve.monitor.monitor import build_leaderboard_tab +from fastchat.utils import build_logger, get_window_url_params_js + + +logger = build_logger("gradio_web_server_multi", "gradio_web_server_multi.log") + + +def load_demo(url_params, request: gr.Request): + logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}") + selected = 0 + if "arena" in url_params: + selected = 1 + elif "compare" in url_params: + selected = 2 + elif "leaderboard" in url_params: + selected = 3 + single_updates = load_demo_single(models, url_params) + + models_anony = models + if args.anony_only_for_proprietary_model: + # Only enable these models in anony battles. + if args.add_chatgpt: + models_anony = ["gpt-4", "gpt-3.5-turbo"] + models_anony + if args.add_claude: + models_anony = ["claude-v1", "claude-instant-v1"] + models_anony + if args.add_bard: + models_anony = ["bard"] + models_anony + + side_by_side_anony_updates = load_demo_side_by_side_anony(models_anony, url_params) + side_by_side_named_updates = load_demo_side_by_side_named(models, url_params) + return ( + (gr.Tabs.update(selected=selected),) + + single_updates + + side_by_side_anony_updates + + side_by_side_named_updates + ) + + +def build_demo(models, elo_results_file): + with gr.Blocks( + title="Chat with Open Large Language Models", + theme=gr.themes.Base(), + css=block_css, + ) as demo: + with gr.Tabs() as tabs: + with gr.Tab("Single Model", id=0): + ( + a_state, + a_model_selector, + a_chatbot, + a_textbox, + a_send_btn, + a_button_row, + a_parameter_row, + ) = build_single_model_ui(models) + a_list = [ + a_state, + a_model_selector, + a_chatbot, + a_textbox, + a_send_btn, + a_button_row, + a_parameter_row, + ] + + with gr.Tab("Chatbot Arena (battle)", id=1): + ( + b_states, + b_model_selectors, + b_chatbots, + b_textbox, + b_send_btn, + b_button_row, + b_button_row2, + b_parameter_row, + ) = build_side_by_side_ui_anony(models) + b_list = ( + b_states + + b_model_selectors + + b_chatbots + + [ + b_textbox, + b_send_btn, + b_button_row, + b_button_row2, + b_parameter_row, + ] + ) + + with gr.Tab("Chatbot Arena (side-by-side)", id=2): + ( + c_states, + c_model_selectors, + c_chatbots, + c_textbox, + c_send_btn, + c_button_row, + c_button_row2, + c_parameter_row, + ) = build_side_by_side_ui_named(models) + c_list = ( + c_states + + c_model_selectors + + c_chatbots + + [ + c_textbox, + c_send_btn, + c_button_row, + c_button_row2, + c_parameter_row, + ] + ) + + if elo_results_file: + with gr.Tab("Leaderboard", id=3): + build_leaderboard_tab(elo_results_file) + + url_params = gr.JSON(visible=False) + + if args.model_list_mode == "once": + demo.load( + load_demo, + [url_params], + [tabs] + a_list + b_list + c_list, + _js=get_window_url_params_js, + ) + else: + raise ValueError(f"Unknown model list mode: {args.model_list_mode}") + + return demo + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="0.0.0.0") + parser.add_argument("--port", type=int) + parser.add_argument("--controller-url", type=str, default="http://localhost:21001") + parser.add_argument("--concurrency-count", type=int, default=10) + parser.add_argument( + "--model-list-mode", + type=str, + default="once", + choices=["once"], + ) + parser.add_argument("--share", action="store_true") + parser.add_argument( + "--moderate", action="store_true", help="Enable content moderation" + ) + parser.add_argument( + "--add-chatgpt", + action="store_true", + help="Add OpenAI ChatGPT models (gpt-3.5-turbo, gpt-4)", + ) + parser.add_argument( + "--add-claude", + action="store_true", + help="Add Anthropic's Claude models (claude-v1, claude-instant-v1)", + ) + parser.add_argument( + "--add-bard", + action="store_true", + help="Add Google's Bard model (PaLM 2 for Chat: chat-bison@001)", + ) + parser.add_argument( + "--anony-only-for-proprietary-model", + action="store_true", + help="Only add ChatGPT, Claude, Bard under anony battle tab", + ) + parser.add_argument("--elo-results-file", type=str) + args = parser.parse_args() + logger.info(f"args: {args}") + + set_global_vars(args.controller_url, args.moderate) + set_global_vars_named(args.moderate) + set_global_vars_anony(args.moderate) + models = get_model_list(args.controller_url) + + if not args.anony_only_for_proprietary_model: + if args.add_chatgpt: + models = ["gpt-3.5-turbo", "gpt-4"] + models + if args.add_claude: + models = ["claude-v1", "claude-instant-v1"] + models + if args.add_bard: + models = ["bard"] + models + + demo = build_demo(models, args.elo_results_file) + demo.queue( + concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False + ).launch( + server_name=args.host, server_port=args.port, share=args.share, max_threads=200 + ) diff --git a/graphgpt/serve/huggingface_api.py b/graphgpt/serve/huggingface_api.py new file mode 100644 index 0000000..9223f2e --- /dev/null +++ b/graphgpt/serve/huggingface_api.py @@ -0,0 +1,69 @@ +""" +Use FastChat with Hugging Face generation APIs. + +Usage: +python3 -m fastchat.serve.huggingface_api --model lmsys/fastchat-t5-3b-v1.0 +python3 -m fastchat.serve.huggingface_api --model ~/model_weights/vicuna-7b/ +""" +import argparse +import json + +import torch +from transformers import AutoTokenizer, AutoModelForCausalLM + +from fastchat.model import load_model, get_conversation_template, add_model_args + + +@torch.inference_mode() +def main(args): + model, tokenizer = load_model( + args.model_path, + args.device, + args.num_gpus, + args.max_gpu_memory, + args.load_8bit, + args.cpu_offloading, + debug=args.debug, + ) + + msg = args.message + + conv = get_conversation_template(args.model_path) + conv.append_message(conv.roles[0], msg) + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + + input_ids = tokenizer([prompt]).input_ids + + if "t5" in args.model_path and args.repetition_penalty == 1.0: + args.repetition_penalty = 1.2 + output_ids = model.generate( + torch.as_tensor(input_ids).cuda(), + do_sample=True, + temperature=args.temperature, + repetition_penalty=args.repetition_penalty, + max_new_tokens=args.max_new_tokens, + ) + if model.config.is_encoder_decoder: + output_ids = output_ids[0] + else: + output_ids = output_ids[0][len(input_ids[0]) :] + outputs = tokenizer.decode( + output_ids, skip_special_tokens=True, spaces_between_special_tokens=False + ) + + print(f"{conv.roles[0]}: {msg}") + print(f"{conv.roles[1]}: {outputs}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + add_model_args(parser) + parser.add_argument("--temperature", type=float, default=0.7) + parser.add_argument("--repetition_penalty", type=float, default=1.0) + parser.add_argument("--max-new-tokens", type=int, default=512) + parser.add_argument("--debug", action="store_true") + parser.add_argument("--message", type=str, default="Hello! Who are you?") + args = parser.parse_args() + + main(args) diff --git a/graphgpt/serve/inference.py b/graphgpt/serve/inference.py new file mode 100644 index 0000000..489b379 --- /dev/null +++ b/graphgpt/serve/inference.py @@ -0,0 +1,314 @@ +"""Inference for FastChat models.""" +import abc +import gc +import math +from typing import Iterable, Optional +import sys +import warnings + +import psutil +import torch +from transformers import ( + AutoTokenizer, + AutoModelForCausalLM, + LlamaTokenizer, + LlamaForCausalLM, + AutoModel, + AutoModelForSeq2SeqLM, + T5Tokenizer, + AutoConfig, +) +from transformers.generation.logits_process import ( + LogitsProcessorList, + RepetitionPenaltyLogitsProcessor, + TemperatureLogitsWarper, + TopKLogitsWarper, + TopPLogitsWarper, +) + +from fastchat.conversation import get_conv_template, SeparatorStyle +from fastchat.model.model_adapter import load_model, get_conversation_template +from fastchat.model.chatglm_model import chatglm_generate_stream + + +def prepare_logits_processor( + temperature: float, repetition_penalty: float, top_p: float, top_k: int +) -> LogitsProcessorList: + processor_list = LogitsProcessorList() + # TemperatureLogitsWarper doesn't accept 0.0, 1.0 makes it a no-op so we skip two cases. + if temperature >= 1e-5 and temperature != 1.0: + processor_list.append(TemperatureLogitsWarper(temperature)) + if repetition_penalty > 1.0: + processor_list.append(RepetitionPenaltyLogitsProcessor(repetition_penalty)) + if 1e-8 <= top_p < 1.0: + processor_list.append(TopPLogitsWarper(top_p)) + if top_k > 0: + processor_list.append(TopKLogitsWarper(top_k)) + return processor_list + + +def partial_stop(output, stop_str): + for i in range(0, min(len(output), len(stop_str))): + if stop_str.startswith(output[-i:]): + return True + return False + + +@torch.inference_mode() +def generate_stream( + model, tokenizer, params, device, context_len=2048, stream_interval=2 +): + prompt = params["prompt"] + len_prompt = len(prompt) + temperature = float(params.get("temperature", 1.0)) + repetition_penalty = float(params.get("repetition_penalty", 1.0)) + top_p = float(params.get("top_p", 1.0)) + top_k = int(params.get("top_k", -1)) # -1 means disable + max_new_tokens = int(params.get("max_new_tokens", 256)) + stop_str = params.get("stop", None) + echo = bool(params.get("echo", True)) + stop_token_ids = params.get("stop_token_ids", None) or [] + stop_token_ids.append(tokenizer.eos_token_id) + + logits_processor = prepare_logits_processor( + temperature, repetition_penalty, top_p, top_k + ) + + input_ids = tokenizer(prompt).input_ids + input_echo_len = len(input_ids) + output_ids = list(input_ids) + + if model.config.is_encoder_decoder: + max_src_len = context_len + else: + max_src_len = context_len - max_new_tokens - 8 + + input_ids = input_ids[-max_src_len:] + + if model.config.is_encoder_decoder: + encoder_output = model.encoder( + input_ids=torch.as_tensor([input_ids], device=device) + )[0] + start_ids = torch.as_tensor( + [[model.generation_config.decoder_start_token_id]], + dtype=torch.int64, + device=device, + ) + + past_key_values = out = None + for i in range(max_new_tokens): + if i == 0: + if model.config.is_encoder_decoder: + out = model.decoder( + input_ids=start_ids, + encoder_hidden_states=encoder_output, + use_cache=True, + ) + logits = model.lm_head(out[0]) + else: + out = model(torch.as_tensor([input_ids], device=device), use_cache=True) + logits = out.logits + past_key_values = out.past_key_values + else: + if model.config.is_encoder_decoder: + out = model.decoder( + input_ids=torch.as_tensor([[token]], device=device), + encoder_hidden_states=encoder_output, + use_cache=True, + past_key_values=past_key_values, + ) + + logits = model.lm_head(out[0]) + else: + out = model( + input_ids=torch.as_tensor([[token]], device=device), + use_cache=True, + past_key_values=past_key_values, + ) + logits = out.logits + past_key_values = out.past_key_values + + if logits_processor: + if repetition_penalty > 1.0: + tmp_output_ids = torch.as_tensor([output_ids], device=logits.device) + else: + tmp_output_ids = None + last_token_logits = logits_processor(tmp_output_ids, logits[:, -1, :])[0] + else: + last_token_logits = logits[0, -1, :] + + if device == "mps": + # Switch to CPU by avoiding some bugs in mps backend. + last_token_logits = last_token_logits.float().to("cpu") + + if temperature < 1e-5 or top_p < 1e-8: # greedy + token = int(torch.argmax(last_token_logits)) + else: + probs = torch.softmax(last_token_logits, dim=-1) + token = int(torch.multinomial(probs, num_samples=1)) + + output_ids.append(token) + + if token in stop_token_ids: + stopped = True + else: + stopped = False + + if i % stream_interval == 0 or i == max_new_tokens - 1 or stopped: + if echo: + tmp_output_ids = output_ids + rfind_start = len_prompt + else: + tmp_output_ids = output_ids[input_echo_len:] + rfind_start = 0 + + output = tokenizer.decode( + tmp_output_ids, + skip_special_tokens=True, + spaces_between_special_tokens=False, + ) + + partially_stopped = False + if stop_str: + if isinstance(stop_str, str): + pos = output.rfind(stop_str, rfind_start) + if pos != -1: + output = output[:pos] + stopped = True + else: + partially_stopped = partial_stop(output, stop_str) + elif isinstance(stop_str, Iterable): + for each_stop in stop_str: + pos = output.rfind(each_stop, rfind_start) + if pos != -1: + output = output[:pos] + stopped = True + break + else: + partially_stopped = partial_stop(output, each_stop) + if partially_stopped: + break + else: + raise ValueError("Invalid stop field type.") + + # prevent yielding partial stop sequence + if not partially_stopped: + yield { + "text": output, + "usage": { + "prompt_tokens": input_echo_len, + "completion_tokens": i, + "total_tokens": input_echo_len + i, + }, + "finish_reason": None, + } + + if stopped: + break + + # finish stream event, which contains finish reason + if i == max_new_tokens - 1: + finish_reason = "length" + elif stopped: + finish_reason = "stop" + else: + finish_reason = None + + yield { + "text": output, + "usage": { + "prompt_tokens": input_echo_len, + "completion_tokens": i, + "total_tokens": input_echo_len + i, + }, + "finish_reason": finish_reason, + } + + # clean + del past_key_values, out + gc.collect() + torch.cuda.empty_cache() + + +class ChatIO(abc.ABC): + @abc.abstractmethod + def prompt_for_input(self, role: str) -> str: + """Prompt for input from a role.""" + + @abc.abstractmethod + def prompt_for_output(self, role: str): + """Prompt for output from a role.""" + + @abc.abstractmethod + def stream_output(self, output_stream): + """Stream output.""" + + +def chat_loop( + model_path: str, + device: str, + num_gpus: int, + max_gpu_memory: str, + load_8bit: bool, + cpu_offloading: bool, + conv_template: Optional[str], + temperature: float, + repetition_penalty: float, + max_new_tokens: int, + chatio: ChatIO, + debug: bool, +): + # Model + model, tokenizer = load_model( + model_path, device, num_gpus, max_gpu_memory, load_8bit, cpu_offloading, debug + ) + is_chatglm = "chatglm" in str(type(model)).lower() + is_fastchat_t5 = "t5" in str(type(model)).lower() + + # Hardcode T5 repetition penalty to be 1.2 + if is_fastchat_t5 and repetition_penalty == 1.0: + repetition_penalty = 1.2 + + # Chat + if conv_template: + conv = get_conv_template(conv_template) + else: + conv = get_conversation_template(model_path) + + while True: + try: + inp = chatio.prompt_for_input(conv.roles[0]) + except EOFError: + inp = "" + if not inp: + print("exit...") + break + + conv.append_message(conv.roles[0], inp) + conv.append_message(conv.roles[1], None) + + if is_chatglm: + generate_stream_func = chatglm_generate_stream + prompt = conv.messages[conv.offset :] + else: + generate_stream_func = generate_stream + prompt = conv.get_prompt() + + gen_params = { + "model": model_path, + "prompt": prompt, + "temperature": temperature, + "repetition_penalty": repetition_penalty, + "max_new_tokens": max_new_tokens, + "stop": conv.stop_str, + "stop_token_ids": conv.stop_token_ids, + "echo": False, + } + + chatio.prompt_for_output(conv.roles[1]) + output_stream = generate_stream_func(model, tokenizer, gen_params, device) + outputs = chatio.stream_output(output_stream) + conv.update_last_message(outputs.strip()) + + if debug: + print("\n", {"prompt": prompt, "outputs": outputs}, "\n") diff --git a/graphgpt/serve/model_worker.py b/graphgpt/serve/model_worker.py new file mode 100644 index 0000000..8f9c2fa --- /dev/null +++ b/graphgpt/serve/model_worker.py @@ -0,0 +1,427 @@ +""" +A model worker executes the model. +""" +import argparse +import asyncio +import dataclasses +import logging +import json +import os +import time +from typing import List, Union +import threading +import uuid + +from fastapi import FastAPI, Request, BackgroundTasks +from fastapi.responses import StreamingResponse, JSONResponse +import requests + +try: + from transformers import ( + AutoTokenizer, + AutoModelForCausalLM, + LlamaTokenizer, + AutoModel, + ) +except ImportError: + from transformers import ( + AutoTokenizer, + AutoModelForCausalLM, + LLaMATokenizer, + AutoModel, + ) +import torch +import torch.nn.functional as F +import uvicorn + +from fastchat.constants import WORKER_HEART_BEAT_INTERVAL, ErrorCode, SERVER_ERROR_MSG +from fastchat.model.model_adapter import load_model, add_model_args +from fastchat.model.chatglm_model import chatglm_generate_stream +from fastchat.serve.inference import generate_stream +from fastchat.utils import build_logger, pretty_print_semaphore + +GB = 1 << 30 + +worker_id = str(uuid.uuid4())[:6] +logger = build_logger("model_worker", f"model_worker_{worker_id}.log") +global_counter = 0 + +model_semaphore = None + + +def heart_beat_worker(controller): + while True: + time.sleep(WORKER_HEART_BEAT_INTERVAL) + controller.send_heart_beat() + + +class ModelWorker: + def __init__( + self, + controller_addr, + worker_addr, + worker_id, + no_register, + model_path, + model_name, + device, + num_gpus, + max_gpu_memory, + load_8bit=False, + cpu_offloading=False, + ): + self.controller_addr = controller_addr + self.worker_addr = worker_addr + self.worker_id = worker_id + if model_path.endswith("/"): + model_path = model_path[:-1] + self.model_name = model_name or model_path.split("/")[-1] + self.device = device + + logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...") + self.model, self.tokenizer = load_model( + model_path, device, num_gpus, max_gpu_memory, load_8bit, cpu_offloading + ) + if self.tokenizer.pad_token == None: + self.tokenizer.pad_token = self.tokenizer.eos_token + + if hasattr(self.model.config, "max_sequence_length"): + self.context_len = self.model.config.max_sequence_length + elif hasattr(self.model.config, "max_position_embeddings"): + self.context_len = self.model.config.max_position_embeddings + else: + self.context_len = 2048 + + # generate_stream + is_chatglm = "chatglm" in str(type(self.model)).lower() + if is_chatglm: + self.generate_stream_func = chatglm_generate_stream + else: + self.generate_stream_func = generate_stream + + if not no_register: + self.register_to_controller() + self.heart_beat_thread = threading.Thread( + target=heart_beat_worker, args=(self,) + ) + self.heart_beat_thread.start() + + def register_to_controller(self): + logger.info("Register to controller") + + url = self.controller_addr + "/register_worker" + data = { + "worker_name": self.worker_addr, + "check_heart_beat": True, + "worker_status": self.get_status(), + } + r = requests.post(url, json=data) + assert r.status_code == 200 + + def send_heart_beat(self): + logger.info( + f"Send heart beat. Models: {[self.model_name]}. " + f"Semaphore: {pretty_print_semaphore(model_semaphore)}. " + f"global_counter: {global_counter}" + ) + + url = self.controller_addr + "/receive_heart_beat" + + while True: + try: + ret = requests.post( + url, + json={ + "worker_name": self.worker_addr, + "queue_length": self.get_queue_length(), + }, + timeout=5, + ) + exist = ret.json()["exist"] + break + except requests.exceptions.RequestException as e: + logger.error(f"heart beat error: {e}") + time.sleep(5) + + if not exist: + self.register_to_controller() + + def get_queue_length(self): + if ( + model_semaphore is None + or model_semaphore._value is None + or model_semaphore._waiters is None + ): + return 0 + else: + return ( + args.limit_model_concurrency + - model_semaphore._value + + len(model_semaphore._waiters) + ) + + def get_status(self): + return { + "model_names": [self.model_name], + "speed": 1, + "queue_length": self.get_queue_length(), + } + + def count_token(self, params): + prompt = params["prompt"] + input_ids = self.tokenizer(prompt).input_ids + input_echo_len = len(input_ids) + + ret = { + "count": input_echo_len, + "error_code": 0, + } + return ret + + def generate_stream_gate(self, params): + try: + for output in self.generate_stream_func( + self.model, + self.tokenizer, + params, + self.device, + self.context_len, + args.stream_interval, + ): + ret = { + "text": output["text"], + "error_code": 0, + } + if "usage" in output: + ret["usage"] = output["usage"] + if "finish_reason" in output: + ret["finish_reason"] = output["finish_reason"] + if "logprobs" in output: + ret["logprobs"] = output["logprobs"] + yield json.dumps(ret).encode() + b"\0" + except torch.cuda.OutOfMemoryError as e: + ret = { + "text": f"{SERVER_ERROR_MSG}\n\n({e})", + "error_code": ErrorCode.CUDA_OUT_OF_MEMORY, + } + yield json.dumps(ret).encode() + b"\0" + except (ValueError, RuntimeError) as e: + ret = { + "text": f"{SERVER_ERROR_MSG}\n\n({e})", + "error_code": ErrorCode.INTERNAL_ERROR, + } + yield json.dumps(ret).encode() + b"\0" + + def generate_gate(self, params): + try: + ret = {"text": "", "error_code": 0} + for output in self.generate_stream_func( + self.model, + self.tokenizer, + params, + self.device, + self.context_len, + args.stream_interval, + ): + ret["text"] = output["text"] + if "usage" in output: + ret["usage"] = output["usage"] + if "finish_reason" in output: + ret["finish_reason"] = output["finish_reason"] + if "logprobs" in output: + ret["logprobs"] = output["logprobs"] + except torch.cuda.OutOfMemoryError as e: + ret = { + "text": f"{SERVER_ERROR_MSG}\n\n({e})", + "error_code": ErrorCode.CUDA_OUT_OF_MEMORY, + } + except (ValueError, RuntimeError) as e: + ret = { + "text": f"{SERVER_ERROR_MSG}\n\n({e})", + "error_code": ErrorCode.INTERNAL_ERROR, + } + return ret + + @torch.inference_mode() + def get_embeddings(self, params): + try: + tokenizer = self.tokenizer + is_llama = "llama" in str(type(self.model)) # vicuna support batch inference + is_chatglm = "chatglm" in str(type(self.model)) + is_t5 = "t5" in str(type(self.model)) + if is_llama: + encoding = tokenizer.batch_encode_plus( + params["input"], padding=True, return_tensors="pt" + ) + input_ids = encoding["input_ids"].to(self.device) + attention_mask = encoding["attention_mask"].to(self.device) + model_output = self.model( + input_ids, attention_mask, output_hidden_states=True + ) + data = model_output.hidden_states[-1] + mask = attention_mask.unsqueeze(-1).expand(data.size()).float() + masked_embeddings = data * mask + sum_embeddings = torch.sum(masked_embeddings, dim=1) + seq_length = torch.sum(mask, dim=1) + embedding = sum_embeddings / seq_length + normalized_embeddings = F.normalize(embedding, p=2, dim=1) + ret = { + "embedding": normalized_embeddings.tolist(), + "token_num": torch.sum(attention_mask).item(), + } + else: + embedding = [] + token_num = 0 + for text in params["input"]: + input_ids = tokenizer.encode(text, return_tensors="pt").to( + self.device + ) + if is_t5: + model_output = self.model(input_ids, decoder_input_ids=input_ids) + else: + model_output = self.model(input_ids, output_hidden_states=True) + if is_chatglm: + data = (model_output.hidden_states[-1].transpose(0, 1))[0] + elif is_t5: + data = model_output.encoder_last_hidden_state[0] + else: + data = model_output.hidden_states[-1][0] + data = F.normalize(torch.mean(data, dim=0), p=2, dim=0) + embedding.append(data.tolist()) + token_num += len(input_ids[0]) + ret = { + "embedding": embedding, + "token_num": token_num, + } + except torch.cuda.OutOfMemoryError as e: + ret = { + "text": f"{SERVER_ERROR_MSG}\n\n({e})", + "error_code": ErrorCode.CUDA_OUT_OF_MEMORY, + } + except (ValueError, RuntimeError) as e: + ret = { + "text": f"{SERVER_ERROR_MSG}\n\n({e})", + "error_code": ErrorCode.INTERNAL_ERROR, + } + return ret + + +app = FastAPI() + + +def release_model_semaphore(): + model_semaphore.release() + + +def acquire_model_semaphore(): + global model_semaphore, global_counter + global_counter += 1 + if model_semaphore is None: + model_semaphore = asyncio.Semaphore(args.limit_model_concurrency) + return model_semaphore.acquire() + + +def create_background_tasks(): + background_tasks = BackgroundTasks() + background_tasks.add_task(release_model_semaphore) + return background_tasks + + +@app.post("/worker_generate_stream") +async def api_generate_stream(request: Request): + params = await request.json() + await acquire_model_semaphore() + generator = worker.generate_stream_gate(params) + background_tasks = create_background_tasks() + return StreamingResponse(generator, background=background_tasks) + + +@app.post("/worker_generate") +async def api_generate(request: Request): + params = await request.json() + await acquire_model_semaphore() + output = worker.generate_gate(params) + release_model_semaphore() + return JSONResponse(output) + + +@app.post("/worker_generate_completion_stream") +async def api_generate_completion_stream(request: Request): + params = await request.json() + await acquire_model_semaphore() + generator = worker.generate_stream_gate(params) + background_tasks = create_background_tasks() + return StreamingResponse(generator, background=background_tasks) + + +@app.post("/worker_generate_completion") +async def api_generate_completion(request: Request): + params = await request.json() + await acquire_model_semaphore() + completion = worker.generate_gate(params) + background_tasks = create_background_tasks() + return JSONResponse(content=completion, background=background_tasks) + + +@app.post("/worker_get_embeddings") +async def api_get_embeddings(request: Request): + params = await request.json() + await acquire_model_semaphore() + embedding = worker.get_embeddings(params) + background_tasks = create_background_tasks() + return JSONResponse(content=embedding, background=background_tasks) + + +@app.post("/worker_get_status") +async def api_get_status(request: Request): + return worker.get_status() + + +@app.post("/count_token") +async def count_token(request: Request): + params = await request.json() + return worker.count_token(params) + + +@app.post("/model_details") +async def model_details(request: Request): + return {"context_length": worker.context_len} + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--port", type=int, default=21002) + parser.add_argument("--worker-address", type=str, default="http://localhost:21002") + parser.add_argument( + "--controller-address", type=str, default="http://localhost:21001" + ) + add_model_args(parser) + parser.add_argument("--model-name", type=str, help="Optional display name") + parser.add_argument("--limit-model-concurrency", type=int, default=5) + parser.add_argument("--stream-interval", type=int, default=2) + parser.add_argument("--no-register", action="store_true") + args = parser.parse_args() + logger.info(f"args: {args}") + + if args.gpus: + if len(args.gpus.split(",")) < args.num_gpus: + raise ValueError( + f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!" + ) + os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus + + worker = ModelWorker( + args.controller_address, + args.worker_address, + worker_id, + args.no_register, + args.model_path, + args.model_name, + args.device, + args.num_gpus, + args.max_gpu_memory, + args.load_8bit, + args.cpu_offloading, + ) + uvicorn.run(app, host=args.host, port=args.port, log_level="info") diff --git a/graphgpt/serve/model_worker_graph.py b/graphgpt/serve/model_worker_graph.py new file mode 100644 index 0000000..5e4b237 --- /dev/null +++ b/graphgpt/serve/model_worker_graph.py @@ -0,0 +1,285 @@ +""" +A model worker executes the model. +""" +import argparse +import asyncio +import json +import time +import threading +import uuid + +from fastapi import FastAPI, Request, BackgroundTasks +from fastapi.responses import StreamingResponse +import requests +import torch +import uvicorn +from functools import partial + +from graphgpt.constants import WORKER_HEART_BEAT_INTERVAL +from graphgpt.utils import (build_logger, server_error_msg, + pretty_print_semaphore) +from graphgpt.model.builder import load_pretrained_model +from graphgpt.mm_utils import process_images, load_image_from_base64, tokenizer_image_token, KeywordsStoppingCriteria +from graphgpt.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN +from transformers import TextIteratorStreamer +from threading import Thread + + +GB = 1 << 30 + +worker_id = str(uuid.uuid4())[:6] +logger = build_logger("model_worker", f"model_worker_{worker_id}.log") +global_counter = 0 + +model_semaphore = None + + +def heart_beat_worker(controller): + + while True: + time.sleep(WORKER_HEART_BEAT_INTERVAL) + controller.send_heart_beat() + + +class ModelWorker: + def __init__(self, controller_addr, worker_addr, + worker_id, no_register, + model_path, model_base, model_name, + load_8bit, load_4bit, device): + self.controller_addr = controller_addr + self.worker_addr = worker_addr + self.worker_id = worker_id + if model_path.endswith("/"): + model_path = model_path[:-1] + if model_name is None: + model_paths = model_path.split("/") + if model_paths[-1].startswith('checkpoint-'): + self.model_name = model_paths[-2] + "_" + model_paths[-1] + else: + self.model_name = model_paths[-1] + else: + self.model_name = model_name + + self.device = device + logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...") + self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model( + model_path, model_base, self.model_name, load_8bit, load_4bit, device=self.device) + self.is_multimodal = 'llava' in self.model_name.lower() + + if not no_register: + self.register_to_controller() + self.heart_beat_thread = threading.Thread( + target=heart_beat_worker, args=(self,)) + self.heart_beat_thread.start() + + def register_to_controller(self): + logger.info("Register to controller") + + url = self.controller_addr + "/register_worker" + data = { + "worker_name": self.worker_addr, + "check_heart_beat": True, + "worker_status": self.get_status() + } + r = requests.post(url, json=data) + assert r.status_code == 200 + + def send_heart_beat(self): + logger.info(f"Send heart beat. Models: {[self.model_name]}. " + f"Semaphore: {pretty_print_semaphore(model_semaphore)}. " + f"global_counter: {global_counter}") + + url = self.controller_addr + "/receive_heart_beat" + + while True: + try: + ret = requests.post(url, json={ + "worker_name": self.worker_addr, + "queue_length": self.get_queue_length()}, timeout=5) + exist = ret.json()["exist"] + break + except requests.exceptions.RequestException as e: + logger.error(f"heart beat error: {e}") + time.sleep(5) + + if not exist: + self.register_to_controller() + + def get_queue_length(self): + if model_semaphore is None: + return 0 + else: + return args.limit_model_concurrency - model_semaphore._value + (len( + model_semaphore._waiters) if model_semaphore._waiters is not None else 0) + + def get_status(self): + return { + "model_names": [self.model_name], + "speed": 1, + "queue_length": self.get_queue_length(), + } + + @torch.inference_mode() + def generate_stream(self, params): + tokenizer, model, image_processor = self.tokenizer, self.model, self.image_processor + + prompt = params["prompt"] + ori_prompt = prompt + images = params.get("images", None) + num_image_tokens = 0 + if images is not None and len(images) > 0 and self.is_multimodal: + if len(images) > 0: + if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN): + raise ValueError("Number of images does not match number of tokens in prompt") + + images = [load_image_from_base64(image) for image in images] + images = process_images(images, image_processor, model.config) + + if type(images) is list: + images = [image.to(self.model.device, dtype=torch.float16) for image in images] + else: + images = images.to(self.model.device, dtype=torch.float16) + + replace_token = DEFAULT_IMAGE_TOKEN + if getattr(self.model.config, 'mm_use_im_start_end', False): + replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN + prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token) + + num_image_tokens = prompt.count(replace_token) * model.get_vision_tower().num_patches + else: + images = None + image_args = {"images": images} + else: + images = None + image_args = {} + + temperature = float(params.get("temperature", 1.0)) + top_p = float(params.get("top_p", 1.0)) + max_context_length = getattr(model.config, 'max_position_embeddings', 2048) + max_new_tokens = min(int(params.get("max_new_tokens", 256)), 1024) + stop_str = params.get("stop", None) + do_sample = True if temperature > 0.001 else False + + input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device) + keywords = [stop_str] + stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) + streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15) + + max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens) + + if max_new_tokens < 1: + yield json.dumps({"text": ori_prompt + "Exceeds max token length. Please start a new conversation, thanks.", "error_code": 0}).encode() + b"\0" + return + + thread = Thread(target=model.generate, kwargs=dict( + inputs=input_ids, + do_sample=do_sample, + temperature=temperature, + top_p=top_p, + max_new_tokens=max_new_tokens, + streamer=streamer, + stopping_criteria=[stopping_criteria], + use_cache=True, + **image_args + )) + thread.start() + + generated_text = ori_prompt + for new_text in streamer: + generated_text += new_text + if generated_text.endswith(stop_str): + generated_text = generated_text[:-len(stop_str)] + yield json.dumps({"text": generated_text, "error_code": 0}).encode() + b"\0" + + def generate_stream_gate(self, params): + try: + for x in self.generate_stream(params): + yield x + except ValueError as e: + print("Caught ValueError:", e) + ret = { + "text": server_error_msg, + "error_code": 1, + } + yield json.dumps(ret).encode() + b"\0" + except torch.cuda.CudaError as e: + print("Caught torch.cuda.CudaError:", e) + ret = { + "text": server_error_msg, + "error_code": 1, + } + yield json.dumps(ret).encode() + b"\0" + except Exception as e: + print("Caught Unknown Error", e) + ret = { + "text": server_error_msg, + "error_code": 1, + } + yield json.dumps(ret).encode() + b"\0" + + +app = FastAPI() + + +def release_model_semaphore(fn=None): + model_semaphore.release() + if fn is not None: + fn() + + +@app.post("/worker_generate_stream") +async def generate_stream(request: Request): + global model_semaphore, global_counter + global_counter += 1 + params = await request.json() + + if model_semaphore is None: + model_semaphore = asyncio.Semaphore(args.limit_model_concurrency) + await model_semaphore.acquire() + worker.send_heart_beat() + generator = worker.generate_stream_gate(params) + background_tasks = BackgroundTasks() + background_tasks.add_task(partial(release_model_semaphore, fn=worker.send_heart_beat)) + return StreamingResponse(generator, background=background_tasks) + + +@app.post("/worker_get_status") +async def get_status(request: Request): + return worker.get_status() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--port", type=int, default=21002) + parser.add_argument("--worker-address", type=str, + default="http://localhost:21002") + parser.add_argument("--controller-address", type=str, + default="http://localhost:21001") + parser.add_argument("--model-path", type=str, default="facebook/opt-350m") + parser.add_argument("--model-base", type=str, default=None) + parser.add_argument("--model-name", type=str) + parser.add_argument("--device", type=str, default="cuda") + parser.add_argument("--multi-modal", action="store_true", help="Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path.") + parser.add_argument("--limit-model-concurrency", type=int, default=5) + parser.add_argument("--stream-interval", type=int, default=1) + parser.add_argument("--no-register", action="store_true") + parser.add_argument("--load-8bit", action="store_true") + parser.add_argument("--load-4bit", action="store_true") + args = parser.parse_args() + logger.info(f"args: {args}") + + if args.multi_modal: + logger.warning("Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path.") + + worker = ModelWorker(args.controller_address, + args.worker_address, + worker_id, + args.no_register, + args.model_path, + args.model_base, + args.model_name, + args.load_8bit, + args.load_4bit, + args.device) + uvicorn.run(app, host=args.host, port=args.port, log_level="info") \ No newline at end of file diff --git a/graphgpt/serve/monitor/basic_stats.py b/graphgpt/serve/monitor/basic_stats.py new file mode 100644 index 0000000..c910e12 --- /dev/null +++ b/graphgpt/serve/monitor/basic_stats.py @@ -0,0 +1,198 @@ +import argparse +import code +import datetime +import json +import os +from pytz import timezone +import time + +import pandas as pd +import plotly.express as px +import plotly.graph_objects as go +from tqdm import tqdm + + +def get_log_files(max_num_files=None): + dates = [] + for month in [4, 5]: + for day in range(1, 32): + dates.append(f"2023-{month:02d}-{day:02d}") + + num_servers = 12 + filenames = [] + for d in dates: + for i in range(num_servers): + name = os.path.expanduser(f"~/fastchat_logs/server{i}/{d}-conv.json") + if os.path.exists(name): + filenames.append(name) + max_num_files = max_num_files or len(filenames) + filenames = filenames[-max_num_files:] + return filenames + + +def load_log_files(log_files): + data = [] + for filename in tqdm(log_files, desc="read files"): + for retry in range(5): + try: + lines = open(filename).readlines() + break + except FileNotFoundError: + time.sleep(2) + + for l in lines: + row = json.loads(l) + + data.append( + dict( + type=row["type"], + tstamp=row["tstamp"], + model=row.get("model", ""), + models=row.get("models", ["", ""]), + ) + ) + + return data + + +def get_anony_vote_df(df): + anony_vote_df = df[ + df["type"].isin(["leftvote", "rightvote", "tievote", "bothbad_vote"]) + ] + anony_vote_df = anony_vote_df[ + anony_vote_df["models"].apply(lambda x: x[0] == "") + ] + return anony_vote_df + + +def merge_counts(series, on, names): + ret = pd.merge(series[0], series[1], on=on) + for i in range(2, len(series)): + ret = pd.merge(ret, series[i], on=on) + ret = ret.reset_index() + old_names = list(ret.columns)[-len(series) :] + rename = {old_name: new_name for old_name, new_name in zip(old_names, names)} + ret = ret.rename(columns=rename) + return ret + + +def report_basic_stats(log_files): + df_all = load_log_files(log_files) + df_all = pd.DataFrame(df_all) + now_t = df_all["tstamp"].max() + df_1_hour = df_all[df_all["tstamp"] > (now_t - 3600)] + df_1_day = df_all[df_all["tstamp"] > (now_t - 3600 * 24)] + anony_vote_df_all = get_anony_vote_df(df_all) + + # Chat trends + chat_dates = [ + datetime.datetime.fromtimestamp( + x, tz=timezone("US/Pacific") + ).strftime("%Y-%m-%d") + for x in df_all[df_all["type"] == "chat"]["tstamp"] + ] + chat_dates_counts = pd.value_counts(chat_dates) + vote_dates = [ + datetime.datetime.fromtimestamp( + x, tz=timezone("US/Pacific") + ).strftime("%Y-%m-%d") + for x in anony_vote_df_all["tstamp"] + ] + vote_dates_counts = pd.value_counts(vote_dates) + chat_dates_bar = go.Figure(data=[ + go.Bar(name="Anony. Vote", x=vote_dates_counts.index, y=vote_dates_counts, + text=[f"{val:.0f}" for val in vote_dates_counts], textposition="auto"), + go.Bar(name="Chat", x=chat_dates_counts.index, y=chat_dates_counts, + text=[f"{val:.0f}" for val in chat_dates_counts], textposition="auto"), + ]) + chat_dates_bar.update_layout( + barmode="stack", + xaxis_title="Dates", + yaxis_title="Count", + height=300, + width=1200, + ) + + # Model call counts + model_hist_all = df_all[df_all["type"] == "chat"]["model"].value_counts() + model_hist_1_day = df_1_day[df_1_day["type"] == "chat"]["model"].value_counts() + model_hist_1_hour = df_1_hour[df_1_hour["type"] == "chat"]["model"].value_counts() + model_hist = merge_counts( + [model_hist_all, model_hist_1_day, model_hist_1_hour], + on="model", + names=["All", "Last Day", "Last Hour"], + ) + model_hist_md = model_hist.to_markdown(index=False, tablefmt="github") + + # Action counts + action_hist_all = df_all["type"].value_counts() + action_hist_1_day = df_1_day["type"].value_counts() + action_hist_1_hour = df_1_hour["type"].value_counts() + action_hist = merge_counts( + [action_hist_all, action_hist_1_day, action_hist_1_hour], + on="type", + names=["All", "Last Day", "Last Hour"], + ) + action_hist_md = action_hist.to_markdown(index=False, tablefmt="github") + + # Anony vote counts + anony_vote_hist_all = anony_vote_df_all["type"].value_counts() + anony_vote_df_1_day = get_anony_vote_df(df_1_day) + anony_vote_hist_1_day = anony_vote_df_1_day["type"].value_counts() + anony_vote_df_1_hour = get_anony_vote_df(df_1_hour) + anony_vote_hist_1_hour = anony_vote_df_1_hour["type"].value_counts() + anony_vote_hist = merge_counts( + [anony_vote_hist_all, anony_vote_hist_1_day, anony_vote_hist_1_hour], + on="type", + names=["All", "Last Day", "Last Hour"], + ) + anony_vote_hist_md = anony_vote_hist.to_markdown(index=False, tablefmt="github") + + # Last 24 hours + chat_1_day = df_1_day[df_1_day["type"] == "chat"] + num_chats_last_24_hours = [] + base = df_1_day["tstamp"].min() + for i in range(24, 0, -1): + left = base + (i - 1) * 3600 + right = base + i * 3600 + num = ((chat_1_day["tstamp"] >= left) & (chat_1_day["tstamp"] < right)).sum() + num_chats_last_24_hours.append(num) + times = [ + datetime.datetime.fromtimestamp( + base + i * 3600, tz=timezone("US/Pacific") + ).strftime("%Y-%m-%d %H:%M:%S %Z") + for i in range(24, 0, -1) + ] + last_24_hours_df = pd.DataFrame({"time": times, "value": num_chats_last_24_hours}) + last_24_hours_md = last_24_hours_df.to_markdown(index=False, tablefmt="github") + + # Last update datetime + last_updated_tstamp = now_t + last_updated_datetime = datetime.datetime.fromtimestamp( + last_updated_tstamp, tz=timezone("US/Pacific") + ).strftime("%Y-%m-%d %H:%M:%S %Z") + + # code.interact(local=locals()) + + return { + "chat_dates_bar": chat_dates_bar, + "model_hist_md": model_hist_md, + "action_hist_md": action_hist_md, + "anony_vote_hist_md": anony_vote_hist_md, + "num_chats_last_24_hours": last_24_hours_md, + "last_updated_datetime": last_updated_datetime, + } + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--max-num-files", type=int) + args = parser.parse_args() + + log_files = get_log_files(args.max_num_files) + basic_stats = report_basic_stats(log_files) + + print(basic_stats["action_hist_md"] + "\n") + print(basic_stats["model_hist_md"] + "\n") + print(basic_stats["anony_vote_hist_md"] + "\n") + print(basic_stats["num_chats_last_24_hours"] + "\n") diff --git a/graphgpt/serve/monitor/clean_battle_data.py b/graphgpt/serve/monitor/clean_battle_data.py new file mode 100644 index 0000000..73c1a48 --- /dev/null +++ b/graphgpt/serve/monitor/clean_battle_data.py @@ -0,0 +1,195 @@ +import argparse +import datetime +import json +from pytz import timezone +import os +import time + +from tqdm import tqdm + +from fastchat.serve.monitor.basic_stats import get_log_files +from fastchat.utils import detect_language + + +VOTES = ["tievote", "leftvote", "rightvote", "bothbad_vote"] +IDENTITY_WORDS = [ + "vicuna", + "lmsys", + "koala", + "uc berkeley", + "open assistant", + "laion", + "chatglm", + "chatgpt", + "openai", + "anthropic", + "claude", + "bard", + "palm", + "Lamda", + "google", + "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**", +] + + +def get_log_files(max_num_files=None): + dates = [] + for month in [4]: + for day in range(24, 32): + dates.append(f"2023-{month:02d}-{day:02d}") + for month in [5]: + for day in range(1, 24): + dates.append(f"2023-{month:02d}-{day:02d}") + cutoff_date = dates[-1].replace("-", "") + + num_servers = 12 + filenames = [] + for d in dates: + for i in range(num_servers): + name = os.path.expanduser(f"~/fastchat_logs/server{i}/{d}-conv.json") + if os.path.exists(name): + filenames.append(name) + max_num_files = max_num_files or len(filenames) + filenames = filenames[-max_num_files:] + return filenames, cutoff_date + + +def remove_html(raw): + if raw.startswith("

"): + return raw[raw.find(": ") + 2 : -len("

\n")] + return raw + + +def clean_battle_data(log_files): + data = [] + for filename in tqdm(log_files, desc="read files"): + for retry in range(5): + try: + lines = open(filename).readlines() + break + except FileNotFoundError: + time.sleep(2) + + for l in lines: + row = json.loads(l) + if row["type"] in VOTES: + data.append(row) + + convert_type = { + "leftvote": "model_a", + "rightvote": "model_b", + "tievote": "tie", + "bothbad_vote": "tie (bothbad)", + } + + all_models = set() + ct_annoy = 0 + ct_invalid = 0 + ct_leaked_identity = 0 + battles = [] + for row in data: + # Resolve model names + models_public = [remove_html(row["models"][0]), remove_html(row["models"][1])] + if "model_name" in row["states"][0]: + models_hidden = [ + row["states"][0]["model_name"], + row["states"][1]["model_name"], + ] + if models_hidden[0] is None: + models_hidden = models_public + else: + models_hidden = models_public + + if (models_public[0] == "" and models_public[1] != "") or ( + models_public[1] == "" and models_public[0] != "" + ): + ct_invalid += 1 + continue + + if models_public[0] == "" or models_public[0] == "Model A": + anony = True + models = models_hidden + ct_annoy += 1 + else: + anony = False + models = models_public + if not models_public == models_hidden: + ct_invalid += 1 + continue + + # Detect langauge + state = row["states"][0] + if state["offset"] >= len(state["messages"]): + ct_invalid += 1 + continue + lang_code = detect_language(state["messages"][state["offset"]][1]) + rounds = (len(state["messages"]) - state["offset"]) // 2 + + # Drop conversations if the model names are leaked + leaked_identity = False + messages = "" + for i in range(2): + state = row["states"][i] + for role, msg in state["messages"][state["offset"] :]: + if msg: + messages += msg.lower() + for word in IDENTITY_WORDS: + if word in messages: + leaked_identity = True + break + + if leaked_identity: + ct_leaked_identity += 1 + continue + + # Replace bard with palm + models = [m.replace("bard", "palm-2") for m in models] + + # Keep the result + battles.append( + dict( + model_a=models[0], + model_b=models[1], + win=convert_type[row["type"]], + anony=anony, + rounds=rounds, + language=lang_code, + tstamp=row["tstamp"], + ) + ) + + all_models.update(models_hidden) + battles.sort(key=lambda x: x["tstamp"]) + last_updated_tstamp = battles[-1]["tstamp"] + + last_updated_datetime = datetime.datetime.fromtimestamp( + last_updated_tstamp, tz=timezone("US/Pacific") + ).strftime("%Y-%m-%d %H:%M:%S %Z") + + print( + f"#votes: {len(data)}, #invalid votes: {ct_invalid}, " + f"#leaked_identity: {ct_leaked_identity}" + ) + print(f"#battles: {len(battles)}, #annoy: {ct_annoy}") + print(f"#models: {len(all_models)}, {all_models}") + print(f"last-updated: {last_updated_datetime}") + + return battles + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--max-num-files", type=int) + args = parser.parse_args() + + log_files, cutoff_date = get_log_files(args.max_num_files) + battles = clean_battle_data(log_files) + + print("Samples:") + for i in range(4): + print(battles[i]) + + output = f"clean_battle_{cutoff_date}.json" + with open(output, "w") as fout: + json.dump(battles, fout, indent=2) + print(f"Write cleaned data to {output}") diff --git a/graphgpt/serve/monitor/elo_analysis.py b/graphgpt/serve/monitor/elo_analysis.py new file mode 100644 index 0000000..ea0307c --- /dev/null +++ b/graphgpt/serve/monitor/elo_analysis.py @@ -0,0 +1,283 @@ +import argparse +from collections import defaultdict +import datetime +import json +import math +import pickle +from pytz import timezone + +import gdown +import numpy as np +import pandas as pd +import plotly.express as px +from tqdm import tqdm + +from fastchat.model.model_registry import get_model_info +from fastchat.serve.monitor.basic_stats import get_log_files +from fastchat.serve.monitor.clean_battle_data import clean_battle_data + + +pd.options.display.float_format = "{:.2f}".format + + +def compute_elo(battles, K=4, SCALE=400, BASE=10, INIT_RATING=1000): + rating = defaultdict(lambda: INIT_RATING) + + for rd, model_a, model_b, win in battles[ + ["model_a", "model_b", "win"] + ].itertuples(): + ra = rating[model_a] + rb = rating[model_b] + ea = 1 / (1 + BASE ** ((rb - ra) / SCALE)) + eb = 1 / (1 + BASE ** ((ra - rb) / SCALE)) + if win == "model_a": + sa = 1 + elif win == "model_b": + sa = 0 + elif win == "tie" or win == "tie (bothbad)": + sa = 0.5 + else: + raise Exception(f"unexpected vote {win}") + rating[model_a] += K * (sa - ea) + rating[model_b] += K * (1 - sa - eb) + + return dict(rating) + + +def get_bootstrap_result(battles, func_compute_elo, num_round=1000): + rows = [] + for i in tqdm(range(num_round), desc="bootstrap"): + tmp_battles = battles.sample(frac=1.0, replace=True) + # tmp_battles = tmp_battles.sort_values(ascending=True, by=["tstamp"]) + rows.append(func_compute_elo(tmp_battles)) + df = pd.DataFrame(rows) + return df[df.median().sort_values(ascending=False).index] + + +def get_elo_from_bootstrap(bootstrap_df): + return dict(bootstrap_df.quantile(0.5)) + + +def compute_pairwise_win_fraction(battles, model_order): + # Times each model wins as Model A + a_win_ptbl = pd.pivot_table( + battles[battles["win"] == "model_a"], + index="model_a", + columns="model_b", + aggfunc="size", + fill_value=0, + ) + + # Table counting times each model wins as Model B + b_win_ptbl = pd.pivot_table( + battles[battles["win"] == "model_b"], + index="model_a", + columns="model_b", + aggfunc="size", + fill_value=0, + ) + + # Table counting number of A-B pairs + num_battles_ptbl = pd.pivot_table( + battles, index="model_a", columns="model_b", aggfunc="size", fill_value=0 + ) + + # Computing the proportion of wins for each model as A and as B + # against all other models + row_beats_col_freq = (a_win_ptbl + b_win_ptbl.T) / ( + num_battles_ptbl + num_battles_ptbl.T + ) + + if model_order is None: + prop_wins = row_beats_col_freq.mean(axis=1).sort_values(ascending=False) + model_order = list(prop_wins.keys()) + + # Arrange ordering according to proprition of wins + row_beats_col = row_beats_col_freq.loc[model_order, model_order] + return row_beats_col + + +def visualize_leaderboard_table(rating): + models = list(rating.keys()) + models.sort(key=lambda k: -rating[k]) + + emoji_dict = { + 1: "🥇", + 2: "🥈", + 3: "🥉", + } + + md = "" + md += "| Rank | Model | Elo Rating | Description |\n" + md += "| --- | --- | --- | --- |\n" + for i, model in enumerate(models): + rank = i + 1 + minfo = get_model_info(model) + emoji = emoji_dict.get(rank, "") + md += f"| {rank} | {emoji} [{model}]({minfo.link}) | {rating[model]:.0f} | {minfo.description} |\n" + + return md + + +def visualize_pairwise_win_fraction(battles, model_order): + row_beats_col = compute_pairwise_win_fraction(battles, model_order) + fig = px.imshow( + row_beats_col, + color_continuous_scale="RdBu", + text_auto=".2f", + height=600, + width=600, + ) + fig.update_layout( + xaxis_title="Model B", + yaxis_title="Model A", + xaxis_side="top", + title_y=0.07, + title_x=0.5, + ) + fig.update_traces( + hovertemplate="Model A: %{y}
Model B: %{x}
Fraction of A Wins: %{z}" + ) + + return fig + + +def visualize_battle_count(battles, model_order): + ptbl = pd.pivot_table( + battles, index="model_a", columns="model_b", aggfunc="size", fill_value=0 + ) + battle_counts = ptbl + ptbl.T + fig = px.imshow( + battle_counts.loc[model_order, model_order], + text_auto=True, + height=600, + width=600, + ) + fig.update_layout( + xaxis_title="Model B", + yaxis_title="Model A", + xaxis_side="top", + title_y=0.07, + title_x=0.5, + ) + fig.update_traces( + hovertemplate="Model A: %{y}
Model B: %{x}
Count: %{z}" + ) + return fig + + +def visualize_average_win_rate(battles): + row_beats_col_freq = compute_pairwise_win_fraction(battles, None) + fig = px.bar( + row_beats_col_freq.mean(axis=1).sort_values(ascending=False), + text_auto=".2f", + height=400, + width=600, + ) + fig.update_layout( + yaxis_title="Average Win Rate", xaxis_title="Model", showlegend=False + ) + return fig + + +def visualize_bootstrap_elo_rating(df): + bars = ( + pd.DataFrame( + dict( + lower=df.quantile(0.025), + rating=df.quantile(0.5), + upper=df.quantile(0.975), + ) + ) + .reset_index(names="model") + .sort_values("rating", ascending=False) + ) + bars["error_y"] = bars["upper"] - bars["rating"] + bars["error_y_minus"] = bars["rating"] - bars["lower"] + bars["rating_rounded"] = np.round(bars["rating"], 2) + fig = px.scatter( + bars, + x="model", + y="rating", + error_y="error_y", + error_y_minus="error_y_minus", + text="rating_rounded", + height=400, + width=600, + ) + fig.update_layout(xaxis_title="Model", yaxis_title="Rating") + return fig + + +def report_elo_analysis_results(battles_json): + battles = pd.DataFrame(battles_json) + battles = battles.sort_values(ascending=True, by=["tstamp"]) + # Only use anonymous votes + battles = battles[battles["anony"]].reset_index(drop=True) + battles_no_ties = battles[~battles["win"].str.contains("tie")] + + # Online update + elo_rating_online = compute_elo(battles) + + # Bootstrap + bootstrap_df = get_bootstrap_result(battles, compute_elo) + elo_rating_median = get_elo_from_bootstrap(bootstrap_df) + elo_rating_median = {k: int(v + 0.5) for k, v in elo_rating_median.items()} + model_order = list(elo_rating_online.keys()) + model_order.sort(key=lambda k: -elo_rating_online[k]) + + # Plots + leaderboard_table = visualize_leaderboard_table(elo_rating_online) + win_fraction_heatmap = visualize_pairwise_win_fraction(battles_no_ties, model_order) + battle_count_heatmap = visualize_battle_count(battles_no_ties, model_order) + average_win_rate_bar = visualize_average_win_rate(battles_no_ties) + bootstrap_elo_rating = visualize_bootstrap_elo_rating(bootstrap_df) + + last_updated_tstamp = battles["tstamp"].max() + last_updated_datetime = datetime.datetime.fromtimestamp( + last_updated_tstamp, tz=timezone("US/Pacific") + ).strftime("%Y-%m-%d %H:%M:%S %Z") + + return { + "elo_rating_online": elo_rating_online, + "elo_rating_median": elo_rating_median, + "leaderboard_table": leaderboard_table, + "win_fraction_heatmap": win_fraction_heatmap, + "battle_count_heatmap": battle_count_heatmap, + "average_win_rate_bar": average_win_rate_bar, + "bootstrap_elo_rating": bootstrap_elo_rating, + "last_updated_datetime": last_updated_datetime, + } + + +def pretty_print_elo_rating(rating): + model_order = list(rating.keys()) + model_order.sort(key=lambda k: -rating[k]) + for i, model in enumerate(model_order): + print(f"{i+1:2d}, {model:25s}, {rating[model]:.0f}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--clean-battle-file", type=str) + parser.add_argument("--max-num-files", type=int) + args = parser.parse_args() + + if args.clean_battle_file: + # Read data from a cleaned battle files + battles = pd.read_json(args.clean_battle_file) + else: + # Read data from all log files + log_files = get_log_files(args.max_num_files) + battles = clean_battle_data(log_files) + + results = report_elo_analysis_results(battles) + + print("# Online") + pretty_print_elo_rating(results["elo_rating_online"]) + print("# Median") + pretty_print_elo_rating(results["elo_rating_median"]) + print(f"last update : {results['last_updated_datetime']}") + + with open("elo_results.pkl", "wb") as fout: + pickle.dump(results, fout) diff --git a/graphgpt/serve/monitor/hf_space_leaderboard_app.py b/graphgpt/serve/monitor/hf_space_leaderboard_app.py new file mode 100644 index 0000000..44ae4a8 --- /dev/null +++ b/graphgpt/serve/monitor/hf_space_leaderboard_app.py @@ -0,0 +1,86 @@ +"""A gradio app that renders a static leaderboard. This is used for Hugging Face Space.""" +import argparse +import pickle + +import gradio as gr + + +notebook_url = "https://colab.research.google.com/drive/17L9uCiAivzWfzOxo2Tb9RMauT7vS6nVU?usp=sharing" + + +def make_leaderboard_md(elo_results): + leaderboard_md = f""" +# Leaderboard +[[Blog](https://lmsys.org/blog/2023-05-03-arena/)] [[Vote](https://arena.lmsys.org/)] [[Github]](https://github.com/lm-sys/FastChat) [[Twitter]](https://twitter.com/lmsysorg) [[Discord]](https://discord.gg/KjdtsE9V) + +We use the Elo rating system to calculate the relative performance of the models. You can view the voting data, basic analyses, and calculation procedure in this [notebook]({notebook_url}). We will periodically release new leaderboards. If you want to see more models, please help us [add them](https://github.com/lm-sys/FastChat/blob/main/docs/arena.md#how-to-add-a-new-model). +Last updated: {elo_results["last_updated_datetime"]} +{elo_results["leaderboard_table"]} +""" + return leaderboard_md + + +def build_leaderboard_tab(elo_results_file): + if elo_results_file is not None: + with open(elo_results_file, "rb") as fin: + elo_results = pickle.load(fin) + + md = make_leaderboard_md(elo_results) + p1 = elo_results["win_fraction_heatmap"] + p2 = elo_results["battle_count_heatmap"] + p3 = elo_results["average_win_rate_bar"] + p4 = elo_results["bootstrap_elo_rating"] + else: + md = "Loading ..." + p1 = p2 = p3 = p4 = None + + md_1 = gr.Markdown(md) + gr.Markdown( + f"""## More Statistics\n +We added some additional figures to show more statistics. The code for generating them is also included in this [notebook]({notebook_url}). +Please note that you may see different orders from different ranking methods. This is expected for models that perform similarly, as demonstrated by the confidence interval in the bootstrap figure. Going forward, we prefer the classical Elo calculation because of its scalability and interpretability. You can find more discussions in this blog [post](https://lmsys.org/blog/2023-05-03-arena/). +""" + ) + + with gr.Row(): + with gr.Column(): + gr.Markdown( + "#### Figure 1: Fraction of Model A Wins for All Non-tied A vs. B Battles" + ) + plot_1 = gr.Plot(p1, show_label=False) + with gr.Column(): + gr.Markdown( + "#### Figure 2: Battle Count for Each Combination of Models (without Ties)" + ) + plot_2 = gr.Plot(p2, show_label=False) + with gr.Row(): + with gr.Column(): + gr.Markdown( + "#### Figure 3: Average Win Rate Against All Other Models (Assuming Uniform Sampling and No Ties)" + ) + plot_3 = gr.Plot(p3, show_label=False) + with gr.Column(): + gr.Markdown( + "#### Figure 4: Bootstrap of Elo Estimates (1000 Rounds of Random Sampling)" + ) + plot_4 = gr.Plot(p4, show_label=False) + return [md_1, plot_1, plot_2, plot_3, plot_4] + + +def build_demo(elo_results_file): + with gr.Blocks( + title="Chatbot Arena Leaderboard", + theme=gr.themes.Base(), + ) as demo: + leader_components = build_leaderboard_tab(elo_results_file) + + return demo + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--share", action="store_true") + args = parser.parse_args() + + demo = build_demo("elo_results_20230508.pkl") + demo.launch(share=args.share) diff --git a/graphgpt/serve/monitor/monitor.py b/graphgpt/serve/monitor/monitor.py new file mode 100644 index 0000000..ac9d26e --- /dev/null +++ b/graphgpt/serve/monitor/monitor.py @@ -0,0 +1,209 @@ +# sudo apt install pkg-config libicu-dev +# pip install pytz gradio gdown plotly polyglot pyicu pycld2 tabulate + +import argparse +import pickle +import os +import threading +import time + +import gradio as gr + +from fastchat.serve.monitor.basic_stats import report_basic_stats, get_log_files +from fastchat.serve.monitor.clean_battle_data import clean_battle_data +from fastchat.serve.monitor.elo_analysis import report_elo_analysis_results +from fastchat.utils import build_logger, get_window_url_params_js + + +notebook_url = "https://colab.research.google.com/drive/17L9uCiAivzWfzOxo2Tb9RMauT7vS6nVU?usp=sharing" + + +logger = build_logger("monitor", "monitor.log") + + +basic_component_values = [None] * 6 +leader_component_values = [None] * 5 + + +def make_leaderboard_md(elo_results): + leaderboard_md = f""" +# Leaderboard +[[Blog](https://lmsys.org/blog/2023-05-03-arena/)] [[GitHub]](https://github.com/lm-sys/FastChat) [[Twitter]](https://twitter.com/lmsysorg) [[Discord]](https://discord.gg/KjdtsE9V) + +We use the Elo rating system to calculate the relative performance of the models. You can view the voting data, basic analyses, and calculation procedure in this [notebook]({notebook_url}). We will periodically release new leaderboards. If you want to see more models, please help us [add them](https://github.com/lm-sys/FastChat/blob/main/docs/arena.md#how-to-add-a-new-model). +Last updated: {elo_results["last_updated_datetime"]} +{elo_results["leaderboard_table"]} +""" + return leaderboard_md + + +def update_elo_components(max_num_files, elo_results_file): + log_files = get_log_files(max_num_files) + + # Leaderboard + if elo_results_file is None: + battles = clean_battle_data(log_files) + elo_results = report_elo_analysis_results(battles) + + leader_component_values[0] = make_leaderboard_md(elo_results) + leader_component_values[1] = elo_results["win_fraction_heatmap"] + leader_component_values[2] = elo_results["battle_count_heatmap"] + leader_component_values[3] = elo_results["average_win_rate_bar"] + leader_component_values[4] = elo_results["bootstrap_elo_rating"] + + # Basic stats + basic_stats = report_basic_stats(log_files) + md0 = f"Last updated: {basic_stats['last_updated_datetime']}" + + md1 = "### Action Histogram\n" + md1 += basic_stats["action_hist_md"] + "\n" + + md2 = "### Anony. Vote Histogram\n" + md2 += basic_stats["anony_vote_hist_md"] + "\n" + + md3 = "### Model Call Histogram\n" + md3 += basic_stats["model_hist_md"] + "\n" + + md4 = "### Model Call (Last 24 Hours)\n" + md4 += basic_stats["num_chats_last_24_hours"] + "\n" + + basic_component_values[0] = md0 + basic_component_values[1] = basic_stats["chat_dates_bar"] + basic_component_values[2] = md1 + basic_component_values[3] = md2 + basic_component_values[4] = md3 + basic_component_values[5] = md4 + + +def update_worker(max_num_files, interval, elo_results_file): + while True: + tic = time.time() + update_elo_components(max_num_files, elo_results_file) + durtaion = time.time() - tic + print(f"update duration: {durtaion:.2f} s") + time.sleep(max(interval - durtaion, 0)) + + +def load_demo(url_params, request: gr.Request): + logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}") + return basic_component_values + leader_component_values + + +def build_basic_stats_tab(): + empty = "Loading ..." + basic_component_values[:] = [empty, None, empty, empty, empty, empty] + + md0 = gr.Markdown(empty) + gr.Markdown( + "#### Figure 1: Number of model calls and votes" + ) + plot_1 = gr.Plot(show_label=False) + with gr.Row(): + with gr.Column(): + md1 = gr.Markdown(empty) + with gr.Column(): + md2 = gr.Markdown(empty) + with gr.Row(): + with gr.Column(): + md3 = gr.Markdown(empty) + with gr.Column(): + md4 = gr.Markdown(empty) + return [md0, plot_1, md1, md2, md3, md4] + + +def build_leaderboard_tab(elo_results_file): + if elo_results_file is not None: + with open(elo_results_file, "rb") as fin: + elo_results = pickle.load(fin) + + md = make_leaderboard_md(elo_results) + p1 = elo_results["win_fraction_heatmap"] + p2 = elo_results["battle_count_heatmap"] + p3 = elo_results["average_win_rate_bar"] + p4 = elo_results["bootstrap_elo_rating"] + else: + md = "Loading ..." + p1 = p2 = p3 = p4 = None + + leader_component_values[:] = [md, p1, p2, p3, p4] + + md_1 = gr.Markdown(md) + gr.Markdown( + f"""## More Statistics\n +We added some additional figures to show more statistics. The code for generating them is also included in this [notebook]({notebook_url}). +Please note that you may see different orders from different ranking methods. This is expected for models that perform similarly, as demonstrated by the confidence interval in the bootstrap figure. Going forward, we prefer the classical Elo calculation because of its scalability and interpretability. You can find more discussions in this blog [post](https://lmsys.org/blog/2023-05-03-arena/). +""" + ) + + with gr.Row(): + with gr.Column(): + gr.Markdown( + "#### Figure 1: Fraction of Model A Wins for All Non-tied A vs. B Battles" + ) + plot_1 = gr.Plot(p1, show_label=False) + with gr.Column(): + gr.Markdown( + "#### Figure 2: Battle Count for Each Combination of Models (without Ties)" + ) + plot_2 = gr.Plot(p2, show_label=False) + with gr.Row(): + with gr.Column(): + gr.Markdown( + "#### Figure 3: Average Win Rate Against All Other Models (Assuming Uniform Sampling and No Ties)" + ) + plot_3 = gr.Plot(p3, show_label=False) + with gr.Column(): + gr.Markdown( + "#### Figure 4: Bootstrap of Elo Estimates (1000 Rounds of Random Sampling)" + ) + plot_4 = gr.Plot(p4, show_label=False) + return [md_1, plot_1, plot_2, plot_3, plot_4] + + +def build_demo(elo_results_file): + with gr.Blocks( + title="Monitor", + theme=gr.themes.Base(), + ) as demo: + with gr.Tabs() as tabs: + with gr.Tab("Leaderboard", id=0): + leader_components = build_leaderboard_tab(elo_results_file) + + with gr.Tab("Basic Stats", id=1): + basic_components = build_basic_stats_tab() + + url_params = gr.JSON(visible=False) + demo.load( + load_demo, + [url_params], + basic_components + leader_components, + _js=get_window_url_params_js, + ) + + return demo + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="0.0.0.0") + parser.add_argument("--port", type=int) + parser.add_argument("--share", action="store_true") + parser.add_argument("--concurrency-count", type=int, default=10) + parser.add_argument("--update-interval", type=int, default=300) + parser.add_argument("--max-num-files", type=int) + parser.add_argument("--elo-results-file", type=str) + args = parser.parse_args() + logger.info(f"args: {args}") + + update_thread = threading.Thread( + target=update_worker, + args=(args.max_num_files, args.update_interval, args.elo_results_file), + ) + update_thread.start() + + demo = build_demo(args.elo_results_file) + demo.queue( + concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False + ).launch( + server_name=args.host, server_port=args.port, share=args.share, max_threads=200 + ) diff --git a/graphgpt/serve/openai_api_server.py b/graphgpt/serve/openai_api_server.py new file mode 100644 index 0000000..6e3099a --- /dev/null +++ b/graphgpt/serve/openai_api_server.py @@ -0,0 +1,722 @@ +"""A server that provides OpenAI-compatible RESTful APIs. It supports: + +- Chat Completions. (Reference: https://platform.openai.com/docs/api-reference/chat) +- Completions. (Reference: https://platform.openai.com/docs/api-reference/completions) +- Embeddings. (Reference: https://platform.openai.com/docs/api-reference/embeddings) + +Usage: +python3 -m fastchat.serve.openai_api_server +""" +import asyncio + +import argparse +import asyncio +import json +import logging + +import os +from typing import Generator, Optional, Union, Dict, List, Any + +import fastapi +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import StreamingResponse, JSONResponse +import httpx +from pydantic import BaseSettings +import shortuuid +import tiktoken +import uvicorn + +from fastchat.constants import WORKER_API_TIMEOUT, WORKER_API_EMBEDDING_BATCH_SIZE, ErrorCode +from fastchat.model.model_adapter import get_conversation_template +from fastapi.exceptions import RequestValidationError +from fastchat.protocol.openai_api_protocol import ( + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionResponseStreamChoice, + ChatCompletionStreamResponse, + ChatMessage, + ChatCompletionResponseChoice, + CompletionRequest, + CompletionResponse, + CompletionResponseChoice, + DeltaMessage, + CompletionResponseStreamChoice, + CompletionStreamResponse, + EmbeddingsRequest, + EmbeddingsResponse, + ErrorResponse, + ModelCard, + ModelList, + ModelPermission, + TokenCheckRequest, + TokenCheckResponse, + UsageInfo, +) + +logger = logging.getLogger(__name__) + + +class AppSettings(BaseSettings): + # The address of the model controller. + controller_address: str = "http://localhost:21001" + + +app_settings = AppSettings() + +app = fastapi.FastAPI() +headers = {"User-Agent": "FastChat API Server"} + + +def create_error_response(code: int, message: str) -> JSONResponse: + return JSONResponse( + ErrorResponse(message=message, code=code).dict(), status_code=400 + ) + + +@app.exception_handler(RequestValidationError) +async def validation_exception_handler(request, exc): + return create_error_response(ErrorCode.VALIDATION_TYPE_ERROR, str(exc)) + + +async def check_model(request) -> Optional[JSONResponse]: + controller_address = app_settings.controller_address + ret = None + async with httpx.AsyncClient() as client: + try: + _worker_addr = await _get_worker_address(request.model, client) + except: + models_ret = await client.post(controller_address + "/list_models") + models = models_ret.json()["models"] + ret = create_error_response( + ErrorCode.INVALID_MODEL, + f"Only {'&&'.join(models)} allowed now, your model {request.model}", + ) + return ret + + +async def check_length(request, prompt, max_tokens): + async with httpx.AsyncClient() as client: + worker_addr = await _get_worker_address(request.model, client) + + response = await client.post( + worker_addr + "/model_details", + headers=headers, + json={}, + timeout=WORKER_API_TIMEOUT, + ) + context_len = response.json()["context_length"] + + response = await client.post( + worker_addr + "/count_token", + headers=headers, + json={"prompt": prompt}, + timeout=WORKER_API_TIMEOUT, + ) + token_num = response.json()["count"] + + if token_num + max_tokens > context_len: + return create_error_response( + ErrorCode.CONTEXT_OVERFLOW, + f"This model's maximum context length is {context_len} tokens. " + f"However, you requested {max_tokens + token_num} tokens " + f"({token_num} in the messages, " + f"{max_tokens} in the completion). " + f"Please reduce the length of the messages or completion.", + ) + else: + return None + + +def check_requests(request) -> Optional[JSONResponse]: + # Check all params + if request.max_tokens is not None and request.max_tokens <= 0: + return create_error_response( + ErrorCode.PARAM_OUT_OF_RANGE, + f"{request.max_tokens} is less than the minimum of 1 - 'max_tokens'", + ) + if request.n is not None and request.n <= 0: + return create_error_response( + ErrorCode.PARAM_OUT_OF_RANGE, + f"{request.n} is less than the minimum of 1 - 'n'", + ) + if request.temperature is not None and request.temperature < 0: + return create_error_response( + ErrorCode.PARAM_OUT_OF_RANGE, + f"{request.temperature} is less than the minimum of 0 - 'temperature'", + ) + if request.temperature is not None and request.temperature > 2: + return create_error_response( + ErrorCode.PARAM_OUT_OF_RANGE, + f"{request.temperature} is greater than the maximum of 2 - 'temperature'", + ) + if request.top_p is not None and request.top_p < 0: + return create_error_response( + ErrorCode.PARAM_OUT_OF_RANGE, + f"{request.top_p} is less than the minimum of 0 - 'top_p'", + ) + if request.top_p is not None and request.top_p > 1: + return create_error_response( + ErrorCode.PARAM_OUT_OF_RANGE, + f"{request.top_p} is greater than the maximum of 1 - 'temperature'", + ) + if request.stop is not None and ( + not isinstance(request.stop, str) and not isinstance(request.stop, list) + ): + return create_error_response( + ErrorCode.PARAM_OUT_OF_RANGE, + f"{request.stop} is not valid under any of the given schemas - 'stop'", + ) + + return None + + +def process_input(model_name, input): + if isinstance(input, str): + input = [input] + elif isinstance(input, list): + if isinstance(input[0], int): + decoding = tiktoken.model.encoding_for_model(model_name) + input = [decoding.decode(input)] + elif isinstance(input[0], list): + decoding = tiktoken.model.encoding_for_model(model_name) + input = [decoding.decode(text) for text in input] + + return input + + +def get_gen_params( + model_name: str, + messages: Union[str, List[Dict[str, str]]], + *, + temperature: float, + top_p: float, + max_tokens: Optional[int], + echo: Optional[bool], + stream: Optional[bool], + stop: Optional[Union[str, List[str]]], +) -> Dict[str, Any]: + conv = get_conversation_template(model_name) + + if isinstance(messages, str): + prompt = messages + else: + for message in messages: + msg_role = message["role"] + if msg_role == "system": + conv.system = message["content"] + elif msg_role == "user": + conv.append_message(conv.roles[0], message["content"]) + elif msg_role == "assistant": + conv.append_message(conv.roles[1], message["content"]) + else: + raise ValueError(f"Unknown role: {msg_role}") + + # Add a blank message for the assistant. + conv.append_message(conv.roles[1], None) + + is_chatglm = "chatglm" in model_name.lower() + if is_chatglm: + prompt = conv.messages[conv.offset :] + else: + prompt = conv.get_prompt() + + if max_tokens is None: + max_tokens = 512 + + gen_params = { + "model": model_name, + "prompt": prompt, + "temperature": temperature, + "top_p": top_p, + "max_new_tokens": max_tokens, + "echo": echo, + "stream": stream, + } + + if stop is None: + gen_params.update( + {"stop": conv.stop_str, "stop_token_ids": conv.stop_token_ids} + ) + else: + gen_params.update({"stop": stop}) + + logger.debug(f"==== request ====\n{gen_params}") + return gen_params + + +async def _get_worker_address(model_name: str, client: httpx.AsyncClient) -> str: + """ + Get worker address based on the requested model + + :param model_name: The worker's model name + :param client: The httpx client to use + :return: Worker address from the controller + :raises: :class:`ValueError`: No available worker for requested model + """ + controller_address = app_settings.controller_address + + ret = await client.post( + controller_address + "/get_worker_address", json={"model": model_name} + ) + worker_addr = ret.json()["address"] + # No available worker + if worker_addr == "": + raise ValueError(f"No available worker for {model_name}") + + logger.debug(f"model_name: {model_name}, worker_addr: {worker_addr}") + return worker_addr + + +@app.get("/v1/models") +async def show_available_models(): + controller_address = app_settings.controller_address + async with httpx.AsyncClient() as client: + ret = await client.post(controller_address + "/refresh_all_workers") + ret = await client.post(controller_address + "/list_models") + models = ret.json()["models"] + models.sort() + # TODO: return real model permission details + model_cards = [] + for m in models: + model_cards.append(ModelCard(id=m, root=m, permission=[ModelPermission()])) + return ModelList(data=model_cards) + + +# TODO: Have check_length and count_tokens share code. +@app.post("/v1/token_check") +async def count_tokens(request: TokenCheckRequest): + """ + Checks the token count against your message + This is not part of the OpenAI API spec. + """ + async with httpx.AsyncClient() as client: + worker_addr = await _get_worker_address(request.model, client) + + response = await client.post( + worker_addr + "/model_details", + headers=headers, + json={}, + timeout=WORKER_API_TIMEOUT, + ) + context_len = response.json()["context_length"] + + response = await client.post( + worker_addr + "/count_token", + headers=headers, + json={"prompt": request.prompt}, + timeout=WORKER_API_TIMEOUT, + ) + token_num = response.json()["count"] + + can_fit = True + if token_num + request.max_tokens > context_len: + can_fit = False + + return TokenCheckResponse(fits=can_fit, contextLength=context_len, tokenCount=token_num) + + +@app.post("/v1/chat/completions") +async def create_chat_completion(request: ChatCompletionRequest): + """Creates a completion for the chat message""" + error_check_ret = await check_model(request) + if error_check_ret is not None: + return error_check_ret + error_check_ret = check_requests(request) + if error_check_ret is not None: + return error_check_ret + + gen_params = get_gen_params( + request.model, + request.messages, + temperature=request.temperature, + top_p=request.top_p, + max_tokens=request.max_tokens, + echo=False, + stream=request.stream, + stop=request.stop, + ) + error_check_ret = await check_length( + request, gen_params["prompt"], gen_params["max_new_tokens"] + ) + if error_check_ret is not None: + return error_check_ret + + if request.stream: + generator = chat_completion_stream_generator( + request.model, gen_params, request.n + ) + return StreamingResponse(generator, media_type="text/event-stream") + + choices = [] + # TODO: batch the requests. maybe not necessary if using CacheFlow worker + chat_completions = [] + for i in range(request.n): + content = asyncio.create_task(chat_completion(request.model, gen_params)) + chat_completions.append(content) + try: + all_tasks = await asyncio.gather(*chat_completions) + except Exception as e: + return create_error_response(ErrorCode.INTERNAL_ERROR, str(e)) + usage = UsageInfo() + for i, content in enumerate(all_tasks): + if content["error_code"] != 0: + return create_error_response(content["error_code"], content["text"]) + choices.append( + ChatCompletionResponseChoice( + index=i, + message=ChatMessage(role="assistant", content=content["text"]), + finish_reason=content.get("finish_reason", "stop"), + ) + ) + task_usage = UsageInfo.parse_obj(content["usage"]) + for usage_key, usage_value in task_usage.dict().items(): + setattr(usage, usage_key, getattr(usage, usage_key) + usage_value) + + return ChatCompletionResponse(model=request.model, choices=choices, usage=usage) + + +async def chat_completion_stream_generator( + model_name: str, gen_params: Dict[str, Any], n: int +) -> Generator[str, Any, None]: + """ + Event stream format: + https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format + """ + id = f"chatcmpl-{shortuuid.random()}" + finish_stream_events = [] + for i in range(n): + # First chunk with role + choice_data = ChatCompletionResponseStreamChoice( + index=i, + delta=DeltaMessage(role="assistant"), + finish_reason=None, + ) + chunk = ChatCompletionStreamResponse( + id=id, choices=[choice_data], model=model_name + ) + yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n" + + previous_text = "" + async for content in chat_completion_stream(model_name, gen_params): + if content["error_code"] != 0: + yield f"data: {json.dumps(content, ensure_ascii=False)}\n\n" + yield "data: [DONE]\n\n" + return + decoded_unicode = content["text"].replace("\ufffd", "") + delta_text = decoded_unicode[len(previous_text) :] + previous_text = decoded_unicode + + if len(delta_text) == 0: + delta_text = None + choice_data = ChatCompletionResponseStreamChoice( + index=i, + delta=DeltaMessage(content=delta_text), + finish_reason=content.get("finish_reason", None), + ) + chunk = ChatCompletionStreamResponse( + id=id, choices=[choice_data], model=model_name + ) + if delta_text is None: + if content.get("finish_reason", None) is not None: + finish_stream_events.append(chunk) + continue + yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n" + # There is not "content" field in the last delta message, so exclude_none to exclude field "content". + for finish_chunk in finish_stream_events: + yield f"data: {finish_chunk.json(exclude_none=True, ensure_ascii=False)}\n\n" + yield "data: [DONE]\n\n" + + +async def chat_completion_stream(model_name: str, gen_params: Dict[str, Any]): + controller_url = app_settings.controller_address + async with httpx.AsyncClient() as client: + worker_addr = await _get_worker_address(model_name, client) + delimiter = b"\0" + async with client.stream( + "POST", + worker_addr + "/worker_generate_stream", + headers=headers, + json=gen_params, + timeout=WORKER_API_TIMEOUT, + ) as response: + # content = await response.aread() + async for raw_chunk in response.aiter_raw(): + for chunk in raw_chunk.split(delimiter): + if not chunk: + continue + data = json.loads(chunk.decode()) + yield data + + +async def chat_completion( + model_name: str, gen_params: Dict[str, Any] +) -> Optional[Dict[str, Any]]: + async with httpx.AsyncClient() as client: + worker_addr = await _get_worker_address(model_name, client) + + output = None + delimiter = b"\0" + + async with client.stream( + "POST", + worker_addr + "/worker_generate_stream", + headers=headers, + json=gen_params, + timeout=WORKER_API_TIMEOUT, + ) as response: + content = await response.aread() + + for chunk in content.split(delimiter): + if not chunk: + continue + data = json.loads(chunk.decode()) + output = data + + return output + + +@app.post("/v1/completions") +async def create_completion(request: CompletionRequest): + error_check_ret = await check_model(request) + if error_check_ret is not None: + return error_check_ret + error_check_ret = check_requests(request) + if error_check_ret is not None: + return error_check_ret + + request.prompt = process_input(request.model, request.prompt) + + for text in request.prompt: + error_check_ret = await check_length(request, text, request.max_tokens) + if error_check_ret is not None: + return error_check_ret + + if request.stream: + generator = generate_completion_stream_generator(request, request.n) + return StreamingResponse(generator, media_type="text/event-stream") + else: + text_completions = [] + for text in request.prompt: + payload = get_gen_params( + request.model, + text, + temperature=request.temperature, + top_p=request.top_p, + max_tokens=request.max_tokens, + echo=request.echo, + stream=request.stream, + stop=request.stop, + ) + for i in range(request.n): + content = asyncio.create_task(generate_completion(payload)) + text_completions.append(content) + + try: + all_tasks = await asyncio.gather(*text_completions) + except Exception as e: + return create_error_response(ErrorCode.INTERNAL_ERROR, str(e)) + + choices = [] + usage = UsageInfo() + for i, content in enumerate(all_tasks): + if content["error_code"] != 0: + return create_error_response(content["error_code"], content["text"]) + choices.append( + CompletionResponseChoice( + index=i, + text=content["text"], + logprobs=content.get("logprobs", None), + finish_reason=content.get("finish_reason", "stop"), + ) + ) + task_usage = UsageInfo.parse_obj(content["usage"]) + for usage_key, usage_value in task_usage.dict().items(): + setattr(usage, usage_key, getattr(usage, usage_key) + usage_value) + + return CompletionResponse( + model=request.model, choices=choices, usage=UsageInfo.parse_obj(usage) + ) + + +async def generate_completion_stream_generator(request: CompletionRequest, n: int): + model_name = request.model + id = f"cmpl-{shortuuid.random()}" + finish_stream_events = [] + for text in request.prompt: + for i in range(n): + previous_text = "" + payload = get_gen_params( + request.model, + text, + temperature=request.temperature, + top_p=request.top_p, + max_tokens=request.max_tokens, + echo=request.echo, + stream=request.stream, + stop=request.stop, + ) + async for content in generate_completion_stream(payload): + if content["error_code"] != 0: + yield f"data: {json.dumps(content, ensure_ascii=False)}\n\n" + yield "data: [DONE]\n\n" + return + decoded_unicode = content["text"].replace("\ufffd", "") + delta_text = decoded_unicode[len(previous_text) :] + previous_text = decoded_unicode + # todo: index is not apparent + choice_data = CompletionResponseStreamChoice( + index=i, + text=delta_text, + logprobs=content.get("logprobs", None), + finish_reason=content.get("finish_reason", None), + ) + chunk = CompletionStreamResponse( + id=id, + object="text_completion", + choices=[choice_data], + model=model_name, + ) + if len(delta_text) == 0: + if content.get("finish_reason", None) is not None: + finish_stream_events.append(chunk) + continue + yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n" + # There is not "content" field in the last delta message, so exclude_none to exclude field "content". + for finish_chunk in finish_stream_events: + yield f"data: {finish_chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n" + yield "data: [DONE]\n\n" + + +async def generate_completion_stream(payload: Dict[str, Any]): + controller_address = app_settings.controller_address + async with httpx.AsyncClient() as client: + worker_addr = await _get_worker_address(payload["model"], client) + + delimiter = b"\0" + async with client.stream( + "POST", + worker_addr + "/worker_generate_completion_stream", + headers=headers, + json=payload, + timeout=WORKER_API_TIMEOUT, + ) as response: + # content = await response.aread() + async for raw_chunk in response.aiter_raw(): + for chunk in raw_chunk.split(delimiter): + if not chunk: + continue + data = json.loads(chunk.decode()) + yield data + + +async def generate_completion(payload: Dict[str, Any]): + controller_address = app_settings.controller_address + async with httpx.AsyncClient() as client: + worker_addr = await _get_worker_address(payload["model"], client) + + response = await client.post( + worker_addr + "/worker_generate_completion", + headers=headers, + json=payload, + timeout=WORKER_API_TIMEOUT, + ) + completion = response.json() + return completion + + +@app.post("/v1/embeddings") +@app.post("/v1/engines/{model_name}/embeddings") +async def create_embeddings(request: EmbeddingsRequest, model_name: str = None): + """Creates embeddings for the text""" + if request.model is None: + request.model = model_name + error_check_ret = await check_model(request) + if error_check_ret is not None: + return error_check_ret + + request.input = process_input(request.model, request.input) + + data = [] + token_num = 0 + batch_size = WORKER_API_EMBEDDING_BATCH_SIZE + batches = [ + request.input[i : min(i + batch_size, len(request.input))] + for i in range(0, len(request.input), batch_size) + ] + for num_batch, batch in enumerate(batches): + payload = { + "model": request.model, + "input": batch, + } + embedding = await get_embedding(payload) + data += [ + { + "object": "embedding", + "embedding": emb, + "index": num_batch * batch_size + i, + } + for i, emb in enumerate(embedding["embedding"]) + ] + token_num += embedding["token_num"] + return EmbeddingsResponse( + data=data, + model=request.model, + usage=UsageInfo( + prompt_tokens=token_num, + total_tokens=token_num, + completion_tokens=None, + ), + ).dict(exclude_none=True) + + +async def get_embedding(payload: Dict[str, Any]): + controller_address = app_settings.controller_address + model_name = payload["model"] + async with httpx.AsyncClient() as client: + worker_addr = await _get_worker_address(model_name, client) + + response = await client.post( + worker_addr + "/worker_get_embeddings", + headers=headers, + json=payload, + timeout=WORKER_API_TIMEOUT, + ) + embedding = response.json() + return embedding + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="FastChat ChatGPT-Compatible RESTful API server." + ) + parser.add_argument("--host", type=str, default="localhost", help="host name") + parser.add_argument("--port", type=int, default=8000, help="port number") + parser.add_argument( + "--controller-address", type=str, default="http://localhost:21001" + ) + parser.add_argument( + "--allow-credentials", action="store_true", help="allow credentials" + ) + parser.add_argument( + "--allowed-origins", type=json.loads, default=["*"], help="allowed origins" + ) + parser.add_argument( + "--allowed-methods", type=json.loads, default=["*"], help="allowed methods" + ) + parser.add_argument( + "--allowed-headers", type=json.loads, default=["*"], help="allowed headers" + ) + args = parser.parse_args() + + app.add_middleware( + CORSMiddleware, + allow_origins=args.allowed_origins, + allow_credentials=args.allow_credentials, + allow_methods=args.allowed_methods, + allow_headers=args.allowed_headers, + ) + app_settings.controller_address = args.controller_address + + logger.info(f"args: {args}") + + uvicorn.run(app, host=args.host, port=args.port, log_level="info") diff --git a/graphgpt/serve/register_worker.py b/graphgpt/serve/register_worker.py new file mode 100644 index 0000000..2c2c402 --- /dev/null +++ b/graphgpt/serve/register_worker.py @@ -0,0 +1,26 @@ +""" +Manually register workers. + +Usage: +python3 -m fastchat.serve.register_worker --controller http://localhost:21001 --worker-name http://localhost:21002 +""" + +import argparse + +import requests + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--controller-address", type=str) + parser.add_argument("--worker-name", type=str) + parser.add_argument("--check-heart-beat", action="store_true") + args = parser.parse_args() + + url = args.controller_address + "/register_worker" + data = { + "worker_name": args.worker_name, + "check_heart_beat": args.check_heart_beat, + "worker_status": None, + } + r = requests.post(url, json=data) + assert r.status_code == 200 diff --git a/graphgpt/serve/test_message.py b/graphgpt/serve/test_message.py new file mode 100644 index 0000000..203a449 --- /dev/null +++ b/graphgpt/serve/test_message.py @@ -0,0 +1,81 @@ +"""Send a test message.""" +import argparse +import json + +import requests + +from fastchat.model.model_adapter import get_conversation_template + + +def main(): + model_name = args.model_name + + if args.worker_address: + worker_addr = args.worker_address + else: + controller_addr = args.controller_address + ret = requests.post(controller_addr + "/refresh_all_workers") + ret = requests.post(controller_addr + "/list_models") + models = ret.json()["models"] + models.sort() + print(f"Models: {models}") + + ret = requests.post( + controller_addr + "/get_worker_address", json={"model": model_name} + ) + worker_addr = ret.json()["address"] + print(f"worker_addr: {worker_addr}") + + if worker_addr == "": + print(f"No available workers for {model_name}") + return + + conv = get_conversation_template(model_name) + conv.append_message(conv.roles[0], args.message) + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + + headers = {"User-Agent": "FastChat Client"} + gen_params = { + "model": model_name, + "prompt": prompt, + "temperature": args.temperature, + "max_new_tokens": args.max_new_tokens, + "stop": conv.stop_str, + "stop_token_ids": conv.stop_token_ids, + "echo": False, + } + response = requests.post( + worker_addr + "/worker_generate_stream", + headers=headers, + json=gen_params, + stream=True, + ) + + print(f"{conv.roles[0]}: {args.message}") + print(f"{conv.roles[1]}: ", end="") + prev = 0 + for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): + if chunk: + data = json.loads(chunk.decode()) + output = data["text"].strip() + print(output[prev:], end="", flush=True) + prev = len(output) + print("") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--controller-address", type=str, default="http://localhost:21001" + ) + parser.add_argument("--worker-address", type=str) + parser.add_argument("--model-name", type=str, required=True) + parser.add_argument("--temperature", type=float, default=0.0) + parser.add_argument("--max-new-tokens", type=int, default=32) + parser.add_argument( + "--message", type=str, default="Tell me a story with more than 1000 words." + ) + args = parser.parse_args() + + main() diff --git a/graphgpt/serve/test_throughput.py b/graphgpt/serve/test_throughput.py new file mode 100644 index 0000000..9cc5f45 --- /dev/null +++ b/graphgpt/serve/test_throughput.py @@ -0,0 +1,115 @@ +"""Benchmarking script to test the throughput of serving workers.""" +import argparse +import json + +import requests +import threading +import time + +from fastchat.conversation import default_conversation + + +def main(): + if args.worker_address: + worker_addr = args.worker_address + else: + controller_addr = args.controller_address + ret = requests.post(controller_addr + "/refresh_all_workers") + ret = requests.post(controller_addr + "/list_models") + models = ret.json()["models"] + models.sort() + print(f"Models: {models}") + + ret = requests.post( + controller_addr + "/get_worker_address", json={"model": args.model_name} + ) + worker_addr = ret.json()["address"] + print(f"worker_addr: {worker_addr}") + + if worker_addr == "": + return + + conv = default_conversation.copy() + conv.append_message(conv.roles[0], "Tell me a story with more than 1000 words") + prompt_template = conv.get_prompt() + prompts = [prompt_template for _ in range(args.n_thread)] + + headers = {"User-Agent": "fastchat Client"} + ploads = [ + { + "model": args.model_name, + "prompt": prompts[i], + "max_new_tokens": args.max_new_tokens, + "temperature": 0.0, + # "stop": conv.sep, + } + for i in range(len(prompts)) + ] + + def send_request(results, i): + if args.test_dispatch: + ret = requests.post( + controller_addr + "/get_worker_address", json={"model": args.model_name} + ) + thread_worker_addr = ret.json()["address"] + else: + thread_worker_addr = worker_addr + print(f"thread {i} goes to {thread_worker_addr}") + response = requests.post( + thread_worker_addr + "/worker_generate_stream", + headers=headers, + json=ploads[i], + stream=False, + ) + k = list( + response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0") + ) + # print(k) + response_new_words = json.loads(k[-2].decode("utf-8"))["text"] + error_code = json.loads(k[-2].decode("utf-8"))["error_code"] + # print(f"=== Thread {i} ===, words: {1}, error code: {error_code}") + results[i] = len(response_new_words.split(" ")) - len(prompts[i].split(" ")) + + # use N threads to prompt the backend + tik = time.time() + threads = [] + results = [None] * args.n_thread + for i in range(args.n_thread): + t = threading.Thread(target=send_request, args=(results, i)) + t.start() + # time.sleep(0.5) + threads.append(t) + + for t in threads: + t.join() + + print(f"Time (POST): {time.time() - tik} s") + # n_words = 0 + # for i, response in enumerate(results): + # # print(prompt[i].replace(conv.sep, "\n"), end="") + # # make sure the streaming finishes at EOS or stopping criteria + # k = list(response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0")) + # response_new_words = json.loads(k[-2].decode("utf-8"))["text"] + # # print(response_new_words) + # n_words += len(response_new_words.split(" ")) - len(prompts[i].split(" ")) + n_words = sum(results) + time_seconds = time.time() - tik + print( + f"Time (Completion): {time_seconds}, n threads: {args.n_thread}, " + f"throughput: {n_words / time_seconds} words/s." + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--controller-address", type=str, default="http://localhost:21001" + ) + parser.add_argument("--worker-address", type=str) + parser.add_argument("--model-name", type=str, default="vicuna") + parser.add_argument("--max-new-tokens", type=int, default=2048) + parser.add_argument("--n-thread", type=int, default=8) + parser.add_argument("--test-dispatch", action="store_true") + args = parser.parse_args() + + main() diff --git a/graphgpt/train/graphchat_trainer.py b/graphgpt/train/graphchat_trainer.py new file mode 100644 index 0000000..11e4b62 --- /dev/null +++ b/graphgpt/train/graphchat_trainer.py @@ -0,0 +1,49 @@ +import os +import torch +import torch.nn as nn + +from transformers import Trainer +from typing import Dict, Optional, Sequence + + +def unwrap_model(model: nn.Module) -> nn.Module: + """ + Recursively unwraps a model from potential containers (as used in distributed training). + + Args: + model (`torch.nn.Module`): The model to unwrap. + """ + # since there could be multiple levels of wrapping, unwrap recursively + if hasattr(model, "module"): + return unwrap_model(model.module) + else: + return model + + +class GraphChatTrainer(Trainer): + + def _save(self, output_dir: Optional[str] = None, state_dict=None): + if getattr(self.args, 'tune_graph_mlp_adapter', False): + # Save the model + _state_dict = state_dict + if _state_dict is None: + # Only save the model itself if we are using distributed training + model_to_save = unwrap_model(self.model) + _state_dict = model_to_save.state_dict() + + weight_to_save = {} + keys_to_match = ['graph_projector', 'embed_tokens', 'embed_in'] + for k, v in _state_dict.items(): + if any(key_match in k for key_match in keys_to_match): + weight_to_save[k] = v + + current_folder = output_dir.split('/')[-1] + parent_folder = os.path.dirname(output_dir) + if current_folder.startswith('checkpoint-'): + mm_projector_folder = os.path.join(parent_folder, "graph_projector") + os.makedirs(mm_projector_folder, exist_ok=True) + torch.save(weight_to_save, os.path.join(mm_projector_folder, f'{current_folder}.bin')) + else: + torch.save(weight_to_save, os.path.join(output_dir, f'graph_projector.bin')) + + super(GraphChatTrainer, self)._save(output_dir, state_dict) diff --git a/graphgpt/train/llama_flash_attn_monkey_patch.py b/graphgpt/train/llama_flash_attn_monkey_patch.py new file mode 100644 index 0000000..00fc39e --- /dev/null +++ b/graphgpt/train/llama_flash_attn_monkey_patch.py @@ -0,0 +1,114 @@ +from typing import List, Optional, Tuple + +import torch +from torch import nn + +import transformers +from transformers.models.llama.modeling_llama import apply_rotary_pos_emb + +from einops import rearrange + +from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func +from flash_attn.bert_padding import unpad_input, pad_input + + +def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel + + attention_mask: [bsz, q_len] + """ + bsz, q_len, _ = hidden_states.size() + + query_states = ( + self.q_proj(hidden_states) + .view(bsz, q_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + key_states = ( + self.k_proj(hidden_states) + .view(bsz, q_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + value_states = ( + self.v_proj(hidden_states) + .view(bsz, q_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + # [bsz, q_len, nh, hd] + # [bsz, nh, q_len, hd] + + kv_seq_len = key_states.shape[-2] + assert past_key_value is None, "past_key_value is not supported" + + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin, position_ids + ) + # [bsz, nh, t, hd] + assert not output_attentions, "output_attentions is not supported" + assert not use_cache, "use_cache is not supported" + + # Flash attention codes from + # https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py + + # transform the data into the format required by flash attention + qkv = torch.stack( + [query_states, key_states, value_states], dim=2 + ) # [bsz, nh, 3, q_len, hd] + qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd] + # We have disabled _prepare_decoder_attention_mask in LlamaModel + # the attention_mask should be the same as the key_padding_mask + key_padding_mask = attention_mask + + if key_padding_mask is None: + qkv = rearrange(qkv, "b s ... -> (b s) ...") + max_s = q_len + cu_q_lens = torch.arange( + 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device + ) + output = flash_attn_unpadded_qkvpacked_func( + qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True + ) + output = rearrange(output, "(b s) ... -> b s ...", b=bsz) + else: + nheads = qkv.shape[-2] + x = rearrange(qkv, "b s three h d -> b s (three h d)") + x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask) + x_unpad = rearrange( + x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads + ) + output_unpad = flash_attn_unpadded_qkvpacked_func( + x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True + ) + output = rearrange( + pad_input( + rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz, q_len + ), + "b s (h d) -> b s h d", + h=nheads, + ) + return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, None + + +# Disable the transformation of the attention mask in LlamaModel as the flash attention +# requires the attention mask to be the same as the key_padding_mask +def _prepare_decoder_attention_mask( + self, attention_mask, input_shape, inputs_embeds, past_key_values_length +): + # [bsz, seq_len] + return attention_mask + + +def replace_llama_attn_with_flash_attn(): + transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( + _prepare_decoder_attention_mask + ) + transformers.models.llama.modeling_llama.LlamaAttention.forward = forward diff --git a/graphgpt/train/train_graph.py b/graphgpt/train/train_graph.py new file mode 100644 index 0000000..00d0681 --- /dev/null +++ b/graphgpt/train/train_graph.py @@ -0,0 +1,966 @@ +# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright: +# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright: +# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import copy +from dataclasses import dataclass, field +import json +import logging +import pathlib +from typing import Dict, Optional, Sequence, List + +import torch + +import transformers +from torch.utils.data import Dataset +from graphgpt.train.graphchat_trainer import GraphChatTrainer + +from graphgpt import conversation as conversation_lib +from graphgpt.model import * + +from PIL import Image +import torch.nn as nn +from torch_geometric.data import Data + +# TODO: import and use code from ../data/dataset.py + +IGNORE_INDEX = -100 +DEFAULT_PAD_TOKEN = "[PAD]" +DEFAULT_EOS_TOKEN = "" +DEFAULT_BOS_TOKEN = "" +DEFAULT_UNK_TOKEN = "" +DEFAULT_GRAPH_TOKEN = "" +DEFAULT_GRAPH_PATCH_TOKEN = "" +DEFAULT_G_START_TOKEN = "" +DEFAULT_G_END_TOKEN = "" + + +@dataclass +class ModelArguments: + model_name_or_path: Optional[str] = field(default="facebook/opt-125m") + version: Optional[str] = field(default="v0") + freeze_backbone: bool = field(default=False) + tune_graph_mlp_adapter: bool = field(default=False) + graph_tower: Optional[str] = field(default=None) + graph_select_layer: Optional[int] = field(default=-1) # default to the last layer + pretrain_graph_mlp_adapter: Optional[str] = field(default=None) + use_graph_start_end: bool = field(default=False) + + +@dataclass +class DataArguments: + data_path: str = field(default=None, + metadata={"help": "Path to the training data."}) + lazy_preprocess: bool = False + is_graph: bool = False + sep_graph_conv_front: bool = False + graph_token_len: int = 0 + graph_content: Optional[str] = field(default=None) + graph_data_path: Optional[str] = field(default=None) + image_aspect_ratio: str = 'square' + + +@dataclass +class TrainingArguments(transformers.TrainingArguments): + cache_dir: Optional[str] = field(default=None) + optim: str = field(default="adamw_torch") + remove_unused_columns: bool = field(default=False) + freeze_graph_mlp_adapter: bool = field(default=False) + force_fsdp: bool = field(default=False) + model_max_length: int = field( + default=512, + metadata={ + "help": + "Maximum sequence length. Sequences will be right padded (and possibly truncated)." + }, + ) + double_quant: bool = field( + default=True, + metadata={"help": "Compress the quantization statistics through double quantization."} + ) + quant_type: str = field( + default="nf4", + metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."} + ) + bits: int = field( + default=16, + metadata={"help": "How many bits to use."} + ) + lora_enable: bool = False + lora_r: int = 64 + lora_alpha: int = 16 + lora_dropout: float = 0.05 + lora_weight_path: str = "" + lora_bias: str = "none" + disable_tqdm: bool =False + + +def maybe_zero_3(param, ignore_status=False, name=None): + from deepspeed import zero + from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus + if hasattr(param, "ds_id"): + if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: + if not ignore_status: + logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}") + with zero.GatheredParameters([param]): + param = param.data.detach().cpu().clone() + else: + param = param.detach().cpu().clone() + return param + + +# Borrowed from peft.utils.get_peft_model_state_dict +def get_peft_state_maybe_zero_3(named_params, bias): + if bias == "none": + to_return = {k: t for k, t in named_params if "lora_" in k} + elif bias == "all": + to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k} + elif bias == "lora_only": + to_return = {} + maybe_lora_bias = {} + lora_bias_names = set() + for k, t in named_params: + if "lora_" in k: + to_return[k] = t + bias_name = k.split("lora_")[0] + "bias" + lora_bias_names.add(bias_name) + elif "bias" in k: + maybe_lora_bias[k] = t + for k, t in maybe_lora_bias: + if bias_name in lora_bias_names: + to_return[bias_name] = t + else: + raise NotImplementedError + to_return = {k: maybe_zero_3(v, name=k) for k, v in to_return.items()} + return to_return + + +def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True): + to_return = {k: t for k, t in named_params if "lora_" not in k} + if require_grad_only: + to_return = {k: t for k, t in to_return.items() if t.requires_grad} + to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()} + return to_return + + +def find_all_linear_names(model): + cls = torch.nn.Linear + lora_module_names = set() + for name, module in model.named_modules(): + if isinstance(module, cls): + names = name.split('.') + lora_module_names.add(names[0] if len(names) == 1 else names[-1]) + + + if 'lm_head' in lora_module_names: # needed for 16-bit + lora_module_names.remove('lm_head') + return list(lora_module_names) + + +def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, + output_dir: str): + """Collects the state dict and dump to disk.""" + if trainer.deepspeed: + torch.cuda.synchronize() + trainer.save_model(output_dir) + return + + state_dict = trainer.model.state_dict() + if trainer.args.should_save: + cpu_state_dict = { + key: value.cpu() + for key, value in state_dict.items() + } + del state_dict + trainer._save(output_dir, state_dict=cpu_state_dict) # noqa + + +def smart_tokenizer_and_embedding_resize( + special_tokens_dict: Dict, + tokenizer: transformers.PreTrainedTokenizer, + model: transformers.PreTrainedModel, +): + """Resize tokenizer and embedding. + + Note: This is the unoptimized version that may make your embedding size not be divisible by 64. + """ + num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) + model.resize_token_embeddings(len(tokenizer)) + + if num_new_tokens > 0: + input_embeddings = model.get_input_embeddings().weight.data + output_embeddings = model.get_output_embeddings().weight.data + + input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( + dim=0, keepdim=True) + output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( + dim=0, keepdim=True) + + input_embeddings[-num_new_tokens:] = input_embeddings_avg + output_embeddings[-num_new_tokens:] = output_embeddings_avg + + +def _tokenize_fn(strings: Sequence[str], + tokenizer: transformers.PreTrainedTokenizer) -> Dict: + """Tokenize a list of strings.""" + tokenized_list = [ + tokenizer( + text, + return_tensors="pt", + padding="longest", + max_length=tokenizer.model_max_length, + truncation=True, + ) for text in strings + ] + input_ids = labels = [ + tokenized.input_ids[0] for tokenized in tokenized_list + ] + input_ids_lens = labels_lens = [ + tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() + for tokenized in tokenized_list + ] + return dict( + input_ids=input_ids, + labels=labels, + input_ids_lens=input_ids_lens, + labels_lens=labels_lens, + ) + + +def _mask_targets(target, tokenized_lens, speakers): + # cur_idx = 0 + cur_idx = tokenized_lens[0] + tokenized_lens = tokenized_lens[1:] + target[:cur_idx] = IGNORE_INDEX + for tokenized_len, speaker in zip(tokenized_lens, speakers): + if speaker == "human": + target[cur_idx+2:cur_idx + tokenized_len] = IGNORE_INDEX + cur_idx += tokenized_len + + +def _add_speaker_and_signal(header, source, get_conversation=True): + """Add speaker and start/end signal on each round.""" + BEGIN_SIGNAL = "### " + END_SIGNAL = "\n" + conversation = header + for sentence in source: + from_str = sentence["from"] + if from_str.lower() == "human": + from_str = conversation_lib.default_conversation.roles[0] + elif from_str.lower() == "gpt": + from_str = conversation_lib.default_conversation.roles[1] + else: + from_str = 'unknown' + sentence["value"] = (BEGIN_SIGNAL + from_str + ": " + + sentence["value"] + END_SIGNAL) + if get_conversation: + conversation += sentence["value"] + conversation += BEGIN_SIGNAL + return conversation + + +def preprocess_graph( + sources: Sequence[str], + graph_cfg: dict, + cur_token_len: int, +) -> Dict: + is_graph = graph_cfg['is_graph'] + # image_token_len = multimodal_cfg['image_token_len'] + graph_token_len = cur_token_len + if not is_graph: + return sources + + for source in sources: + if graph_cfg['sep_graph_conv_front']: + assert DEFAULT_GRAPH_TOKEN in source[0]['value'] + source[0]['value'] = source[0]['value'].replace(DEFAULT_GRAPH_TOKEN, '').strip() + source[0]['value'] = DEFAULT_GRAPH_TOKEN + conversation_lib.default_conversation.sep + conversation_lib.default_conversation.roles[0] + ": " + source[0]['value'] + for sentence in source: + replace_token = DEFAULT_GRAPH_PATCH_TOKEN * graph_token_len + if graph_cfg['use_graph_start_end']: + replace_token = DEFAULT_G_START_TOKEN + replace_token + DEFAULT_G_END_TOKEN + sentence["value"] = sentence["value"].replace(DEFAULT_GRAPH_TOKEN, replace_token) + + return sources + +def preprocess_graph_LP( + sources: Sequence[str], + graph_cfg: dict, + cur_token_len_1: int, + cur_token_len_2: int, +) -> Dict: + is_graph = graph_cfg['is_graph'] + # image_token_len = multimodal_cfg['image_token_len'] + graph_token_len_1 = cur_token_len_1 + graph_token_len_2 = cur_token_len_2 + + if not is_graph: + return sources + + for source in sources: + if graph_cfg['sep_graph_conv_front']: + assert DEFAULT_GRAPH_TOKEN in source[0]['value'] + source[0]['value'] = source[0]['value'].replace(DEFAULT_GRAPH_TOKEN, '').strip() + source[0]['value'] = DEFAULT_GRAPH_TOKEN + conversation_lib.default_conversation.sep + conversation_lib.default_conversation.roles[0] + ": " + source[0]['value'] + for sentence in source: + replace_token_1 = DEFAULT_GRAPH_PATCH_TOKEN * graph_token_len_1 + replace_token_2 = DEFAULT_GRAPH_PATCH_TOKEN * graph_token_len_2 + if graph_cfg['use_graph_start_end']: + replace_token_1 = DEFAULT_G_START_TOKEN + replace_token_1 + DEFAULT_G_END_TOKEN + replace_token_2 = DEFAULT_G_START_TOKEN + replace_token_2 + DEFAULT_G_END_TOKEN + + if DEFAULT_GRAPH_TOKEN in sentence["value"]: + first_index = sentence["value"].find(DEFAULT_GRAPH_TOKEN) + sentence["value"] = sentence["value"][:first_index] + replace_token_1 + sentence["value"][first_index+len(DEFAULT_GRAPH_TOKEN):] + + # 替换第二个为B + second_index = sentence["value"].find(DEFAULT_GRAPH_TOKEN) + sentence["value"] = sentence["value"][:second_index] + replace_token_2 + sentence["value"][second_index+len(DEFAULT_GRAPH_TOKEN):] + + + # sentence["value"] = sentence["value"].replace(DEFAULT_GRAPH_TOKEN, replace_token) + + # print(sources) + + return sources + + +def preprocess_v1( + sources, + tokenizer: transformers.PreTrainedTokenizer, +) -> Dict: + conv = conversation_lib.default_conversation.copy() + roles = {"human": conv.roles[0], "gpt": conv.roles[1]} + + # Apply prompt templates + conversations = [] + for i, source in enumerate(sources): + if roles[source[0]["from"]] != conv.roles[0]: + # Skip the first one if it is not from human + source = source[1:] + + conv.messages = [] + for j, sentence in enumerate(source): + role = roles[sentence["from"]] + assert role == conv.roles[j % 2], f"{i}" + conv.append_message(role, sentence["value"]) + conversations.append(conv.get_prompt()) + + # Tokenize conversations + input_ids = tokenizer( + conversations, + return_tensors="pt", + padding="longest", + max_length=tokenizer.model_max_length, + truncation=True, + ).input_ids + targets = input_ids.clone() + + assert conv.sep_style == conversation_lib.SeparatorStyle.TWO + + # Mask targets + sep = conv.sep + conv.roles[1] + ": " + for conversation, target in zip(conversations, targets): + total_len = int(target.ne(tokenizer.pad_token_id).sum()) + + rounds = conversation.split(conv.sep2) + cur_len = 1 + target[:cur_len] = IGNORE_INDEX + for i, rou in enumerate(rounds): + if rou == "": + break + + parts = rou.split(sep) + if len(parts) != 2: + break + parts[0] += sep + round_len = len(tokenizer(rou).input_ids) + instruction_len = len(tokenizer(parts[0]).input_ids) - 2 + + target[cur_len : cur_len + instruction_len] = IGNORE_INDEX + + cur_len += round_len + target[cur_len:] = IGNORE_INDEX + + if cur_len < tokenizer.model_max_length: + if cur_len != total_len: + target[:] = IGNORE_INDEX + print( + f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." + f" (ignored)" + ) + + return dict( + input_ids=input_ids, + labels=targets, + ) + +def preprocess_mpt( + sources, + tokenizer: transformers.PreTrainedTokenizer, +) -> Dict: + conv = conversation_lib.default_conversation.copy() + roles = {"human": conv.roles[0], "gpt": conv.roles[1]} + + # Apply prompt templates + conversations = [] + for i, source in enumerate(sources): + if roles[source[0]["from"]] != conv.roles[0]: + # Skip the first one if it is not from human + source = source[1:] + + conv.messages = [] + for j, sentence in enumerate(source): + role = roles[sentence["from"]] + assert role == conv.roles[j % 2], f"{i}" + conv.append_message(role, sentence["value"]) + conversations.append(conv.get_prompt()) + + # Tokenize conversations + input_ids = tokenizer( + conversations, + return_tensors="pt", + padding="longest", + max_length=tokenizer.model_max_length, + truncation=True, + ).input_ids + targets = input_ids.clone() + assert conv.sep_style == conversation_lib.SeparatorStyle.MPT + + # Mask targets + sep = conv.sep + conv.roles[1] + for conversation, target in zip(conversations, targets): + total_len = int(target.ne(tokenizer.pad_token_id).sum()) + + rounds = conversation.split(conv.sep) + re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt + for conv_idx in range(3, len(rounds), 2): + re_rounds.append(conv.sep.join(rounds[conv_idx:conv_idx+2])) # user + gpt + cur_len = 0 + target[:cur_len] = IGNORE_INDEX + for i, rou in enumerate(re_rounds): + if rou == "": + break + + parts = rou.split(sep) + if len(parts) != 2: + break + parts[0] += sep + round_len = len(tokenizer(rou).input_ids) + len(tokenizer(conv.sep).input_ids) + instruction_len = len(tokenizer(parts[0]).input_ids) + target[cur_len : cur_len + instruction_len] = IGNORE_INDEX + + cur_len += round_len + target[cur_len:] = IGNORE_INDEX + + if cur_len < tokenizer.model_max_length: + if cur_len != total_len: + target[:] = IGNORE_INDEX + print( + f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." + f" (ignored)" + ) + + return dict( + input_ids=input_ids, + labels=targets, + ) + + +def preprocess( + sources: Sequence[str], + tokenizer: transformers.PreTrainedTokenizer, +) -> Dict: + """ + Given a list of sources, each is a conversation list. This transform: + 1. Add signal '### ' at the beginning each sentence, with end signal '\n'; + 2. Concatenate conversations together; + 3. Tokenize the concatenated conversation; + 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX. + """ + if conversation_lib.default_conversation.version == "v1": + return preprocess_v1(sources, tokenizer) + if conversation_lib.default_conversation.version == "mpt": + return preprocess_mpt(sources, tokenizer) + # add end signal and concatenate together + conversations = [] + for source in sources: + header = f"{conversation_lib.default_conversation.system}\n\n" + conversation = _add_speaker_and_signal(header, source) + conversations.append(conversation) + # tokenize conversations + conversations_tokenized = _tokenize_fn(conversations, tokenizer) + input_ids = conversations_tokenized["input_ids"] + targets = copy.deepcopy(input_ids) + for target, source in zip(targets, sources): + tokenized_lens = _tokenize_fn([header] + [s["value"] for s in source], + tokenizer)["input_ids_lens"] + speakers = [sentence["from"] for sentence in source] + _mask_targets(target, tokenized_lens, speakers) + + return dict(input_ids=input_ids, labels=targets) + + +class SupervisedDataset(Dataset): + """Dataset for supervised fine-tuning.""" + + def __init__(self, data_path: str, + tokenizer: transformers.PreTrainedTokenizer): + super(SupervisedDataset, self).__init__() + logging.warning("Loading data...") + list_data_dict = json.load(open(data_path, "r")) + + logging.warning("Formatting inputs...") + sources = [example["conversations"] for example in list_data_dict] + data_dict = preprocess(sources, tokenizer) + + self.input_ids = data_dict["input_ids"] + self.labels = data_dict["labels"] + + def __len__(self): + return len(self.input_ids) + + def __getitem__(self, i) -> Dict[str, torch.Tensor]: + return dict(input_ids=self.input_ids[i], labels=self.labels[i]) + + +class LazySupervisedDataset(Dataset): + """Dataset for supervised fine-tuning.""" + + def __init__(self, data_path: str, + tokenizer: transformers.PreTrainedTokenizer, + graph_cfg: dict, + **kwargs,): + super(LazySupervisedDataset, self).__init__() + logging.warning("Loading data...") + list_data_dict = json.load(open(data_path, "r")) + + logging.warning("Formatting inputs...Skip in lazy mode") + self.tokenizer = tokenizer + self.list_data_dict = list_data_dict + self.graph_cfg = graph_cfg + graph_data_path = kwargs.get('graph_data_path') + self.graph_data_all = torch.load(graph_data_path) + + def __len__(self): + return len(self.list_data_dict) + + def __getitem__(self, i) -> Dict[str, torch.Tensor]: + sources = self.list_data_dict[i] + if isinstance(i, int): + sources = [sources] + assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME + + task_type = self.list_data_dict[i]['id'].split("_")[-1] + if task_type != 'LP': + if 'graph' in sources[0]: + graph_dict = self.list_data_dict[i]['graph'] + graph_edge_index = torch.Tensor(copy.deepcopy(graph_dict['edge_index'])).long() + graph_node_list = copy.deepcopy(graph_dict['node_list']) + target_node = copy.deepcopy(graph_dict['node_idx']) + graph_type = copy.deepcopy(self.list_data_dict[i]['id']).split('_')[0] + graph_node_rep = self.graph_data_all[graph_type].x[graph_node_list] ## + + cur_token_len = len(graph_node_rep) # FIXME: 14 is hardcoded patch size + sources = preprocess_graph( + copy.deepcopy([e["conversations"] for e in sources]), + self.graph_cfg, cur_token_len) + else: + sources = copy.deepcopy([e["conversations"] for e in sources]) + else: + if 'graph' in sources[0]: + graph_dict = self.list_data_dict[i]['graph'] + graph_edge_index_1 = torch.Tensor(copy.deepcopy(graph_dict['edge_index_1'])).long() + graph_node_list_1 = copy.deepcopy(graph_dict['node_list_1']) + target_node_1 = copy.deepcopy(graph_dict['node_idx_1']) + graph_type = copy.deepcopy(self.list_data_dict[i]['id']).split('_')[0] + graph_node_rep_1 = self.graph_data_all[graph_type].x[graph_node_list_1] ## + + cur_token_len_1 = len(graph_node_rep_1) # FIXME: 14 is hardcoded patch size + + graph_edge_index_2 = torch.Tensor(copy.deepcopy(graph_dict['edge_index_2'])).long() + graph_node_list_2 = copy.deepcopy(graph_dict['node_list_2']) + target_node_2 = copy.deepcopy(graph_dict['node_idx_2']) + graph_node_rep_2 = self.graph_data_all[graph_type].x[graph_node_list_2] ## + + cur_token_len_2 = len(graph_node_rep_2) # FIXME: 14 is hardcoded patch size + sources = preprocess_graph_LP( + copy.deepcopy([e["conversations"] for e in sources]), + self.graph_cfg, cur_token_len_1, cur_token_len_2) + else: + sources = copy.deepcopy([e["conversations"] for e in sources]) + data_dict = preprocess( + sources, + self.tokenizer) + if isinstance(i, int): + data_dict = dict(input_ids=data_dict["input_ids"][0], + labels=data_dict["labels"][0]) + + # image exist in the data + if task_type != 'LP': + if 'graph' in self.list_data_dict[i]: + # data_dict['graph_node'] = graph_node_rep + # data_dict['graph_edge'] = graph_edge_index + # data_dict['target_node'] = target_node + data_dict['graph_data'] = Data(graph_node = graph_node_rep, edge_index=graph_edge_index, target_node = torch.tensor([target_node])) + + elif self.graph_cfg['is_graph']: + # image does not exist in the data, but the model is multimodal + node_feas = self.graph_cfg['graph_processor'].node_feas + data_dict['graph_data'] = Data(graph_node = torch.zeros(3, node_feas), edge_index=torch.zeros(2, 3), target_node = torch.tensor([0])) + else: + if 'graph' in self.list_data_dict[i]: + # data_dict['graph_node'] = graph_node_rep + # data_dict['graph_edge'] = graph_edge_index + # data_dict['target_node'] = target_node + data_dict['graph_data'] = { + 'graph_1': Data(graph_node = graph_node_rep_1, edge_index=graph_edge_index_1, target_node = torch.tensor([target_node_1])), + 'graph_2': Data(graph_node = graph_node_rep_2, edge_index=graph_edge_index_2, target_node = torch.tensor([target_node_2])) + } + + elif self.graph_cfg['is_graph']: + # image does not exist in the data, but the model is multimodal + node_feas = self.graph_cfg['graph_processor'].node_feas + data_dict['graph_data'] = Data(graph_node = torch.zeros(3, node_feas), edge_index=torch.zeros(2, 3), target_node = torch.tensor([0])) + return data_dict + +class LazySupervisedDataset_back(Dataset): + """Dataset for supervised fine-tuning.""" + + def __init__(self, data_path: str, + tokenizer: transformers.PreTrainedTokenizer, + graph_cfg: dict, + **kwargs,): + super(LazySupervisedDataset, self).__init__() + logging.warning("Loading data...") + list_data_dict = json.load(open(data_path, "r")) + + logging.warning("Formatting inputs...Skip in lazy mode") + self.tokenizer = tokenizer + self.list_data_dict = list_data_dict + self.graph_cfg = graph_cfg + graph_data_path = kwargs.get('graph_data_path') + self.graph_data_all = torch.load(graph_data_path) + + def __len__(self): + return len(self.list_data_dict) + + def __getitem__(self, i) -> Dict[str, torch.Tensor]: + sources = self.list_data_dict[i] + if isinstance(i, int): + sources = [sources] + assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME + if 'graph' in sources[0]: + graph_dict = self.list_data_dict[i]['graph'] + graph_edge_index = torch.Tensor(copy.deepcopy(graph_dict['edge_index'])).long() + graph_node_list = copy.deepcopy(graph_dict['node_list']) + target_node = copy.deepcopy(graph_dict['node_idx']) + graph_type = copy.deepcopy(self.list_data_dict[i]['id']).split('_')[0] + graph_node_rep = self.graph_data_all[graph_type].x[graph_node_list] ## + + cur_token_len = len(graph_node_rep) # FIXME: 14 is hardcoded patch size + sources = preprocess_graph( + copy.deepcopy([e["conversations"] for e in sources]), + self.graph_cfg, cur_token_len) + else: + sources = copy.deepcopy([e["conversations"] for e in sources]) + data_dict = preprocess( + sources, + self.tokenizer) + if isinstance(i, int): + data_dict = dict(input_ids=data_dict["input_ids"][0], + labels=data_dict["labels"][0]) + + # image exist in the data + if 'graph' in self.list_data_dict[i]: + # data_dict['graph_node'] = graph_node_rep + # data_dict['graph_edge'] = graph_edge_index + # data_dict['target_node'] = target_node + data_dict['graph_data'] = Data(graph_node = graph_node_rep, edge_index=graph_edge_index, target_node = torch.tensor([target_node])) + + elif self.graph_cfg['is_graph']: + # image does not exist in the data, but the model is multimodal + node_feas = self.graph_cfg['graph_processor'].node_feas + data_dict['graph_data'] = Data(graph_node = torch.zeros(3, node_feas), edge_index=torch.zeros(2, 3), target_node = torch.tensor([0])) + return data_dict + + +@dataclass +class DataCollatorForSupervisedDataset(object): + """Collate examples for supervised fine-tuning.""" + + tokenizer: transformers.PreTrainedTokenizer + + def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: + input_ids, labels = tuple([instance[key] for instance in instances] + for key in ("input_ids", "labels")) + input_ids = torch.nn.utils.rnn.pad_sequence( + input_ids, + batch_first=True, + padding_value=self.tokenizer.pad_token_id) + labels = torch.nn.utils.rnn.pad_sequence(labels, + batch_first=True, + padding_value=IGNORE_INDEX) + batch = dict( + input_ids=input_ids, + labels=labels, + attention_mask=input_ids.ne(self.tokenizer.pad_token_id), + ) + + if 'graph_data' in instances[0]: + # graph_node_reps = [instance['graph_node'] for instance in instances] + # edge_index_reps = [instance['graph_edge'] for instance in instances] + # target_node_reps = [instance['target_node'] for instance in instances] + graph_data_batch = [instance['graph_data'] for instance in instances] + # if all(x is not None and x.shape == images[0].shape for x in images): + # batch['images'] = torch.stack(images) + # else: + # batch['images'] = images + # batch['graph_node_reps'] = graph_node_reps + # batch['edge_index_reps'] = edge_index_reps + # batch['edge_index_reps'] = target_node_reps + batch['graph_data'] = graph_data_batch + + return batch + + +def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, + data_args) -> Dict: + """Make dataset and collator for supervised fine-tuning.""" + dataset_cls = (LazySupervisedDataset + if data_args.lazy_preprocess else SupervisedDataset) + train_dataset = dataset_cls(tokenizer=tokenizer, + data_path=data_args.data_path, + graph_cfg=dict( + is_graph=data_args.is_graph, + sep_graph_conv_front=data_args.sep_graph_conv_front, + graph_token_len=data_args.graph_token_len, + graph_content=data_args.graph_content, + use_graph_start_end=getattr(data_args, 'use_graph_start_end', False) + ), + graph_data_path = data_args.graph_data_path) + data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) + return dict(train_dataset=train_dataset, + eval_dataset=None, + data_collator=data_collator) + +def train(): + parser = transformers.HfArgumentParser( + (ModelArguments, DataArguments, TrainingArguments)) + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + + compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32)) + + bnb_model_from_pretrained_args = {} + + ## load 4 8 bit + if training_args.bits in [4, 8]: + from transformers import BitsAndBytesConfig + from peft import prepare_model_for_int8_training + bnb_model_from_pretrained_args.update(dict( + device_map={"": training_args.device}, + load_in_4bit=training_args.bits == 4, + load_in_8bit=training_args.bits == 8, + quantization_config=BitsAndBytesConfig( + load_in_4bit=training_args.bits == 4, + load_in_8bit=training_args.bits == 8, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=compute_dtype, + bnb_4bit_use_double_quant=training_args.double_quant, + bnb_4bit_quant_type=training_args.quant_type # {'fp4', 'nf4'} + ) + )) + + if model_args.graph_tower is not None: + model = GraphLlamaForCausalLM.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + **bnb_model_from_pretrained_args + ) ## TODO: add real Graph Llama model + else: + model = transformers.LlamaForCausalLM.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + **bnb_model_from_pretrained_args + ) + model.config.pretrain_graph_model_path = model.config.pretrain_graph_model_path + model_args.graph_tower + model.config.use_cache = False + + if model_args.freeze_backbone: + model.model.requires_grad_(False) + + if training_args.bits in [4, 8]: + model.config.torch_dtype=(torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32)) + model = prepare_model_for_int8_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing) + + if training_args.gradient_checkpointing and model_args.graph_tower is None: + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + if training_args.lora_enable: + from peft import LoraConfig, get_peft_model + lora_config = LoraConfig( + r=training_args.lora_r, + lora_alpha=training_args.lora_alpha, + target_modules=find_all_linear_names(model), + lora_dropout=training_args.lora_dropout, + bias=training_args.lora_bias, + task_type="CAUSAL_LM", + ) + if training_args.bits == 16: + if training_args.bf16: + model.to(torch.bfloat16) + if training_args.fp16: + model.to(torch.float16) + logging.warning("Adding LoRA adapters...") + model = get_peft_model(model, lora_config) + + tokenizer = transformers.AutoTokenizer.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + model_max_length=training_args.model_max_length, + padding_side="right", + use_fast=False, + ) + + if model_args.version == "v0": + if tokenizer.pad_token is None: + smart_tokenizer_and_embedding_resize( + special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN), + tokenizer=tokenizer, + model=model, + ) + if "llama" in model_args.model_name_or_path: + tokenizer.add_special_tokens({ + "eos_token": DEFAULT_EOS_TOKEN, + "bos_token": DEFAULT_BOS_TOKEN, + "unk_token": DEFAULT_UNK_TOKEN, + }) + else: + tokenizer.pad_token = tokenizer.unk_token + conversation_lib.default_conversation = conversation_lib.conv_templates["vicuna_v1_1"] + + if model_args.graph_tower is not None: + model_graph_dict = model.get_model().initialize_graph_modules( + graph_tower=model_args.graph_tower, + graph_select_layer=model_args.graph_select_layer, + pretrain_graph_mlp_adapter=model_args.pretrain_graph_mlp_adapter, + fsdp=training_args.fsdp + ) + model.get_graph_tower().to(dtype=torch.float16, device=training_args.device) + # graph_config = model_graph_dict['graph_config'] + + # data_args.graph_token_len = model_graph_dict['graph_token_len'] + # data_args.graph_processor = model_graph_dict['graph_processor'] + data_args.is_graph = True + + model.config.tune_graph_mlp_adapter = training_args.tune_graph_mlp_adapter = model_args.tune_graph_mlp_adapter + if model_args.tune_graph_mlp_adapter: + model.requires_grad_(False) + for p in model.get_model().graph_projector.parameters(): + p.requires_grad = True + + model.config.freeze_graph_mlp_adapter = training_args.freeze_graph_mlp_adapter + if training_args.freeze_graph_mlp_adapter: + for p in model.get_model().graph_projector.parameters(): + p.requires_grad = False + + if training_args.bits in [4, 8]: + model.get_model().graph_projector.to(dtype=compute_dtype, device=training_args.device) + + model.config.use_graph_start_end = data_args.use_graph_start_end = model_args.use_graph_start_end + # graph_config.use_graph_start_end = training_args.use_graph_start_end = model_args.use_graph_start_end + training_args.use_graph_start_end = model_args.use_graph_start_end + model.config.sep_graph_conv_front = data_args.sep_graph_conv_front + model.initialize_graph_tokenizer(use_graph_start_end=model_args.use_graph_start_end, tokenizer=tokenizer, device=training_args.device, + tune_graph_mlp_adapter=model_args.tune_graph_mlp_adapter, pretrain_graph_mlp_adapter=model_args.pretrain_graph_mlp_adapter) + + params_no_grad = [n for n, p in model.named_parameters() if not p.requires_grad] + if len(params_no_grad) > 0: + if training_args.fsdp is not None and len(training_args.fsdp) > 0: + if len(params_no_grad) < 10: + print('[WARNING] Attempting to use FSDP while {} parameters do not require gradients: {}'. format(len(params_no_grad), params_no_grad)) + else: + print('[WARNING] Attempting to use FSDP while {} parameters do not require gradients: {}...(omitted)'. format(len(params_no_grad), ', '.join(params_no_grad[:10]))) + print("[WARNING] Attempting to use FSDP with partially frozen paramters, this is experimental.") + print("[WARNING] As of 4/30/23, this feature requires PyTorch-nightly build. See here for details: https://github.com/haotian-liu/LLaVA#experimental-use-fsdp-to-save-memory-in-pretraining") + + from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP + def patch_FSDP_use_orig_params(func): + def wrap_func(*args, **kwargs): + use_orig_params = kwargs.pop('use_orig_params', True) + return func(*args, **kwargs, use_orig_params=use_orig_params) + return wrap_func + + FSDP.__init__ = patch_FSDP_use_orig_params(FSDP.__init__) + + if training_args.bits in [4, 8]: + from peft.tuners.lora import LoraLayer + for name, module in model.named_modules(): + if isinstance(module, LoraLayer): + if training_args.bf16: + module = module.to(torch.bfloat16) + if 'norm' in name: + module = module.to(torch.float32) + if 'lm_head' in name or 'embed_tokens' in name: + if hasattr(module, 'weight'): + if training_args.bf16 and module.weight.dtype == torch.float32: + module = module.to(torch.bfloat16) + + data_module = make_supervised_data_module(tokenizer=tokenizer, + data_args=data_args) + trainer = GraphChatTrainer(model=model, + tokenizer=tokenizer, + args=training_args, + **data_module) + + print('************************** parameters: #', sum(p.numel() for p in model.parameters() if p.requires_grad)) + tuned_params = [] + for name, param in model.named_parameters(): + if param.requires_grad: + tuned_params.append(name) + print(tuned_params) + + if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): + trainer.train(resume_from_checkpoint=True) + else: + trainer.train() + trainer.save_state() + + if training_args.lora_enable: + state_dict = get_peft_state_maybe_zero_3( + model.named_parameters(), training_args.lora_bias + ) + non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3( + model.named_parameters() + ) + if training_args.local_rank == 0 or training_args.local_rank == -1: + model.config.save_pretrained(training_args.output_dir) + model.save_pretrained(training_args.output_dir, state_dict=state_dict) + torch.save(non_lora_state_dict, os.path.join(training_args.output_dir, 'non_lora_trainables.bin')) + else: + safe_save_model_for_hf_trainer(trainer=trainer, + output_dir=training_args.output_dir) + + +if __name__ == "__main__": + train() diff --git a/graphgpt/train/train_lora.py b/graphgpt/train/train_lora.py new file mode 100644 index 0000000..ac52f81 --- /dev/null +++ b/graphgpt/train/train_lora.py @@ -0,0 +1,157 @@ +# Usage: deepspeed train_lora.py --deepspeed <$PATH_TO_DEEPSPEED_CONFIG> + +# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright: +# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +import logging +import pathlib +import typing + +from deepspeed import zero +from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus +from peft import LoraConfig, get_peft_model +import transformers +from transformers import Trainer + +from fastchat.train.train import ( + DataArguments, + ModelArguments, + TrainingArguments, + make_supervised_data_module, +) + +from fastchat.train.llama_flash_attn_monkey_patch import ( + replace_llama_attn_with_flash_attn, +) + +replace_llama_attn_with_flash_attn() + + +@dataclass +class LoraArguments: + lora_r: int = 8 + lora_alpha: int = 16 + lora_dropout: float = 0.05 + lora_target_modules: typing.List[str] = field( + default_factory=lambda: ["q_proj", "v_proj"] + ) + lora_weight_path: str = "" + lora_bias: str = "none" + + +def maybe_zero_3(param): + if hasattr(param, "ds_id"): + assert param.ds_status == ZeroParamStatus.NOT_AVAILABLE + with zero.GatheredParameters([param]): + param = param.data.detach().cpu().clone() + else: + param = param.detach().cpu().clone() + return param + + +# Borrowed from peft.utils.get_peft_model_state_dict +def get_peft_state_maybe_zero_3(named_params, bias): + if bias == "none": + to_return = {k: t for k, t in named_params if "lora_" in k} + elif bias == "all": + to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k} + elif bias == "lora_only": + to_return = {} + maybe_lora_bias = {} + lora_bias_names = set() + for k, t in named_params: + if "lora_" in k: + to_return[k] = t + bias_name = k.split("lora_")[0] + "bias" + lora_bias_names.add(bias_name) + elif "bias" in k: + maybe_lora_bias[k] = t + for k, t in maybe_lora_bias: + if bias_name in lora_bias_names: + to_return[bias_name] = t + else: + raise NotImplementedError + to_return = {k: maybe_zero_3(v) for k, v in to_return.items()} + return to_return + + +def train(): + parser = transformers.HfArgumentParser( + (ModelArguments, DataArguments, TrainingArguments, LoraArguments) + ) + ( + model_args, + data_args, + training_args, + lora_args, + ) = parser.parse_args_into_dataclasses() + + model = transformers.AutoModelForCausalLM.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + ) + lora_config = LoraConfig( + r=lora_args.lora_r, + lora_alpha=lora_args.lora_alpha, + target_modules=lora_args.lora_target_modules, + lora_dropout=lora_args.lora_dropout, + bias=lora_args.lora_bias, + task_type="CAUSAL_LM", + ) + model = get_peft_model(model, lora_config) + if training_args.deepspeed is not None and training_args.local_rank == 0: + model.print_trainable_parameters() + + if training_args.gradient_checkpointing: + logging.warning( + "gradient checkpointing with lora makes requires_grad " + "incorrect and needs a monkey patch in Trainer or the " + "wrapped model's forward. ref: " + "https://github.com/lm-sys/FastChat/pull/138#issuecomment-1509172198" + ) + + tokenizer = transformers.AutoTokenizer.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + model_max_length=training_args.model_max_length, + padding_side="right", + use_fast=False, + ) + tokenizer.pad_token = tokenizer.unk_token + + data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args) + trainer = Trainer( + model=model, tokenizer=tokenizer, args=training_args, **data_module + ) + + model.config.use_cache = False + + if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): + trainer.train(resume_from_checkpoint=True) + else: + trainer.train() + trainer.save_state() + + # Save states. Weights might be a placeholder in zero3 and need a gather + state_dict = get_peft_state_maybe_zero_3( + model.named_parameters(), lora_args.lora_bias + ) + if training_args.local_rank == 0: + model.save_pretrained(training_args.output_dir, state_dict=state_dict) + + +if __name__ == "__main__": + train() diff --git a/graphgpt/train/train_mem.py b/graphgpt/train/train_mem.py new file mode 100644 index 0000000..bfe35ea --- /dev/null +++ b/graphgpt/train/train_mem.py @@ -0,0 +1,13 @@ +# Make it more memory efficient by monkey patching the LLaMA model with FlashAttn. + +# Need to call this before importing transformers. +from graphgpt.train.llama_flash_attn_monkey_patch import ( + replace_llama_attn_with_flash_attn, +) + +replace_llama_attn_with_flash_attn() + +from graphgpt.train.train_graph import train + +if __name__ == "__main__": + train() diff --git a/graphgpt/utils.py b/graphgpt/utils.py new file mode 100644 index 0000000..75188b0 --- /dev/null +++ b/graphgpt/utils.py @@ -0,0 +1,240 @@ +from asyncio import AbstractEventLoop +import json +import logging +import logging.handlers +import os +import platform +import sys +from typing import AsyncGenerator, Generator +import warnings + +import requests +import torch + +from fastchat.constants import LOGDIR + + +handler = None +server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" +moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN." + +def build_logger(logger_name, logger_filename): + global handler + + formatter = logging.Formatter( + fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + + # Set the format of root handlers + if not logging.getLogger().handlers: + if sys.version_info[1] >= 9: + # This is for windows + logging.basicConfig(level=logging.INFO, encoding="utf-8") + else: + if platform.system() == "Windows": + warnings.warn( + "If you are running on Windows, " + "we recommend you use Python >= 3.9 for UTF-8 encoding." + ) + logging.basicConfig(level=logging.INFO) + logging.getLogger().handlers[0].setFormatter(formatter) + + # Redirect stdout and stderr to loggers + stdout_logger = logging.getLogger("stdout") + stdout_logger.setLevel(logging.INFO) + sl = StreamToLogger(stdout_logger, logging.INFO) + sys.stdout = sl + + stderr_logger = logging.getLogger("stderr") + stderr_logger.setLevel(logging.ERROR) + sl = StreamToLogger(stderr_logger, logging.ERROR) + sys.stderr = sl + + # Get logger + logger = logging.getLogger(logger_name) + logger.setLevel(logging.INFO) + + # Add a file handler for all loggers + if handler is None: + os.makedirs(LOGDIR, exist_ok=True) + filename = os.path.join(LOGDIR, logger_filename) + handler = logging.handlers.TimedRotatingFileHandler( + filename, when="D", utc=True, encoding="utf-8" + ) + handler.setFormatter(formatter) + + for name, item in logging.root.manager.loggerDict.items(): + if isinstance(item, logging.Logger): + item.addHandler(handler) + + return logger + + +class StreamToLogger(object): + """ + Fake file-like stream object that redirects writes to a logger instance. + """ + + def __init__(self, logger, log_level=logging.INFO): + self.terminal = sys.stdout + self.logger = logger + self.log_level = log_level + self.linebuf = "" + + def __getattr__(self, attr): + return getattr(self.terminal, attr) + + def write(self, buf): + temp_linebuf = self.linebuf + buf + self.linebuf = "" + for line in temp_linebuf.splitlines(True): + # From the io.TextIOWrapper docs: + # On output, if newline is None, any '\n' characters written + # are translated to the system default line separator. + # By default sys.stdout.write() expects '\n' newlines and then + # translates them so this is still cross platform. + if line[-1] == "\n": + encoded_message = line.encode("utf-8", "ignore").decode("utf-8") + self.logger.log(self.log_level, encoded_message.rstrip()) + else: + self.linebuf += line + + def flush(self): + if self.linebuf != "": + encoded_message = self.linebuf.encode("utf-8", "ignore").decode("utf-8") + self.logger.log(self.log_level, encoded_message.rstrip()) + self.linebuf = "" + + +def disable_torch_init(): + """ + Disable the redundant torch default initialization to accelerate model creation. + """ + import torch + + setattr(torch.nn.Linear, "reset_parameters", lambda self: None) + setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) + + +def get_gpu_memory(max_gpus=None): + """Get available memory for each GPU.""" + gpu_memory = [] + num_gpus = ( + torch.cuda.device_count() + if max_gpus is None + else min(max_gpus, torch.cuda.device_count()) + ) + + for gpu_id in range(num_gpus): + with torch.cuda.device(gpu_id): + device = torch.cuda.current_device() + gpu_properties = torch.cuda.get_device_properties(device) + total_memory = gpu_properties.total_memory / (1024**3) + allocated_memory = torch.cuda.memory_allocated() / (1024**3) + available_memory = total_memory - allocated_memory + gpu_memory.append(available_memory) + return gpu_memory + + +def violates_moderation(text): + """ + Check whether the text violates OpenAI moderation API. + """ + url = "https://api.openai.com/v1/moderations" + headers = { + "Content-Type": "application/json", + "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"], + } + text = text.replace("\n", "") + data = "{" + '"input": ' + f'"{text}"' + "}" + data = data.encode("utf-8") + try: + ret = requests.post(url, headers=headers, data=data, timeout=5) + flagged = ret.json()["results"][0]["flagged"] + except requests.exceptions.RequestException as e: + flagged = False + except KeyError as e: + flagged = False + + return flagged + + +# Flan-t5 trained with HF+FSDP saves corrupted weights for shared embeddings, +# Use this function to make sure it can be correctly loaded. +def clean_flant5_ckpt(ckpt_path): + index_file = os.path.join(ckpt_path, "pytorch_model.bin.index.json") + index_json = json.load(open(index_file, "r")) + + weightmap = index_json["weight_map"] + + share_weight_file = weightmap["shared.weight"] + share_weight = torch.load(os.path.join(ckpt_path, share_weight_file))[ + "shared.weight" + ] + + for weight_name in ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"]: + weight_file = weightmap[weight_name] + weight = torch.load(os.path.join(ckpt_path, weight_file)) + weight[weight_name] = share_weight + torch.save(weight, os.path.join(ckpt_path, weight_file)) + + +def pretty_print_semaphore(semaphore): + """Print a semaphore in better format.""" + if semaphore is None: + return "None" + return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})" + + +"""A javascript function to get url parameters for the gradio web server.""" +get_window_url_params_js = """ +function() { + const params = new URLSearchParams(window.location.search); + url_params = Object.fromEntries(params); + console.log("url_params", url_params); + return url_params; + } +""" + + +def iter_over_async( + async_gen: AsyncGenerator, event_loop: AbstractEventLoop +) -> Generator: + """ + Convert async generator to sync generator + + :param async_gen: the AsyncGenerator to convert + :param event_loop: the event loop to run on + :returns: Sync generator + """ + ait = async_gen.__aiter__() + + async def get_next(): + try: + obj = await ait.__anext__() + return False, obj + except StopAsyncIteration: + return True, None + + while True: + done, obj = event_loop.run_until_complete(get_next()) + if done: + break + yield obj + + +def detect_language(text: str) -> str: + """Detect the langauge of a string.""" + import polyglot # pip3 install polyglot pyicu pycld2 + from polyglot.detect import Detector + from polyglot.detect.base import logger as polyglot_logger + import pycld2 + + polyglot_logger.setLevel("ERROR") + + try: + lang_code = Detector(text).language.name + except (pycld2.error, polyglot.detect.base.UnknownLanguage): + lang_code = "unknown" + return lang_code diff --git a/images/.DS_Store b/images/.DS_Store new file mode 100644 index 0000000..58595eb Binary files /dev/null and b/images/.DS_Store differ diff --git a/images/graphgpt.png b/images/graphgpt.png new file mode 100644 index 0000000..a73ef65 Binary files /dev/null and b/images/graphgpt.png differ diff --git a/playground/inspect_conv.py b/playground/inspect_conv.py new file mode 100644 index 0000000..99306e5 --- /dev/null +++ b/playground/inspect_conv.py @@ -0,0 +1,87 @@ +import argparse +import code +import datetime +import json +import os +from pytz import timezone +import time + +import pandas as pd +from tqdm import tqdm + + +def get_log_files(max_num_files=None): + dates = [] + for month in [4, 5]: + for day in range(1, 32): + dates.append(f"2023-{month:02d}-{day:02d}") + + num_servers = 12 + filenames = [] + for d in dates: + for i in range(num_servers): + name = os.path.expanduser(f"~/fastchat_logs/server{i}/{d}-conv.json") + if os.path.exists(name): + filenames.append(name) + max_num_files = max_num_files or len(filenames) + filenames = filenames[-max_num_files:] + return filenames + + +def pretty_print_conversation(messages): + for role, msg in messages: + print(f"[[{role}]]: {msg}") + + +def inspect_convs(log_files): + data = [] + for filename in tqdm(log_files, desc="read files"): + for retry in range(5): + try: + lines = open(filename).readlines() + break + except FileNotFoundError: + time.sleep(2) + + for l in lines: + row = json.loads(l) + + if "states" not in row: + continue + if row["type"] not in ["leftvote", "rightvote", "bothbad_vote"]: + continue + + model_names = row["states"][0]["model_name"], row["states"][1]["model_name"] + if row["type"] == "leftvote": + winner, loser = model_names[0], model_names[1] + winner_conv, loser_conv = row["states"][0], row["states"][1] + elif row["type"] == "rightvote": + loser, winner = model_names[0], model_names[1] + loser_conv, winner_conv = row["states"][0], row["states"][1] + + if loser == "bard" and winner == "vicuna-13b": + print("=" * 20) + print(f"Winner: {winner}") + pretty_print_conversation(winner_conv["messages"]) + print(f"Loser: {loser}") + pretty_print_conversation(loser_conv["messages"]) + print("=" * 20) + input() + + # if row["type"] == "bothbad_vote" and "gpt-4" in model_names: + # print("=" * 20) + # print(f"Model A: {model_names[0]}") + # pretty_print_conversation(row["states"][0]["messages"]) + # print(f"Model B: {model_names[1]}") + # pretty_print_conversation(row["states"][1]["messages"]) + # print("=" * 20) + # input() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--max-num-files", type=int) + args = parser.parse_args() + + log_files = get_log_files(args.max_num_files) + inspect_convs(log_files) diff --git a/playground/test_embedding/README.md b/playground/test_embedding/README.md new file mode 100644 index 0000000..57ac73c --- /dev/null +++ b/playground/test_embedding/README.md @@ -0,0 +1,15 @@ +## Machine Learning with Embeddings +You can use embeddings to +- Evaluate text similarity, see [test_sentence_similarity.py](test_sentence_similarity.py) +- Build your own classifier, see [test_classification.py](test_classification.py) +- Search relative texts, see [test_semantic_search.py](test_semantic_search.py) + +To these tests, you need to download the data [here](https://www.kaggle.com/datasets/snap/amazon-fine-food-reviews). You also need an OpenAI API key for comparison. + +Run with: +```bash +cd playground/test_embedding +python3 test_classification.py +``` + +The script will train classifiers based on `vicuna-7b`, `text-similarity-ada-001` and `text-embedding-ada-002` and report the accuracy of each classifier. diff --git a/playground/test_embedding/test_classification.py b/playground/test_embedding/test_classification.py new file mode 100644 index 0000000..393827b --- /dev/null +++ b/playground/test_embedding/test_classification.py @@ -0,0 +1,83 @@ +import json +import os + +import numpy as np +import openai +import pandas as pd +import requests +from sklearn.ensemble import RandomForestClassifier +from sklearn.model_selection import train_test_split +from sklearn.metrics import classification_report, accuracy_score + + +np.set_printoptions(threshold=10000) + + +def get_embedding_from_api(word, model="vicuna-7b-v1.1"): + if "ada" in model: + resp = openai.Embedding.create( + model=model, + input=word, + ) + embedding = np.array(resp["data"][0]["embedding"]) + return embedding + + url = "http://localhost:8000/v1/embeddings" + headers = {"Content-Type": "application/json"} + data = json.dumps({"model": model, "input": word}) + + response = requests.post(url, headers=headers, data=data) + if response.status_code == 200: + embedding = np.array(response.json()["data"][0]["embedding"]) + return embedding + else: + print(f"Error: {response.status_code} - {response.text}") + return None + + +def create_embedding_data_frame(data_path, model, max_tokens=500): + df = pd.read_csv(data_path, index_col=0) + df = df[["Time", "ProductId", "UserId", "Score", "Summary", "Text"]] + df = df.dropna() + df["combined"] = ( + "Title: " + df.Summary.str.strip() + "; Content: " + df.Text.str.strip() + ) + top_n = 1000 + df = df.sort_values("Time").tail(top_n * 2) + df.drop("Time", axis=1, inplace=True) + + df["n_tokens"] = df.combined.apply(lambda x: len(x)) + df = df[df.n_tokens <= max_tokens].tail(top_n) + df["embedding"] = df.combined.apply(lambda x: get_embedding_from_api(x, model)) + return df + + +def train_random_forest(df): + X_train, X_test, y_train, y_test = train_test_split( + list(df.embedding.values), df.Score, test_size=0.2, random_state=42 + ) + + clf = RandomForestClassifier(n_estimators=100) + clf.fit(X_train, y_train) + preds = clf.predict(X_test) + + report = classification_report(y_test, preds) + accuracy = accuracy_score(y_test, preds) + return clf, accuracy, report + + +input_datapath = "amazon_fine_food_review.csv" +if not os.path.exists(input_datapath): + raise Exception( + f"Please download data from: https://www.kaggle.com/datasets/snap/amazon-fine-food-reviews" + ) + +df = create_embedding_data_frame(input_datapath, "vicuna-7b-v1.1") +clf, accuracy, report = train_random_forest(df) +print(f"Vicuna-7b-v1.1 accuracy:{accuracy}") +df = create_embedding_data_frame(input_datapath, "text-similarity-ada-001") +clf, accuracy, report = train_random_forest(df) +print(f"text-similarity-ada-001 accuracy:{accuracy}") +df = create_embedding_data_frame(input_datapath, "text-embedding-ada-002") +clf, accuracy, report = train_random_forest(df) +print(f"text-embedding-ada-002 accuracy:{accuracy}") diff --git a/playground/test_embedding/test_semantic_search.py b/playground/test_embedding/test_semantic_search.py new file mode 100644 index 0000000..879b240 --- /dev/null +++ b/playground/test_embedding/test_semantic_search.py @@ -0,0 +1,99 @@ +import json +import os + +import numpy as np +import openai +import pandas as pd +import requests +from scipy.spatial.distance import cosine + + +def cosine_similarity(vec1, vec2): + try: + return 1 - cosine(vec1, vec2) + except: + print(vec1.shape, vec2.shape) + + +def get_embedding_from_api(word, model="vicuna-7b-v1.1"): + if "ada" in model: + resp = openai.Embedding.create( + model=model, + input=word, + ) + embedding = np.array(resp["data"][0]["embedding"]) + return embedding + + url = "http://localhost:8000/v1/embeddings" + headers = {"Content-Type": "application/json"} + data = json.dumps({"model": model, "input": word}) + + response = requests.post(url, headers=headers, data=data) + if response.status_code == 200: + embedding = np.array(response.json()["data"][0]["embedding"]) + return embedding + else: + print(f"Error: {response.status_code} - {response.text}") + return None + + +def create_embedding_data_frame(data_path, model, max_tokens=500): + df = pd.read_csv(data_path, index_col=0) + df = df[["Time", "ProductId", "UserId", "Score", "Summary", "Text"]] + df = df.dropna() + df["combined"] = ( + "Title: " + df.Summary.str.strip() + "; Content: " + df.Text.str.strip() + ) + top_n = 1000 + df = df.sort_values("Time").tail(top_n * 2) + df.drop("Time", axis=1, inplace=True) + + df["n_tokens"] = df.combined.apply(lambda x: len(x)) + df = df[df.n_tokens <= max_tokens].tail(top_n) + df["embedding"] = df.combined.apply(lambda x: get_embedding_from_api(x, model)) + return df + + +def search_reviews(df, product_description, n=3, pprint=False, model="vicuna-7b-v1.1"): + product_embedding = get_embedding_from_api(product_description, model=model) + df["similarity"] = df.embedding.apply( + lambda x: cosine_similarity(x, product_embedding) + ) + + results = ( + df.sort_values("similarity", ascending=False) + .head(n) + .combined.str.replace("Title: ", "") + .str.replace("; Content:", ": ") + ) + if pprint: + for r in results: + print(r[:200]) + print() + return results + + +def print_model_search(input_path, model): + print(f"Model: {model}") + df = create_embedding_data_frame(input_path, model) + print("search: delicious beans") + results = search_reviews(df, "delicious beans", n=5, model=model) + print(results) + print("search: whole wheat pasta") + results = search_reviews(df, "whole wheat pasta", n=5, model=model) + print(results) + print("search: bad delivery") + results = search_reviews(df, "bad delivery", n=5, model=model) + print(results) + + +input_datapath = "amazon_fine_food_review.csv" +if not os.path.exists(input_datapath): + raise Exception( + f"Please download data from: https://www.kaggle.com/datasets/snap/amazon-fine-food-reviews" + ) + + +print_model_search(input_datapath, "vicuna-7b-v1.1") +print_model_search(input_datapath, "text-similarity-ada-001") +print_model_search(input_datapath, "text-embedding-ada-002") diff --git a/playground/test_embedding/test_sentence_similarity.py b/playground/test_embedding/test_sentence_similarity.py new file mode 100644 index 0000000..0b9a540 --- /dev/null +++ b/playground/test_embedding/test_sentence_similarity.py @@ -0,0 +1,67 @@ +import json +import os + +import numpy as np +import openai +import requests +from scipy.spatial.distance import cosine + + +def get_embedding_from_api(word, model="vicuna-7b-v1.1"): + if "ada" in model: + resp = openai.Embedding.create( + model=model, + input=word, + ) + embedding = np.array(resp["data"][0]["embedding"]) + return embedding + + url = "http://localhost:8000/v1/embeddings" + headers = {"Content-Type": "application/json"} + data = json.dumps({"model": model, "input": word}) + + response = requests.post(url, headers=headers, data=data) + if response.status_code == 200: + embedding = np.array(response.json()["data"][0]["embedding"]) + return embedding + else: + print(f"Error: {response.status_code} - {response.text}") + return None + + +def cosine_similarity(vec1, vec2): + return 1 - cosine(vec1, vec2) + + +def print_cosine_similarity(embeddings, texts): + for i in range(len(texts)): + for j in range(i + 1, len(texts)): + sim = cosine_similarity(embeddings[texts[i]], embeddings[texts[j]]) + print(f"Cosine similarity between '{texts[i]}' and '{texts[j]}': {sim:.2f}") + + +texts = [ + "The quick brown fox", + "The quick brown dog", + "The fast brown fox", + "A completely different sentence", +] + +embeddings = {} +for text in texts: + embeddings[text] = get_embedding_from_api(text) + +print("Vicuna-7B:") +print_cosine_similarity(embeddings, texts) + +for text in texts: + embeddings[text] = get_embedding_from_api(text, model="text-similarity-ada-001") + +print("text-similarity-ada-001:") +print_cosine_similarity(embeddings, texts) + +for text in texts: + embeddings[text] = get_embedding_from_api(text, model="text-embedding-ada-002") + +print("text-embedding-ada-002:") +print_cosine_similarity(embeddings, texts) diff --git a/playground/test_openai_api/anthropic_api.py b/playground/test_openai_api/anthropic_api.py new file mode 100644 index 0000000..c60a701 --- /dev/null +++ b/playground/test_openai_api/anthropic_api.py @@ -0,0 +1,27 @@ +import os + +from fastchat.model import get_conversation_template + + +def claude(): + import anthropic + c = anthropic.Client(os.environ["ANTHROPIC_API_KEY"]) + + model = "claude-v1" + conv = get_conversation_template(model) + conv.append_message(conv.roles[0], "Hello!") + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + + response = c.completion_stream( + prompt=prompt, + stop_sequences=[anthropic.HUMAN_PROMPT], + max_tokens_to_sample=256, + model=model, + stream=True, + ) + for data in response: + print(data["completion"]) + + +claude() diff --git a/playground/test_openai_api/openai_api.py b/playground/test_openai_api/openai_api.py new file mode 100644 index 0000000..a36c680 --- /dev/null +++ b/playground/test_openai_api/openai_api.py @@ -0,0 +1,26 @@ +import os + +from fastchat.model import get_conversation_template + +def chatgpt(): + import openai + model = "gpt-3.5-turbo" + conv = get_conversation_template(model) + conv.append_message(conv.roles[0], "Hello!") + conv.append_message(conv.roles[1], None) + + messages = conv.to_openai_api_messages() + print(messages) + + res = openai.ChatCompletion.create(model=model, messages=messages) + msg = res["choices"][0]["message"]["content"] + print(msg) + + res = openai.ChatCompletion.create(model=model, messages=messages, stream=True) + msg = "" + for chunk in res: + msg += chunk["choices"][0]["delta"].get("content", "") + print(msg) + + +chatgpt() diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..a6b1f63 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,209 @@ +absl-py==1.4.0 +accelerate==0.21.0 +aiofiles==23.1.0 +aiohttp==3.8.4 +aiosignal==1.3.1 +altair==5.0.1 +anthropic==0.3.9 +anyio==3.7.0 +appdirs==1.4.4 +aspy.yaml==1.3.0 +astor==0.8.1 +astroid==2.5.1 +asttokens==2.2.1 +async-timeout==4.0.2 +attrs==20.3.0 +backcall==0.2.0 +bitsandbytes==0.39.1 +black==23.3.0 +blessed==1.20.0 +blinker==1.6.2 +cached-property==1.5.2 +cachetools==5.3.1 +certifi==2020.12.5 +cfgv==3.2.0 +chardet==4.0.0 +charset-normalizer==3.1.0 +click==8.1.3 +contourpy==1.1.0 +cpm-kernels==1.0.11 +cycler==0.11.0 +datasets==2.10.1 +decorator==4.4.2 +deepspeed==0.10.0 +dill==0.3.4 +distlib==0.3.1 +distro==1.8.0 +docker-pycreds==0.4.0 +einops==0.6.1 +evaluate==0.4.0 +exceptiongroup==1.1.1 +executing==1.2.0 +fastapi==0.98.0 +ffmpy==0.3.0 +filelock==3.0.12 +fire==0.5.0 +flash-attn==1.0.4 +Flask==2.3.2 +Flask-Cors==4.0.0 +fonttools==4.40.0 +frozenlist==1.3.3 +fsspec==2023.6.0 +ftfy==6.1.1 +gast==0.4.0 +gitdb==4.0.10 +GitPython==3.1.31 +google-auth==2.22.0 +google-auth-oauthlib==1.0.0 +gpustat==1.1 +gradio==3.23.0 +grpcio==1.56.2 +h11==0.14.0 +hjson==3.1.0 +httpcore==0.17.2 +httpx==0.24.1 +huggingface-hub==0.15.1 +icetk==0.0.7 +identify==2.2.0 +idna==2.10 +importlib-metadata==6.8.0 +importlib-resources==5.12.0 +iniconfig==1.1.1 +ipykernel==4.6.0 +ipython==8.13.0 +ipython-genutils==0.2.0 +isort==5.8.0 +itsdangerous==2.1.2 +jedi==0.18.2 +Jinja2==3.1.2 +joblib==1.3.1 +jsonschema==4.17.3 +jupyter-client==6.1.12 +jupyter-core==4.7.1 +kiwisolver==1.4.4 +lazy-object-proxy==1.5.2 +linkify-it-py==2.0.2 +loguru==0.7.0 +loralib==0.1.1 +Markdown==3.4.4 +markdown-it-py==2.2.0 +markdown2==2.4.9 +MarkupSafe==2.1.3 +matplotlib==3.7.1 +matplotlib-inline==0.1.6 +mccabe==0.6.1 +mdit-py-plugins==0.3.3 +mdurl==0.1.2 +mpi4py==3.1.4 +msgpack==1.0.5 +multidict==6.0.4 +multiprocess==0.70.12.2 +mypy-extensions==1.0.0 +nh3==0.2.13 +ninja==1.11.1 +nodeenv==1.5.0 +numpy==1.24.2 +nvidia-ml-py==12.535.77 +oauthlib==3.2.2 +openai==0.27.8 +orjson==3.9.1 +packaging==23.1 +pandas==2.0.3 +parso==0.8.3 +pathspec==0.11.1 +pathtools==0.1.2 +peft==0.4.0 +pexpect==4.8.0 +pickleshare==0.7.5 +Pillow==8.1.2 +pkgutil_resolve_name==1.3.10 +platformdirs==3.8.0 +pluggy==0.13.1 +pre-commit==1.10.4 +prompt-toolkit==3.0.38 +protobuf==3.20.0 +psutil==5.9.5 +ptyprocess==0.7.0 +pure-eval==0.2.2 +py==1.10.0 +py-cpuinfo==9.0.0 +pyarrow==12.0.1 +pyasn1==0.5.0 +pyasn1-modules==0.3.0 +pydantic==1.10.9 +pydub==0.25.1 +pyg-lib==0.2.0 +Pygments==2.15.1 +PyGObject==3.26.1 +pylint==2.7.2 +pyparsing==2.4.7 +pyrsistent==0.19.3 +pytest==6.2.2 +python-apt==1.6.5+ubuntu0.7 +python-dateutil==2.8.2 +python-multipart==0.0.6 +pytz==2023.3 +PyYAML==5.4.1 +pyzmq==22.0.3 +ray==2.6.1 +regex==2023.6.3 +requests==2.31.0 +requests-oauthlib==1.3.1 +responses==0.18.0 +rich==13.4.2 +rsa==4.9 +safetensors==0.3.1 +scikit-learn==1.2.2 +scipy==1.10.1 +semantic-version==2.10.0 +sentencepiece==0.1.99 +sentry-sdk==1.26.0 +setproctitle==1.3.2 +shortuuid==1.0.11 +simplegeneric==0.8.1 +six==1.15.0 +smmap==5.0.0 +sniffio==1.3.0 +ssh-import-id==5.7 +stack-data==0.6.2 +starlette==0.27.0 +svgwrite==1.4.3 +tensorboard==2.13.0 +tensorboard-data-server==0.7.1 +tensorboardX==2.6.1 +termcolor==2.3.0 +threadpoolctl==3.2.0 +tiktoken==0.4.0 +tokenize-rt==5.1.0 +tokenizers==0.13.3 +toml==0.10.2 +tomli==2.0.1 +toolz +# torch==1.13.0 +torch-cluster==1.6.1 +torch-geometric==2.3.1 +torch-scatter==2.1.0 +torch-sparse==0.6.16 +torch-spline-conv==1.2.2 +# torchvision==0.14.0 +tornado==6.1 +tqdm +traitlets==5.0.5 +transformers==4.31.0 +trl +typing_extensions==4.7.0 +tzdata==2023.3 +uc-micro-py==1.0.2 +unattended-upgrades==0.1 +urllib3 +uvicorn==0.22.0 +virtualenv==20.4.3 +wandb==0.15.11 +wavedrom==2.0.3.post3 +wcwidth==0.2.5 +websockets==11.0.3 +Werkzeug==2.3.6 +wrapt==1.12.1 +xxhash==3.2.0 +yarl==1.9.2 +zipp==3.15.0 diff --git a/scripts/eval_script/graphgpt_eval.sh b/scripts/eval_script/graphgpt_eval.sh new file mode 100644 index 0000000..412fcf6 --- /dev/null +++ b/scripts/eval_script/graphgpt_eval.sh @@ -0,0 +1,10 @@ +# to fill in the following path to extract projector for the second tuning stage! +output_model= +datapath= +graph_data_path= +res_path= +start_id= +end_id= +num_gpus= + +python3.8 ./graphgpt/eval/run_graphgpt.py --model-name ${output_model} --prompting_file ${datapath} --graph_data_path ${graph_data_path} --output_res_path ${res_path} --start_id ${start_id} --end_id ${end_id} --num_gpus ${num_gpus} \ No newline at end of file diff --git a/scripts/extract_graph_projector.py b/scripts/extract_graph_projector.py new file mode 100644 index 0000000..6060a2f --- /dev/null +++ b/scripts/extract_graph_projector.py @@ -0,0 +1,42 @@ +import os +import argparse +import torch +import json +from collections import defaultdict + + +def parse_args(): + parser = argparse.ArgumentParser(description='Extract MMProjector weights') + parser.add_argument('--model_name_or_path', type=str, help='model folder') + parser.add_argument('--output', type=str, help='output file') + args = parser.parse_args() + return args + + +if __name__ == '__main__': + args = parse_args() + + keys_to_match = ['graph_projector', 'embed_tokens', 'transformer.wte'] + ckpt_to_key = defaultdict(list) + try: + model_indices = json.load(open(os.path.join(args.model_name_or_path, 'pytorch_model.bin.index.json'))) + for k, v in model_indices['weight_map'].items(): + if any(key_match in k for key_match in keys_to_match): + ckpt_to_key[v].append(k) + except FileNotFoundError: + # Smaller models or model checkpoints saved by DeepSpeed. + v = 'pytorch_model.bin' + for k in torch.load(os.path.join(args.model_name_or_path, v), map_location='cpu').keys(): + if any(key_match in k for key_match in keys_to_match): + ckpt_to_key[v].append(k) + + loaded_weights = {} + + for ckpt_name, weight_keys in ckpt_to_key.items(): + ckpt = torch.load(os.path.join(args.model_name_or_path, ckpt_name), map_location='cpu') + for k in weight_keys: + loaded_weights[k] = ckpt[k] + + print(loaded_weights.keys()) + + torch.save(loaded_weights, args.output) diff --git a/scripts/serving/controller.yaml b/scripts/serving/controller.yaml new file mode 100644 index 0000000..35c3a88 --- /dev/null +++ b/scripts/serving/controller.yaml @@ -0,0 +1,29 @@ +resources: + cloud: gcp + region: us-central1 + +num_nodes: 1 + +workdir: . + +file_mounts: + ~/chatlogs: + name: skypilot-chatbot-logs + store: gcs + mode: MOUNT + +setup: | + conda activate chatbot + if [ $? -eq 0 ]; then + echo 'conda env exists' + else + # Setup the environment + conda create -n chatbot python=3.10 -y + conda activate chatbot + pip3 install -e . + fi + +run: | + conda activate chatbot + python3 -m fastchat.serve.controller --host 0.0.0.0 --port 21001 & + python3 -m fastchat.serve.gradio_web_server --share diff --git a/scripts/serving/model_worker.yaml b/scripts/serving/model_worker.yaml new file mode 100644 index 0000000..d241e74 --- /dev/null +++ b/scripts/serving/model_worker.yaml @@ -0,0 +1,57 @@ +resources: + accelerators: A100:1 + cloud: gcp + region: us-central1 + +num_nodes: 1 + +workdir: . + +file_mounts: + /artifacts: + name: skypilot-chatbot + store: gcs + mode: MOUNT + + ~/chatlogs: + name: skypilot-chatbot-logs + store: gcs + mode: MOUNT + +setup: | + conda activate chatbot + if [ $? -eq 0 ]; then + echo 'conda env exists' + else + # Setup the environment + conda create -n chatbot python=3.10 -y + conda activate chatbot + + pip3 install -e . + + # Install pytorch + pip install torch==1.13.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116 + + # Install huggingface with the LLaMA commit + pip install git+https://github.com/huggingface/transformers + + # Install alpaca + git clone https://github.com/tatsu-lab/stanford_alpaca.git + cd stanford_alpaca + pip install -r requirements.txt + cd - + fi + + ln -s /artifacts/chatbot/13b/ckpt/ ~/alpaca-13b + +run: | + conda activate chatbot + WORKER_IP=$(hostname -I | cut -d' ' -f1) + CONTROLLER_PORT=21001 + WORKER_PORT=21002 + python3 -m fastchat.serve.model_worker \ + --model ~/alpaca-13b \ + --controller-address http://${CONTROLLER_IP}:${CONTROLLER_PORT} \ + --worker-address http://${WORKER_IP}:${WORKER_PORT} \ + --host 0.0.0.0 \ + --port ${WORKER_PORT} diff --git a/scripts/tune_script/extract_projector.sh b/scripts/tune_script/extract_projector.sh new file mode 100644 index 0000000..6012a76 --- /dev/null +++ b/scripts/tune_script/extract_projector.sh @@ -0,0 +1,7 @@ +# to fill in the following path to extract projector for the second tuning stage! +src_model= +output_proj= + +python3.8 ./scripts/extract_graph_projector.py \ + --model_name_or_path ${src_model} \ + --output ${output_proj} \ No newline at end of file diff --git a/scripts/tune_script/graphgpt_stage1.sh b/scripts/tune_script/graphgpt_stage1.sh new file mode 100644 index 0000000..075f179 --- /dev/null +++ b/scripts/tune_script/graphgpt_stage1.sh @@ -0,0 +1,39 @@ +# to fill in the following path to run the first stage of our GraphGPT! +model_path= +instruct_ds= +graph_data_path= +pretra_gnn= +output_model= + +wandb offline +python -m torch.distributed.run --nnodes=1 --nproc_per_node=4 --master_port=20001 \ + graphgpt/train/train_mem.py \ + --model_name_or_path ${model_path} \ + --version v1 \ + --data_path ${instruct_ds} \ + --graph_content ./arxiv_ti_ab.json \ + --graph_data_path ${graph_data_path} \ + --graph_tower ${pretra_gnn} \ + --tune_graph_mlp_adapter True \ + --graph_select_layer -2 \ + --use_graph_start_end \ + --bf16 True \ + --output_dir ${output_model} \ + --num_train_epochs 3 \ + --per_device_train_batch_size 2 \ + --per_device_eval_batch_size 2 \ + --gradient_accumulation_steps 1 \ + --evaluation_strategy "no" \ + --save_strategy "steps" \ + --save_steps 2400 \ + --save_total_limit 1 \ + --learning_rate 2e-3 \ + --weight_decay 0. \ + --warmup_ratio 0.03 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --tf32 True \ + --model_max_length 2048 \ + --gradient_checkpointing True \ + --lazy_preprocess True \ + --report_to wandb diff --git a/scripts/tune_script/graphgpt_stage2.sh b/scripts/tune_script/graphgpt_stage2.sh new file mode 100644 index 0000000..5bc3a9a --- /dev/null +++ b/scripts/tune_script/graphgpt_stage2.sh @@ -0,0 +1,42 @@ +# to fill in the following path to run the second stage of our GraphGPT! +model_path= +instruct_ds= +graph_data_path= +pretra_gnn= +tuned_proj= +output_model= + +wandb offline +python -m torch.distributed.run --nnodes=1 --nproc_per_node=4 --master_port=20001 \ + graphgpt/train/train_mem.py \ + --model_name_or_path ${model_path} \ + --version v1 \ + --data_path ${instruct_ds} \ + --graph_content ./arxiv_ti_ab.json \ + --graph_data_path ${graph_data_path} \ + --graph_tower ${pretra_gnn} \ + --pretrain_graph_mlp_adapter ${tuned_proj} \ + --tune_graph_mlp_adapter True \ + --graph_select_layer -2 \ + --use_graph_start_end True\ + --bf16 True \ + --output_dir ${output_model} \ + --num_train_epochs 2 \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --gradient_accumulation_steps 1 \ + --evaluation_strategy "no" \ + --save_strategy "steps" \ + --save_steps 50000 \ + --save_total_limit 1 \ + --learning_rate 2e-5 \ + --weight_decay 0. \ + --warmup_ratio 0.03 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --tf32 True \ + --model_max_length 2048 \ + --gradient_checkpointing True \ + --dataloader_num_workers 4 \ + --lazy_preprocess True \ + --report_to wandb diff --git a/tests/test_openai_curl.sh b/tests/test_openai_curl.sh new file mode 100644 index 0000000..48766f3 --- /dev/null +++ b/tests/test_openai_curl.sh @@ -0,0 +1,32 @@ +set -x + +curl http://localhost:8000/v1/models + +echo + +curl http://localhost:8000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "vicuna-7b-v1.1", + "messages": [{"role": "user", "content": "Hello! What is your name?"}] + }' + +echo + +curl http://localhost:8000/v1/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "vicuna-7b-v1.1", + "prompt": "Once upon a time", + "max_tokens": 41, + "temperature": 0.5 + }' + +echo + +curl http://localhost:8000/v1/embeddings \ + -H "Content-Type: application/json" \ + -d '{ + "model": "vicuna-7b-v1.1", + "input": "Hello world!" + }' diff --git a/tests/test_openai_langchain.py b/tests/test_openai_langchain.py new file mode 100644 index 0000000..78b7f9f --- /dev/null +++ b/tests/test_openai_langchain.py @@ -0,0 +1,38 @@ +# export OPENAI_API_BASE=http://localhost:8000/v1 +# export OPENAI_API_KEY=EMPTY + +from langchain import OpenAI, LLMChain, PromptTemplate +from langchain.memory import ConversationBufferWindowMemory +from langchain.embeddings import OpenAIEmbeddings +import numpy as np + +template = """{history} +Human: {human_input} +Assistant:""" + +def test_embedding(): + embeddings = OpenAIEmbeddings() + texts = ["Why does the chicken cross the road", "To be honest", "Long time ago"] + query_result = embeddings.embed_query(texts[0]) + doc_result = embeddings.embed_documents(texts) + assert np.allclose(query_result, doc_result[0], atol=1e-3) + +def test_chain(): + + prompt = PromptTemplate( + input_variables=["history", "human_input"], + template=template + ) + chain = LLMChain( + llm=OpenAI(model="text-embedding-ada-002", temperature=1), + prompt=prompt, + verbose=True, + memory=ConversationBufferWindowMemory(k=2), + ) + output = chain.predict(human_input="ls ~") + print(output) + +if __name__ == "__main__": + test_embedding() + test_chain() + diff --git a/tests/test_openai_sdk.py b/tests/test_openai_sdk.py new file mode 100644 index 0000000..197838b --- /dev/null +++ b/tests/test_openai_sdk.py @@ -0,0 +1,46 @@ +import openai + +openai.api_key = "EMPTY" # Not support yet +openai.api_base = "http://localhost:8000/v1" + +model = "vicuna-7b-v1.1" + + +def test_list_models(): + model_list = openai.Model.list() + print(model_list["data"][0]["id"]) + + +def test_completion(): + prompt = "Once upon a time" + completion = openai.Completion.create(model=model, prompt=prompt, max_tokens=64) + print(prompt + completion.choices[0].text) + + +def test_embedding(): + embedding = openai.Embedding.create(model=model, input="Hello world!") + print(len(embedding["data"][0]["embedding"])) + + +def test_chat_completion(): + completion = openai.ChatCompletion.create( + model=model, messages=[{"role": "user", "content": "Hello! What is your name?"}] + ) + print(completion.choices[0].message.content) + + +def test_chat_completion_stream(): + messages = [{"role": "user", "content": "Hello! What is your name?"}] + res = openai.ChatCompletion.create(model=model, messages=messages, stream=True) + for chunk in res: + content = chunk["choices"][0]["delta"].get("content", "") + print(content, end="", flush=True) + print() + + +if __name__ == "__main__": + test_list_models() + test_completion() + test_embedding() + test_chat_completion() + test_chat_completion_stream()