diff --git a/README.md b/README.md index c43596475..83f43532b 100644 --- a/README.md +++ b/README.md @@ -206,17 +206,16 @@ architectures as well as important features implemented in TensorRT-LLM. ### Devices -TensorRT-LLM is rigorously tested on the following GPUs: +TensorRT-LLM supports the following architectures: -* [H100](https://www.nvidia.com/en-us/data-center/h100/) -* [L40S](https://www.nvidia.com/en-us/data-center/l40s/) -* [A100](https://www.nvidia.com/en-us/data-center/a100/) -* [A30](https://www.nvidia.com/en-us/data-center/products/a30-gpu/) -* [V100](https://www.nvidia.com/en-us/data-center/v100/) (experimental) +* [NVIDIA Hopper](https://www.nvidia.com/en-us/data-center/technologies/hopper-architecture/) (SM90), for example, H200, H100, H20 +* [NVIDIA Ada Lovelace](https://www.nvidia.com/en-us/geforce/ada-lovelace-architecture/) (SM89), for example, L40S, L20, L4 +* [NVIDIA Ampere](https://www.nvidia.com/en-us/data-center/ampere-architecture/) (SM80, SM86), for example, A100, A30, A10G +* [NVIDIA Turing](https://www.nvidia.com/en-us/geforce/turing/) (SM75), for example, T4 +* [NVIDIA Volta](https://www.nvidia.com/en-us/data-center/volta-gpu-architecture/) (SM70 - experimental), for example, V100 -If a GPU is not listed above, it is important to note that TensorRT-LLM is -expected to work on GPUs based on the Volta, Turing, Ampere, Hopper and Ada -Lovelace architectures. Certain limitations may, however, apply. + +It is important to note that TensorRT-LLM is expected to work on all GPUs based on the Volta, Turing, Ampere, Hopper, and Ada Lovelace architectures. Certain limitations may apply. ### Precision @@ -273,7 +272,7 @@ The list of supported models is: * [Blip2](examples/blip2) * [BLOOM](examples/bloom) * [ChatGLM](examples/chatglm) -* [FairSeq NMT](examples/nmt) +* [FairSeq NMT](examples/enc_dec/nmt) * [Falcon](examples/falcon) * [Flan-T5](examples/enc_dec) * [GPT](examples/gpt) @@ -406,7 +405,7 @@ As a rule of thumb, if you are running TensorRT-LLM interactively on a Slurm node, prefix your commands with `mpirun -n 1` to run TensorRT-LLM in a dedicated MPI environment, not the one provided by your Slurm allocation. -For example: `mpirun -n 1 python3 examples/gpt/build.py ...` +For example: `mpirun -n 1 python3 examples/run.py ...` ## Release notes diff --git a/benchmarks/cpp/README.md b/benchmarks/cpp/README.md index e3e6a200f..0e89c6daf 100644 --- a/benchmarks/cpp/README.md +++ b/benchmarks/cpp/README.md @@ -22,7 +22,7 @@ instead, and be sure to set DLL paths as specified in Before you launch C++ benchmarking, please make sure that you have already built engine(s) using TensorRT-LLM API, C++ benchmarking code cannot generate engine(s) for you. -You can use the [`build.py`](source:benchmarks/python/build.py) script to build the engine(s). Alternatively, if you have already benchmarked Python Runtime, you can reuse the engine(s) built previously, please see that [`document`](../python/README.md). +Use `trtllm-build` to build the TRT-LLM engine. Alternatively, if you have already benchmarked Python Runtime, you can reuse the engine(s) built previously, please see that [`document`](../python/README.md). #### Launch benchmarking @@ -73,19 +73,39 @@ This tool can be used in 2 different modes of traffic generation. ##### 1 – Dataset -“Prompt”, “Instruction” (optional) and “Answer” specified as sentences in a Json file - The tool will tokenize the words and instruct the model to generate a specified number of output tokens for a request. ``` python3 prepare_dataset.py \ + --tokenizer \ --output preprocessed_dataset.json - --request-rate 10 \ - --time-delay-dist exponential_dist \ + [--request-rate 10] \ + [--time-delay-dist exponential_dist] \ + dataset + --dataset-name \ + --dataset-input-key \ + --dataset-prompt-key \ + --dataset-output-key \ + [--num-requests 100] \ + [--max-input-len 1000] \ + [--output-len-dist 100,10] +``` + +For datasets that don't have prompt key, set --dataset-prompt instead. +Take [cnn_dailymail dataset](https://huggingface.co/datasets/cnn_dailymail) for example: +``` +python3 prepare_dataset.py \ --tokenizer \ + --output cnn_dailymail.json dataset - --dataset \ - --max-input-len 300 + --dataset-name cnn_dailymail \ + --dataset-config-name 3.0.0 \ + --dataset-input-key article \ + --dataset-prompt "Summarize the following article:" \ + --dataset-output-key "highlights" \ + [--num-requests 100] \ + [--max-input-len 1000] \ + [--output-len-dist 100,10] ``` ##### 2 – Normal token length distribution @@ -94,7 +114,7 @@ This mode allows the user to generate normal token length distributions with a m For example, setting mean=100 and std dev=10 would generate requests where 95.4% of values are in <80,120> range following the normal probability distribution. Setting std dev=0 will generate all requests with the same mean number of tokens. ``` - python prepare_dataset.py \ +python prepare_dataset.py \ --output token-norm-dist.json \ --request-rate 10 \ --time-delay-dist constant \ diff --git a/benchmarks/cpp/gptManagerBenchmark.cpp b/benchmarks/cpp/gptManagerBenchmark.cpp index 4de58c95f..5b511ea94 100644 --- a/benchmarks/cpp/gptManagerBenchmark.cpp +++ b/benchmarks/cpp/gptManagerBenchmark.cpp @@ -23,9 +23,11 @@ #include "tensorrt_llm/common/mpiUtils.h" #include "tensorrt_llm/common/stringUtils.h" #include "tensorrt_llm/executor/executor.h" +#include "tensorrt_llm/executor/tensor.h" #include "tensorrt_llm/plugins/api/tllmPlugin.h" #include "tensorrt_llm/runtime/common.h" #include "tensorrt_llm/runtime/tllmLogger.h" +#include "tensorrt_llm/runtime/utils/numpyUtils.h" #include "tensorrt_llm/runtime/worldConfig.h" #include @@ -48,6 +50,89 @@ namespace trt = nvinfer1; namespace { +using TensorPtr = ITensor::SharedPtr; + +class LoraLib +{ +public: + LoraLib(std::string const& loraDir) + : mLoraDir(loraDir) + , mBufferManager(std::make_shared()) + , mTaskPaths(parseDirPaths(mLoraDir)) + , mLoras(readLoras(mTaskPaths)) + { + } + + TensorPtr getLoraWeights(uint64_t taskId) const + { + return mLoras.at(taskId).first; + } + + TensorPtr getLoraConfig(uint64_t taskId) const + { + return mLoras.at(taskId).second; + } + + void clear() + { + mLoras.clear(); + } + + std::map> const& getLoras() + { + return mLoras; + } + +private: + std::string const mLoraDir; + BufferManager mBufferManager; + std::map mTaskPaths; + std::map> mLoras; + + std::map> readLoras(std::map taskPaths) + { + std::map> loras; + for (auto const& [id, p] : taskPaths) + { + TensorPtr loraWeights = utils::loadNpy(mBufferManager, p / "model.lora_weights.npy", MemoryType::kCPU); + TensorPtr loraConfig = utils::loadNpy(mBufferManager, p / "model.lora_config.npy", MemoryType::kCPU); + loras.insert_or_assign(id, std::make_pair(loraWeights, loraConfig)); + } + return loras; + } + + std::map parseDirPaths(std::string const& loraDir) + { + std::map taskPaths; + if (loraDir == "") + { + return taskPaths; + } + for (auto const& entry : fs::recursive_directory_iterator(loraDir)) + { + if (entry.is_directory()) + { + auto taskId = parseId(entry.path()); + taskPaths.insert_or_assign(taskId, entry.path()); + } + } + return taskPaths; + } + + uint64_t parseId(fs::path p) + { + auto fn = p.filename().string(); + auto dashPos = fn.find_first_of("-"); + std::string idStr = fn; + if (dashPos != std::string::npos) + { + auto idStr = fn.substr(0, dashPos); + } + uint64_t id = static_cast(std::stoi(idStr)); + return id; + } +}; + struct BenchmarkParams { std::optional maxTokensInPagedKvCache = std::nullopt; @@ -56,6 +141,11 @@ struct BenchmarkParams bool enableBlockReuse = false; bool enableChunkedContext = false; bool streaming = false; + + // lora / peft params + std::optional loraDir = std::nullopt; + SizeType loraDeviceNumModLayers = 0; + size_t loraHostCacheSize = 1024 * 2024 * 1024; }; } // namespace @@ -99,13 +189,13 @@ class WorkItemsQueue } // Note: this function only be called under a lock - bool hasInProgressReqId(const uint64_t reqId) const + bool hasInProgressReqId(uint64_t const reqId) const { return (mInProgressWorkItems.find(reqId) != mInProgressWorkItems.end()); } // Note: this function only be called under a lock - bool hasPendingReqId(const uint64_t reqId) const + bool hasPendingReqId(uint64_t const reqId) const { return (mPendingWorkItemsReqIds.find(reqId) != mPendingWorkItemsReqIds.end()); } @@ -168,7 +258,7 @@ class WorkItemsQueue /// @brief Mark a request as being finished /// @param requestId - void markFinished(const uint64_t requestId) + void markFinished(uint64_t const requestId) { std::lock_guard lock(mMutex); if (hasInProgressReqId(requestId)) @@ -328,15 +418,33 @@ class ExecutorServer , mWaitSleep(waitSleep) , mStaticEmulatedBatchSize(staticEmulatedBatchSize) , mActiveCount(0) + , mShutdown(false) { texec::SchedulerConfig schedulerConfig(batch_scheduler::batchManagerToExecSchedPolicy(schedulerPolicy)); texec::KvCacheConfig kvCacheConfig(benchmarkParams.enableBlockReuse, benchmarkParams.maxTokensInPagedKvCache, - std::nullopt, std::nullopt, benchmarkParams.freeGpuMemoryFraction, false); - texec::ExecutorConfig executorConfig(maxBeamWidth, schedulerConfig, kvCacheConfig, - benchmarkParams.enableChunkedContext, true, benchmarkParams.enableTrtOverlap); + std::nullopt, std::nullopt, benchmarkParams.freeGpuMemoryFraction); + texec::PeftCacheConfig peftCacheConfig(0, benchmarkParams.loraDeviceNumModLayers, 8, 64, 4, 4, 4, 24, 8, + std::nullopt, benchmarkParams.loraHostCacheSize); + texec::ExecutorConfig executorConfig( + maxBeamWidth, schedulerConfig, kvCacheConfig, benchmarkParams.enableChunkedContext, true); + executorConfig.setPeftCacheConfig(peftCacheConfig); mExecutor = std::make_shared(trtEnginePath, texec::ModelType::kDECODER_ONLY, executorConfig); + + if (logIterationData) + { + mCollectStatsThread = std::thread(&ExecutorServer::collectStats, this); + } + } + + ~ExecutorServer() + { + mShutdown = true; + if (mCollectStatsThread.joinable()) + { + mCollectStatsThread.join(); + } } void enqueue(std::vector requests, bool warmup = false) @@ -367,10 +475,10 @@ class ExecutorServer } } - void waitForResponses(std::optional numRequests, bool warmup = false) + void waitForResponses(SizeType numRequests, bool warmup = false) { SizeType numFinished = 0; - while (mActiveCount || (numRequests && numFinished < numRequests.value())) + while (mActiveCount || (numFinished < numRequests)) { auto responses = mExecutor->awaitResponses(std::nullopt, mWaitSleep); for (auto const& response : responses) @@ -396,17 +504,28 @@ class ExecutorServer } } - void shutdown() + void collectStats() { - mExecutor->shutdown(); + while (!mShutdown) + { + auto iterStats = mExecutor->getLatestIterationStats(); + for (auto const& iterStat : iterStats) + { + TLLM_LOG_INFO(texec::JsonSerialization::toJsonStr(iterStat)); + } + auto const waitSleep = std::chrono::milliseconds(50); + std::this_thread::sleep_for(waitSleep); + } } private: std::shared_ptr mExecutor; + std::thread mCollectStatsThread; std::shared_ptr mRecorder; std::chrono::milliseconds mWaitSleep; std::optional mStaticEmulatedBatchSize; std::atomic mActiveCount; + std::atomic mShutdown; }; // class ExecutorServer class GptServer @@ -641,6 +760,7 @@ struct Sample std::vector inputIds; int32_t outputLen; float delay; + int32_t taskId; }; using Samples = std::vector; @@ -659,7 +779,8 @@ Samples parseWorkloadJson(std::filesystem::path const& datasetPath, int maxNumSa { if (samples.size() >= maxNumSamples) break; - samples.emplace_back(Sample{sample["input_ids"], sample["output_len"], sample["delay"]}); + int32_t taskId = sample.count("task_id") ? sample["task_id"].template get() : -1; + samples.emplace_back(Sample{sample["input_ids"], sample["output_len"], sample["delay"], taskId}); } return samples; } @@ -667,7 +788,8 @@ Samples parseWorkloadJson(std::filesystem::path const& datasetPath, int maxNumSa std::shared_ptr makeRequest(std::uint64_t reqId, Sample const& sample, ITensor::SharedPtr const& beamWidthTensor, ITensor::SharedPtr const& eosId, ITensor::SharedPtr const& padId, BufferManager const& bufferManager, ITensor::SharedPtr const& returnContextLogits = nullptr, - ITensor::SharedPtr const& returnGenerationLogits = nullptr) + ITensor::SharedPtr const& returnGenerationLogits = nullptr, ITensor::SharedPtr const& loraWeights = nullptr, + ITensor::SharedPtr const& loraConfig = nullptr) { auto request = std::make_shared(reqId); auto const& inputIds = sample.inputIds; @@ -692,16 +814,36 @@ std::shared_ptr makeRequest(std::uint64_t reqId, Sample const& { request->setReturnGenerationLogits(returnGenerationLogits); } + if (sample.taskId >= 0) + { + uint64_t taskId = static_cast(sample.taskId); + request->setLoraTaskId(bufferManager.copyFrom(&taskId, ITensor::makeShape({1}), MemoryType::kPINNED)); + } + if (loraWeights) + { + request->setLoraWeights(loraWeights); + } + if (loraConfig) + { + request->setLoraConfig(loraConfig); + } return request; } texec::Request makeExecutorRequest(Sample const& sample, SizeType const& beamWidth, std::optional const& eosId, std::optional const& padId, bool streaming = false, - bool const& returnContextLogits = false, bool const& returnGenerationLogits = false) + bool const& returnContextLogits = false, bool const& returnGenerationLogits = false, + std::optional const& loraConfig = std::nullopt) { auto samplingConfig = texec::SamplingConfig{beamWidth}; auto outputConfig = texec::OutputConfig{false, returnContextLogits, returnGenerationLogits, false}; - return {sample.inputIds, sample.outputLen, streaming, samplingConfig, outputConfig, eosId, padId}; + return texec::Request(sample.inputIds, sample.outputLen, streaming, samplingConfig, outputConfig, eosId, padId, + std::nullopt, // badWords + std::nullopt, // stopWords + std::nullopt, // embeddingBias + std::nullopt, // speculativeDecoding + std::nullopt, // pTuning + loraConfig); } void benchmarkGptManager(std::filesystem::path const& engineDir, TrtGptModelType modelType, @@ -727,6 +869,11 @@ void benchmarkGptManager(std::filesystem::path const& engineDir, TrtGptModelType optionalParams.kvCacheConfig.enableBlockReuse = benchmarkParams.enableBlockReuse; optionalParams.enableChunkedContext = benchmarkParams.enableChunkedContext; optionalParams.enableTrtOverlap = benchmarkParams.enableTrtOverlap; + optionalParams.peftCacheManagerConfig.hostCacheSize = benchmarkParams.loraHostCacheSize; + optionalParams.peftCacheManagerConfig.numDeviceModuleLayer = benchmarkParams.loraDeviceNumModLayers; + optionalParams.peftCacheManagerConfig.numPutWorkers = 4; + optionalParams.peftCacheManagerConfig.numEnsureWorkers = 4; + optionalParams.peftCacheManagerConfig.numCopyStreams = 4; BufferManager bufferManager{std::make_shared()}; // the stream is not used @@ -758,6 +905,29 @@ void benchmarkGptManager(std::filesystem::path const& engineDir, TrtGptModelType if (worldConfig.getRank() == 0) { + if (benchmarkParams.loraDir) + { + auto startLoraLoad = std::chrono::steady_clock::now(); + LoraLib loras(benchmarkParams.loraDir.value()); + SizeType reqId = 0; + for (auto const& [taskId, p] : loras.getLoras()) + { + reqId++; + if (reqId == terminateReqId) + { + reqId++; + } + Sample s{std::vector{1, 2, 3, 4, 5}, 1, 0.f, static_cast(taskId)}; + auto r = makeRequest(reqId, s, beamWidthTensor, eosIdTensor, padIdTensor, bufferManager, nullptr, + nullptr, p.first, p.second); + gptServer->enqueue(r); + } + gptServer->waitForEmpty(); + auto endLoraLoad = std::chrono::steady_clock::now(); + printf("[BENCHMARK] time to preload LoRAs(ms) %.2f\n", + std::chrono::duration(endLoraLoad - startLoraLoad).count()); + } + // Warm up gptServer->resetBatchDeadline(); SizeType reqId = 0; @@ -820,6 +990,24 @@ void benchmarkExecutor(std::filesystem::path const& engineDir, TrtGptModelType m if (worldRank == 0) { + if (benchmarkParams.loraDir) + { + auto startLoraLoad = std::chrono::steady_clock::now(); + LoraLib loras(benchmarkParams.loraDir.value()); + std::vector requests; + for (auto& [taskId, p] : loras.getLoras()) + { + texec::LoraConfig loraConfig( + taskId, texec::detail::ofITensor(p.first), texec::detail::ofITensor(p.second)); + Sample s{std::vector{1, 2, 3, 4, 5}, 1, 0.f, static_cast(taskId)}; + requests.emplace_back(makeExecutorRequest(s, beamWidth, eosId, padId, false, false, false, loraConfig)); + } + executorServer->enqueue(std::move(requests), true); + executorServer->waitForResponses(loras.getLoras().size(), true); + auto endLoraLoad = std::chrono::steady_clock::now(); + printf("[BENCHMARK] time to preload LoRAs(ms) %.2f\n", + std::chrono::duration(endLoraLoad - startLoraLoad).count()); + } // Warm up { std::vector requests; @@ -840,8 +1028,13 @@ void benchmarkExecutor(std::filesystem::path const& engineDir, TrtGptModelType m std::vector delays; for (std::size_t i = 0; i < numSamples; ++i) { + std::optional loraConfig; + if (samples[i].taskId >= 0) + { + loraConfig = texec::LoraConfig(samples[i].taskId); + } requests.emplace_back(makeExecutorRequest(samples[i], beamWidth, eosId, padId, - benchmarkParams.streaming, returnContextLogits, returnGenerationLogits)); + benchmarkParams.streaming, returnContextLogits, returnGenerationLogits, loraConfig)); delays.push_back(static_cast(samples[i].delay * 1000)); } @@ -957,6 +1150,9 @@ int main(int argc, char* argv[]) cxxopts::value()->default_value("false")); options.add_options()("wait_sleep", "Specify how many milliseconds to sleep each iteration of waitForEmpty loop.", cxxopts::value()->default_value("25")); + options.add_options()("lora_dir", "Directory containing LoRAs", cxxopts::value()->default_value("")); + options.add_options()("lora_host_cache_bytes", "LoRA host cache memory in bytes", cxxopts::value()); + options.add_options()("lora_num_device_mod_layers", "LoRA number 1d cache rows", cxxopts::value()); auto result = options.parse(argc, argv); @@ -1039,6 +1235,19 @@ int main(int argc, char* argv[]) // Argument: Enable return context logits bool returnGenerationLogits = result["return_generation_logits"].as(); + if (result.count("lora_dir")) + { + benchmarkParams.loraDir = result["lora_dir"].as(); + } + if (result.count("lora_host_cache_bytes")) + { + benchmarkParams.loraHostCacheSize = result["lora_host_cache_bytes"].as(); + } + if (result.count("lora_num_device_mod_layers")) + { + benchmarkParams.loraDeviceNumModLayers = result["lora_num_device_mod_layers"].as(); + } + std::optional padId; // Argument: Padding token id if (result.count("pad_id")) diff --git a/benchmarks/cpp/prepare_dataset.py b/benchmarks/cpp/prepare_dataset.py index a47603fa4..10dd1a202 100644 --- a/benchmarks/cpp/prepare_dataset.py +++ b/benchmarks/cpp/prepare_dataset.py @@ -1,4 +1,19 @@ -from typing import Literal +# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from typing import Literal, Optional, Tuple import click from pydantic import BaseModel, field_validator @@ -15,14 +30,16 @@ class RootArgs(BaseModel): mean_time_bet_reqs: float time_delay_dist: Literal["constant", "exponential_dist"] random_seed: int + task_id: int + rand_task_id: Optional[Tuple[int, int]] @field_validator('tokenizer') def get_tokenizer(cls, v: str): try: tokenizer = AutoTokenizer.from_pretrained(v, padding_side='left') - except EnvironmentError: + except EnvironmentError as e: raise ValueError( - "Cannot find a tokenizer from the given string. Please set tokenizer to the directory that contains the tokenizer, or set to a model name in HuggingFace." + f"Cannot find a tokenizer from the given string because of {e}\nPlease set tokenizer to the directory that contains the tokenizer, or set to a model name in HuggingFace." ) tokenizer.pad_token = tokenizer.eos_token return tokenizer @@ -49,23 +66,40 @@ def get_tokenizer(cls, v: str): type=click.Choice(["constant", "exponential_dist"]), help="Distribution of the time delay.", default="exponential_dist") -@click.option( - "--random-seed", - required=False, - type=int, - help= - "random seed for exponential delays (dataset/norm-token-dist) and token_ids(norm-token-dist)", - default=420) +@click.option("--random-seed", + required=False, + type=int, + help="random seed for exponential delays and token_ids", + default=420) +@click.option("--task-id", type=int, default=-1, help="LoRA task id") +@click.option("--rand-task-id", + type=int, + default=None, + nargs=2, + help="Random LoRA Tasks") +@click.option("--log-level", + default="info", + type=click.Choice(['info', 'debug']), + help="Logging level.") @click.pass_context def cli(ctx, **kwargs): """This script generates dataset input for gptManagerBenchmark.""" + if kwargs['log_level'] == 'info': + logging.basicConfig(level=logging.INFO) + elif kwargs['log_level'] == 'debug': + logging.basicConfig(level=logging.DEBUG) + else: + raise ValueError(f"Unsupported logging level {kwargs['log_level']}") + ctx.obj = RootArgs(tokenizer=kwargs['tokenizer'], output=kwargs['output'], request_rate=kwargs['request_rate'], mean_time_bet_reqs=get_req_time_interval( kwargs['request_rate']), time_delay_dist=kwargs['time_delay_dist'], - random_seed=kwargs['random_seed']) + random_seed=kwargs['random_seed'], + task_id=kwargs['task_id'], + rand_task_id=kwargs['rand_task_id']) cli.add_command(dataset) diff --git a/examples/gpt/utils/__init__.py b/benchmarks/cpp/utils/convert_nemo_dataset.py similarity index 50% rename from examples/gpt/utils/__init__.py rename to benchmarks/cpp/utils/convert_nemo_dataset.py index 71bf6d298..6f4884347 100644 --- a/examples/gpt/utils/__init__.py +++ b/benchmarks/cpp/utils/convert_nemo_dataset.py @@ -12,3 +12,36 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +#!/usr/bin/env python3 + +import argparse +import json + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("input") + parser.add_argument("output") + + args = parser.parse_args() + + output_o = [] + + with open(args.input, 'r') as infile: + for _l in infile: + l = _l.strip() + if len(l) == 0: + continue + o = json.loads(l) + output_o.append({ + "input": o["prompt"], + "instruction": "", + "output": o["completion"] + }) + + with open(args.output, 'w') as outfile: + json.dump(output_o, outfile) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/cpp/utils/generate_rand_loras.py b/benchmarks/cpp/utils/generate_rand_loras.py new file mode 100644 index 000000000..12eb1fdc3 --- /dev/null +++ b/benchmarks/cpp/utils/generate_rand_loras.py @@ -0,0 +1,48 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#!/usr/bin/env python3 + +import argparse +import os +from pathlib import Path + +import numpy as np + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("input_lora") + parser.add_argument("output") + parser.add_argument("num_loras", type=int) + + args = parser.parse_args() + + lora_path = Path(args.input_lora) + weights_path = lora_path / "model.lora_weights.npy" + config_path = lora_path / "model.lora_config.npy" + + weights = np.load(weights_path) + config = np.load(config_path) + + for i in range(args.num_loras): + out_path = Path(args.output) / str(i) + os.makedirs(out_path, exist_ok=True) + w = np.random.normal(0, 2, weights.shape).astype(weights.dtype) + np.save(out_path / "model.lora_weights.npy", w) + np.save(out_path / "model.lora_config.npy", config) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/cpp/utils/prepare_real_data.py b/benchmarks/cpp/utils/prepare_real_data.py index 052e90506..2ee8bd744 100644 --- a/benchmarks/cpp/utils/prepare_real_data.py +++ b/benchmarks/cpp/utils/prepare_real_data.py @@ -1,76 +1,233 @@ -import json +import logging +import random +import re +from typing import Optional import click -from utils.utils import dataset_dump, get_list_of_delays +from datasets import load_dataset +from pydantic import BaseModel, model_validator +from utils.utils import dataset_dump, get_list_of_delays, get_norm_dist_tokens + + +def validate_output_len_dist(ctx, param, value): + """Validate the --output-len-dist option.""" + if value is None: + return value + m = re.match(r"(\d+),(\d+)", value) + if m: + return int(m.group(1)), int(m.group(2)) + else: + raise AssertionError( + "Incorrect specification for --output-len-dist. Correct format: --output-len-dist ," + ) + + +class DatasetConfig(BaseModel): + """Dataset configurations.""" + """Name of the dataset on HuggingFace.""" + name: str + """Config name of the dataset if existing.""" + config_name: Optional[str] = None + """Split of the dataset. Typical values: train, validation, test. Setting to None will include all splits.""" + split: Optional[str] + """The dataset dictionary used for the input sentence.""" + input_key: str + """The dataset dictionary key used for the prompt of the input sentence. Must not be set when prompt is set.""" + prompt_key: Optional[str] = None + """The prompt sentence to be added to the input sentence. Must not be set when prompt_key is set.""" + prompt: Optional[str] = None + """The dataset dictionary key used to derive the output sequence length. Set to None if the dataset does not have a key for output.""" + output_key: Optional[str] + + @model_validator(mode='after') + def check_prompt(self) -> 'DatasetConfig': + if self.prompt_key and self.prompt: + raise AssertionError( + "--prompt-key and --prompt cannot be set at the same time.") + if (not self.prompt_key) and (not self.prompt): + raise AssertionError("Either --prompt-key or --prompt must be set.") + return self + + @property + def query(self): + """Generate the query for HuggingFace `datasets.load_dataset()`""" + if self.config_name: + return [self.name, self.config_name] + else: + return [self.name] + + def get_prompt(self, req): + """Get the prompt sentence from the given request.""" + if self.prompt_key: + assert self.prompt_key in req, ( + f"Dataset {self.name} does not have key '{self.prompt_key}'. " + "Please set --prompt-key to one of the available keys: " + f"{req.keys()}") + return req[self.prompt_key] + else: + return self.prompt + + def get_input(self, req): + """Get the input sentence from the given request.""" + assert self.input_key in req, ( + f"Dataset {self.name} does not have key '{self.input_key}'. " + "Please set --input-key to one of the available keys: " + f"{req.keys()}") + return req[self.input_key] + + def get_output(self, req): + """Get the output sentence from the given request.""" + if self.output_key is None: + raise RuntimeError( + "--output-key is not set. Please either:\n" + "1. Define output length through --output-len-dist.\n" + f"2. If the dataset {self.name} has key for golden output and " + "you wish to set output length to the length of the golden " + "output, set --output-key.") + assert self.output_key in req, ( + f"Dataset {self.name} does not have key '{self.output_key}'. " + "Please set --output-key to one of the available keys: " + f"{req.keys()}") + return req[self.output_key] + + +def load_dataset_from_hf(dataset_config: DatasetConfig): + """Load dataset from HuggingFace. + + Args: + dataset_config: A `DatasetConfig` object that defines the dataset to load. + Returns: + Dataset iterator. + Raises: + ValueError: When dataset loading fails due to incorrect dataset config setting. + """ + try: + dataset = iter( + load_dataset(*dataset_config.query, + split=dataset_config.split, + streaming=True)) + except ValueError as e: + if "Config" in e: + e += "\n Please add the config name to the dataset config yaml." + elif "split" in e: + e += "\n Please specify supported split in the dataset config yaml." + raise ValueError(e) + + return dataset @click.command() -@click.option("--dataset", +@click.option("--dataset-name", + required=True, + type=str, + help=f"Dataset name in HuggingFace.") +@click.option("--dataset-config-name", + type=str, + default=None, + help=f"Dataset config name in HuggingFace (if exists).") +@click.option( + "--dataset-split", + type=str, + default=None, + help=f"Split of the dataset to use. Default will include all splits.") +@click.option("--dataset-input-key", required=True, type=str, - help='Dataset path used for the test.') + help=f"The dataset dictionary key for input.") +@click.option("--dataset-prompt-key", + type=str, + default=None, + help=f"The dataset dictionary key for prompt (if exists).") +@click.option( + "--dataset-prompt", + type=str, + default=None, + help=f"The prompt string when there is no prompt key for the dataset.") +@click.option("--dataset-output-key", + type=str, + default=None, + help=f"The dataset dictionary key for output (if exists).") @click.option( "--num-requests", type=int, default=None, help= - 'Number of requests to be generated. Default is dataset length. Will be capped to min(dataset, num_requests).' + "Number of requests to be generated. Will be capped to min(dataset.num_rows, num_requests)." +) +@click.option( + "--max-input-len", + type=int, + default=None, + help= + "Maximum input sequence length for a given request. This will be used to filter out the requests with long input sequence length. Default will include all the requests." ) @click.option( - "--op-tokens-per-word", - type=float, - default=1.3, + "--output-len-dist", + type=str, + default=None, + callback=validate_output_len_dist, help= - 'Specify op tokens/word ratio. Useful to have model generate exactly as many tokens as needed by the dataset.' + "Output length distribution. Default will be the length of the golden output from the dataset. Format: ,. E.g. 100,10 will randomize the output length with mean=100 and variance=10." ) -@click.option("--max-input-len", - type=int, - default=500000, - help='Specify max input length.') @click.pass_obj def dataset(root_args, **kwargs): """Prepare dataset from real dataset.""" - prompt_cnt = 0 + dataset_config = DatasetConfig( + **{k[8:]: v + for k, v in kwargs.items() if k.startswith('dataset_')}) + input_ids = [] output_lens = [] - ratio = [] - - with open(kwargs['dataset'], 'r') as f: - data_dict = json.load(f) - - if kwargs['num_requests'] is None: - kwargs['num_requests'] = len(data_dict) - else: - kwargs['num_requests'] = min(kwargs['num_requests'], len(data_dict)) - - for req in data_dict: - prompt = req['input'] + ' ' + req['instruction'] - output = req['output'] + task_ids = [] + req_cnt = 0 + for req in load_dataset_from_hf(dataset_config): + # input + prompt = dataset_config.get_prompt( + req) + ' ' + dataset_config.get_input(req) + logging.debug(f"Input sequence: {prompt}") line = root_args.tokenizer.encode(prompt) - if len(line) > kwargs['max_input_len']: + if kwargs['max_input_len'] and len(line) > kwargs['max_input_len']: continue + input_ids.append(line) - prompt_cnt += 1 - if prompt_cnt > kwargs['num_requests']: + # output if fetch from golden + if kwargs['output_len_dist'] is None: + output_lens.append( + len(root_args.tokenizer.encode(dataset_config.get_output(req)))) + + # lora task id + task_id = root_args.task_id + if root_args.rand_task_id is not None: + min_id, max_id = root_args.rand_task_id + task_id = random.randint(min_id, max_id) + task_ids.append(task_id) + + req_cnt += 1 + if kwargs['num_requests'] and req_cnt >= kwargs['num_requests']: break - input_ids.append(line) - output_lens.append( - int(len(output.split(' ')) * kwargs['op_tokens_per_word'])) + if kwargs['num_requests'] and len(input_ids) < kwargs['num_requests']: + logging.warning( + "Number of requests is smaller than the num-requests user set.") + + # output if randomized + if kwargs['output_len_dist'] is not None: + osl_mean, osl_stdev = kwargs['output_len_dist'] + output_lens = get_norm_dist_tokens(osl_mean, osl_stdev, len(input_ids), + root_args.random_seed) - prompt_tokens = len(line) - prompt_words = len(prompt.split()) - ratio.append(prompt_tokens / prompt_words) + logging.debug(f"Input lengths: {[len(i) for i in input_ids]}") + logging.debug(f"Output lengths: {output_lens}") delays = get_list_of_delays(root_args.time_delay_dist, root_args.mean_time_bet_reqs, len(input_ids), root_args.random_seed) dataset_dump( - input_ids, output_lens, delays, { + input_ids, output_lens, delays, task_ids, { "workload_type": "dataset", "tokenizer": root_args.tokenizer.__class__.__name__, - "num_requests": kwargs['num_requests'], + "num_requests": len(input_ids), "delay_distr": root_args.time_delay_dist, "request_rate": root_args.request_rate }, root_args.output) diff --git a/benchmarks/cpp/utils/prepare_synthetic_data.py b/benchmarks/cpp/utils/prepare_synthetic_data.py index f671f06ca..38c7794a9 100644 --- a/benchmarks/cpp/utils/prepare_synthetic_data.py +++ b/benchmarks/cpp/utils/prepare_synthetic_data.py @@ -1,3 +1,5 @@ +import random + import click from utils.utils import (dataset_dump, gen_random_tokens, get_list_of_delays, get_norm_dist_tokens) @@ -30,6 +32,7 @@ def token_norm_dist(root_args, **kwargs): input_ids = [] input_lens = [] output_lens = [] + task_ids = [] input_lens = get_norm_dist_tokens(kwargs['input_mean'], kwargs['input_stdev'], @@ -47,8 +50,14 @@ def token_norm_dist(root_args, **kwargs): input_ids = gen_random_tokens(input_lens, root_args.tokenizer, root_args.random_seed) + if root_args.rand_task_id is None: + task_ids = [root_args.task_id for _ in range(num_reqs)] + else: + min_id, max_id = root_args.rand_task_id + task_ids = [random.randint(min_id, max_id) for _ in range(num_reqs)] + dataset_dump( - input_ids, output_lens, delays, { + input_ids, output_lens, delays, task_ids, { "workload_type": "token-norm-dist", "input_mean": kwargs['input_mean'], "input_stdev": kwargs['input_stdev'], diff --git a/benchmarks/cpp/utils/utils.py b/benchmarks/cpp/utils/utils.py index 7a79cf3cf..b722c63ea 100644 --- a/benchmarks/cpp/utils/utils.py +++ b/benchmarks/cpp/utils/utils.py @@ -11,6 +11,7 @@ class Sample(BaseModel): input_ids: List[int] output_len: int delay: float + task_id: int class Workload(BaseModel): @@ -31,13 +32,15 @@ def setup_workload_name(self): self.metadata.setdefault('workload_name', workload_name) -def dataset_dump(input_ids, output_lens, delays, metadata, output_file): +def dataset_dump(input_ids, output_lens, delays, task_ids, metadata, + output_file): samples = [] for i in range(len(input_ids)): samples.append( Sample(input_ids=input_ids[i], output_len=output_lens[i], - delay=delays[i])) + delay=delays[i], + task_id=task_ids[i])) workload = Workload(metadata=metadata, samples=samples) with open(output_file, 'w') as f: json.dump(workload.dict(), f) @@ -79,10 +82,15 @@ def gen_random_tokens(ip_lens, tokenizer, random_seed): input_ids = [] random.seed(random_seed) for ip_len in ip_lens: - start_ids = random.sample(range(0, tokenizer.vocab_size), ip_len) + start_ids = [ + random.randint(0, tokenizer.vocab_size - 1) for _ in range(ip_len) + ] # Make sure it does not contain EOS token while set(tokenizer.encode(tokenizer.eos_token)).issubset(start_ids): - start_ids = random.sample(range(0, tokenizer.vocab_size), ip_len) + start_ids = [ + random.randint(0, tokenizer.vocab_size - 1) + for _ in range(ip_len) + ] input_ids.append(start_ids) return input_ids diff --git a/benchmarks/python/build.py b/benchmarks/python/build.py index 1aaa319a9..9cbea6045 100644 --- a/benchmarks/python/build.py +++ b/benchmarks/python/build.py @@ -31,7 +31,6 @@ from tensorrt_llm._utils import str_dtype_to_trt from tensorrt_llm.builder import Builder from tensorrt_llm.functional import LayerNormPositionType, LayerNormType -from tensorrt_llm.layers import MoeConfig, PositionEmbeddingType from tensorrt_llm.logger import logger from tensorrt_llm.mapping import Mapping from tensorrt_llm.models import PretrainedConfig, quantize_model @@ -312,26 +311,44 @@ def build_gpt(args): # Initialize Module family = get_model_family(args.model) if family == "gpt": - tensorrt_llm_model = tensorrt_llm.models.GPTLMHeadModel( - num_layers=build_config['num_layers'], - num_heads=build_config['num_heads'], - hidden_size=build_config['hidden_size'], - vocab_size=build_config['vocab_size'], - hidden_act=build_config['hidden_act'], - max_position_embeddings=build_config['n_positions'], - dtype=kv_dtype, - mapping=tensorrt_llm.Mapping(world_size=world_size, - tp_size=world_size), # TP only - apply_query_key_layer_scaling=builder_config. - apply_query_key_layer_scaling, - position_embedding_type=PositionEmbeddingType.learned_absolute - if build_config['position_embedding_type'] is None else - PositionEmbeddingType[build_config['position_embedding_type']], - rotary_embedding_percentage=build_config['rotary_pct'], - quant_mode=quant_mode, - bias=build_config['bias'], - moe_config=MoeConfig(build_config["moe_num_experts"], - build_config["moe_top_k"])) + if build_config['num_kv_heads'] is None: + build_config['num_kv_heads'] = build_config['num_heads'] + if build_config['inter_size'] is None: + build_config['inter_size'] = build_config['hidden_size'] * 4 + if build_config['position_embedding_type'] is None: + build_config['position_embedding_type'] = 'learned_absolute' + quant_algo, kv_cache_quant_algo = get_quant_algo(args.quantization) + config = { + 'architecture': 'GPTForCausalLM', + 'dtype': args.dtype, + 'num_hidden_layers': build_config['num_layers'], + 'num_attention_heads': build_config['num_heads'], + 'num_key_value_heads': build_config['num_kv_heads'], + 'hidden_size': build_config['hidden_size'], + 'intermediate_size': build_config['inter_size'], + 'norm_epsilon': 1e-05, + 'vocab_size': build_config['vocab_size'], + 'position_embedding_type': build_config['position_embedding_type'], + 'max_position_embeddings': build_config['n_positions'], + 'hidden_act': build_config['hidden_act'], + 'quantization': { + 'quant_algo': quant_algo, + 'kv_cache_quant_algo': kv_cache_quant_algo, + 'group_size': 128, + }, + 'mapping': { + 'world_size': world_size, + 'tp_size': world_size, + }, + 'bias': build_config['bias'], + 'apply_query_key_layer_scaling': + builder_config.apply_query_key_layer_scaling, + 'rotary_pct': build_config['rotary_pct'], + 'moe_num_experts': build_config["moe_num_experts"], + 'moe_top_k': build_config["moe_top_k"], + } + config = PretrainedConfig.from_dict(config) + tensorrt_llm_model = tensorrt_llm.models.GPTForCausalLM(config) elif family == "opt": config = { 'architecture': 'OPTForCausalLM', @@ -374,6 +391,8 @@ def build_gpt(args): else build_config['num_kv_heads'], 'hidden_size': build_config['hidden_size'], + 'intermediate_size': + build_config['inter_size'], 'vocab_size': build_config['vocab_size'], 'position_embedding_type': @@ -731,10 +750,14 @@ def build_gpt(args): 'num_hidden_layers': build_config['num_layers'], 'num_attention_heads': build_config['num_heads'], 'hidden_act': build_config['hidden_act'], - "ssm_cfg": {}, - "rms_norm": True, - "residual_in_fp32": True, - "pad_vocab_size_multiple": 8, + 'ssm_cfg': { + 'd_state': build_config['mamba_d_state'], + 'd_conv': build_config['mamba_d_conv'], + 'expand': build_config['mamba_expand'] + }, + 'rms_norm': True, + 'residual_in_fp32': True, + 'pad_vocab_size_multiple': 8, } config = PretrainedConfig.from_dict(config) tensorrt_llm_model = tensorrt_llm.models.MambaLMHeadModel(config) @@ -742,7 +765,10 @@ def build_gpt(args): raise Exception(f'Unexpected model: {args.model}') quant_kwargs = {} - if family not in ['opt', 'bloom', 'falcon', 'llama', 'gptj', 'internlm']: + if family not in [ + 'gpt', 'opt', 'bloom', 'falcon', 'llama', 'internlm', 'gptneox', + 'gptj', 'mamba', 'baichuan', 'chatglm', 'chatglm2', 'chatglm3' + ]: tensorrt_llm_model = quantize_model(tensorrt_llm_model, quant_mode, **quant_kwargs) @@ -768,11 +794,7 @@ def build_gpt(args): # Quantization plugins. if use_smooth_quant: - network.plugin_config.set_smooth_quant_gemm_plugin(dtype=args.dtype) - network.plugin_config.set_layernorm_quantization_plugin( - dtype=args.dtype) - network.plugin_config.set_quantize_tensor_plugin() - network.plugin_config.set_quantize_per_token_plugin() + network.plugin_config.set_smooth_quant_plugins(dtype=args.dtype) elif use_weight_only: network.plugin_config.set_weight_only_quant_matmul_plugin( dtype=args.dtype) @@ -805,7 +827,7 @@ def build_gpt(args): use_cache=True, max_beam_width=max_beam_width) if family in [ - 'opt', 'bloom', 'falcon', 'llama', 'internlm', 'gptneox', + 'gpt', 'opt', 'bloom', 'falcon', 'llama', 'internlm', 'gptneox', 'gptj', 'mamba', 'baichuan', 'chatglm', 'chatglm2', 'chatglm3' ]: tensorrt_llm_model(**inputs) @@ -874,7 +896,8 @@ def build_bert(args): max_position_embeddings=build_config['n_positions'], max_batch_size=max_batch_size, max_input_len=max_input_len, - opt_level=build_config['builder_opt']) + opt_level=build_config['builder_opt'], + strongly_typed=args.strongly_typed) engine_name = get_engine_name(args.model, args.dtype, world_size, runtime_rank) diff --git a/benchmarks/python/mem_monitor.py b/benchmarks/python/mem_monitor.py index 5b654310f..f60ce8f0a 100644 --- a/benchmarks/python/mem_monitor.py +++ b/benchmarks/python/mem_monitor.py @@ -52,14 +52,13 @@ def kill(self): def stop(self): self.signal_event.set() logger.debug("Sent signal to stop memory monitor subprocess.") - - peak_mem_use = self.peak_mem_queue.get(timeout=10) + peak_mem_use = self.peak_mem_queue.get(timeout=20) self._peak_host_memory = max(self._peak_host_memory, peak_mem_use[0]) self._peak_device_memory = max(self._peak_device_memory, peak_mem_use[1]) - self.mem_monitor_process.join(timeout=10) + self.mem_monitor_process.join(timeout=20) self.mem_monitor_process = None logger.debug("Memory monitor subprocess joined.") diff --git a/cpp/include/tensorrt_llm/batch_manager/GptManager.h b/cpp/include/tensorrt_llm/batch_manager/GptManager.h index 292c66b82..f3c413ac5 100644 --- a/cpp/include/tensorrt_llm/batch_manager/GptManager.h +++ b/cpp/include/tensorrt_llm/batch_manager/GptManager.h @@ -21,6 +21,8 @@ #include "tensorrt_llm/batch_manager/llmRequest.h" #include "tensorrt_llm/batch_manager/schedulerPolicy.h" #include "tensorrt_llm/batch_manager/trtGptModelOptionalParams.h" +#include "tensorrt_llm/runtime/gptModelConfig.h" +#include "tensorrt_llm/runtime/worldConfig.h" #include #include @@ -86,7 +88,8 @@ class GptManager [[nodiscard]] SizeType getMaxSequenceLen() const; [[nodiscard]] SizeType getMaxNumSequences() const; - void validateLlmRequest(LlmRequest& newReq) const; + void validateLlmRequest( + LlmRequest& newReq, runtime::GptModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig) const; static std::shared_ptr fillLlmRequest(std::shared_ptr newReq); static std::shared_ptr> getReqInputTokens(std::shared_ptr newReq); static SizeType getMaxNewTokens(std::shared_ptr newReq); diff --git a/cpp/include/tensorrt_llm/batch_manager/inferenceRequest.h b/cpp/include/tensorrt_llm/batch_manager/inferenceRequest.h index d94d8c1f7..e1f56497c 100644 --- a/cpp/include/tensorrt_llm/batch_manager/inferenceRequest.h +++ b/cpp/include/tensorrt_llm/batch_manager/inferenceRequest.h @@ -59,6 +59,7 @@ auto constexpr kReturnContextLogitsTensorName = "return_context_logits"; auto constexpr kReturnGenerationLogitsTensorName = "return_generation_logits"; auto constexpr kPromptEmbeddingTableName = "prompt_embedding_table"; auto constexpr kPromptVocabSizeName = "prompt_vocab_size"; +auto constexpr kLoraTaskId = "lora_task_id"; // weights for a lora adapter shape [ num_lora_modules_layers, D x Hi + Ho x D ] // where the last dimension holds the in / out adapter weights for the associated module (e.g. attn_qkv) and model layer // each of the in / out tensors are first flattened and then concatenated together in the format above. @@ -191,6 +192,7 @@ class GenericInferenceRequest inference_request::kPromptVocabSizeName, // obsolete names for backward compatibility inference_request::kInputLengthsTensorName, + inference_request::kLoraTaskId, inference_request::kLoraWeights, inference_request::kLoraConfig, }; @@ -255,6 +257,7 @@ class GenericInferenceRequest TENSOR_GETTER_SETTER(ReturnGenerationLogits, inference_request::kReturnGenerationLogitsTensorName) TENSOR_GETTER_SETTER(PromptEmbeddingTable, inference_request::kPromptEmbeddingTableName) TENSOR_GETTER_SETTER(PromptVocabSize, inference_request::kPromptVocabSizeName) + TENSOR_GETTER_SETTER(LoraTaskId, inference_request::kLoraTaskId) TENSOR_GETTER_SETTER(LoraWeights, inference_request::kLoraWeights) TENSOR_GETTER_SETTER(LoraConfig, inference_request::kLoraConfig) diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheConfig.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheConfig.h index 91be3fed3..2a928faed 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheConfig.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheConfig.h @@ -46,7 +46,7 @@ class KvCacheConfig explicit KvCacheConfig(executor::KvCacheConfig const& kvCacheConfig) : KvCacheConfig(kvCacheConfig.getMaxTokens(), kvCacheConfig.getMaxAttentionWindow(), kvCacheConfig.getSinkTokenLength(), kvCacheConfig.getFreeGpuMemoryFraction(), - kvCacheConfig.getEnableBlockReuse(), kvCacheConfig.getUseUvm()) + kvCacheConfig.getEnableBlockReuse(), false) { } diff --git a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h index 43ef4bd89..5d4b824de 100644 --- a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h +++ b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h @@ -25,12 +25,14 @@ #include #include #include +#include #include #include namespace tensorrt_llm::batch_manager { +// TODO(rkobus): refactor enum LlmRequestState_t { REQUEST_STATE_UNKNOWN = 0, @@ -46,6 +48,7 @@ class GenericLlmRequest using SizeType = runtime::SizeType; using TokenIdType = runtime::TokenIdType; using RequestIdType = std::uint64_t; + using LoraTaskIdType = std::uint64_t; using VecTokens = std::vector; using VecLogProbs = std::vector; using BeamTokens = std::vector; @@ -57,9 +60,9 @@ class GenericLlmRequest std::optional padId = std::nullopt, std::optional embeddingBias = std::nullopt, std::optional badWordsList = std::nullopt, std::optional stopWordsList = std::nullopt, std::optional promptEmbeddingTable = std::nullopt, - std::optional promptVocabSize = std::nullopt, std::optional loraWeights = std::nullopt, - std::optional loraConfig = std::nullopt, bool returnLogProbs = false, - bool returnContextLogits = false, bool returnGenerationLogits = false, + std::optional promptVocabSize = std::nullopt, std::optional loraTaskId = std::nullopt, + std::optional loraWeights = std::nullopt, std::optional loraConfig = std::nullopt, + bool returnLogProbs = false, bool returnContextLogits = false, bool returnGenerationLogits = false, std::optional> draftTokens = std::nullopt, std::optional draftLogits = std::nullopt, bool excludeInputFromOutput = false, std::optional logitsPostProcessor = std::nullopt) @@ -71,7 +74,6 @@ class GenericLlmRequest , mIsStreaming(isStreaming) , mEndId(endId) , mPadId(padId) - , mSeqSlot(-1) , mLogitsPostProcessor(logitsPostProcessor) , mOrigPromptLen(mPromptLen) , mMaxSentTokenPos(mPromptLen - 1) @@ -80,6 +82,7 @@ class GenericLlmRequest , mStopWordsList(std::move(stopWordsList)) , mPromptEmbeddingTable(std::move(promptEmbeddingTable)) , mPromptVocabSize(promptVocabSize) + , mLoraTaskId(loraTaskId) , mLoraWeights(std::move(loraWeights)) , mLoraConfig(std::move(loraConfig)) , mReturnLogProbs(returnLogProbs) @@ -105,7 +108,6 @@ class GenericLlmRequest , mIsStreaming(req.getStreaming()) , mEndId(req.getEndId()) , mPadId(req.getPadId()) - , mSeqSlot(-1) , mOrigPromptLen(mPromptLen) , mMaxSentTokenPos(mPromptLen - 1) , mReturnLogProbs(req.getOutputConfig().returnLogProbs) @@ -145,11 +147,19 @@ class GenericLlmRequest auto loraConfig = req.getLoraConfig(); if (loraConfig) { - mLoraWeights = executor::detail::toITensor(loraConfig.value().getWeights()); - mLoraWeights.value()->unsqueeze(0); + mLoraTaskId = loraConfig->getTaskId(); + auto optWeights = loraConfig->getWeights(); + if (loraConfig.value().getWeights()) + { + mLoraWeights = executor::detail::toITensor(loraConfig.value().getWeights().value()); + mLoraWeights.value()->unsqueeze(0); + } - mLoraConfig = executor::detail::toITensor(loraConfig.value().getConfig()); - mLoraConfig.value()->unsqueeze(0); + if (loraConfig.value().getConfig()) + { + mLoraConfig = executor::detail::toITensor(loraConfig.value().getConfig().value()); + mLoraConfig.value()->unsqueeze(0); + } } auto speculativeDecodingConfig = req.getSpeculativeDecodingConfig(); @@ -344,7 +354,7 @@ class GenericLlmRequest mState = REQUEST_STATE_CONTEXT_INIT; mContextCurrentPosition = 0; mContextChunkSize = std::nullopt; - mSeqSlot = -1; + mSeqSlot.reset(); } /// @brief Get the maximum position of the tokens returned to the client. Use to ensure we don't return to @@ -373,6 +383,16 @@ class GenericLlmRequest return mPromptVocabSize; } + [[nodiscard]] std::optional getLoraTaskId() const + { + return mLoraTaskId; + } + + void setLoraTaskId(LoraTaskIdType taskId) + { + mLoraTaskId = taskId; + } + [[nodiscard]] std::optional getLoraWeights() const { return mLoraWeights; @@ -713,9 +733,9 @@ class GenericLlmRequest runtime::SamplingConfig mSamplingConfig; LlmRequestState_t mState; bool mIsStreaming; - std::optional mEndId; - std::optional mPadId; - SizeType mSeqSlot; + std::optional mEndId; + std::optional mPadId; + std::optional mSeqSlot; std::optional mLogitsPostProcessor; protected: @@ -730,6 +750,7 @@ class GenericLlmRequest std::optional mPromptEmbeddingTable; std::optional mPromptVocabSize; + std::optional mLoraTaskId; std::optional mLoraWeights; std::optional mLoraConfig; @@ -822,23 +843,25 @@ class LlmRequest : public GenericLlmRequest std::optional padId = std::nullopt, std::optional embeddingBias = std::nullopt, std::optional badWordsList = std::nullopt, std::optional stopWordsList = std::nullopt, std::optional promptEmbeddingTable = std::nullopt, - std::optional promptVocabSize = std::nullopt, std::optional loraWeights = std::nullopt, - std::optional loraConfig = std::nullopt, bool returnLogProbs = false, - bool returnContextLogits = false, bool returnGenerationLogits = false, + std::optional promptVocabSize = std::nullopt, std::optional loraTaskId = std::nullopt, + std::optional loraWeights = std::nullopt, std::optional loraConfig = std::nullopt, + bool returnLogProbs = false, bool returnContextLogits = false, bool returnGenerationLogits = false, std::optional> draftTokens = std::nullopt, std::optional draftLogits = std::nullopt, bool excludeInputFromOutput = false, std::optional logitsPostProcessor = std::nullopt) : Base(requestId, maxNewTokens, std::move(inputTokens), samplingConfig, isStreaming, endId, padId, std::move(embeddingBias), std::move(badWordsList), std::move(stopWordsList), - std::move(promptEmbeddingTable), promptVocabSize, std::move(loraWeights), std::move(loraConfig), + std::move(promptEmbeddingTable), promptVocabSize, loraTaskId, std::move(loraWeights), std::move(loraConfig), returnLogProbs, returnContextLogits, returnGenerationLogits, std::move(draftTokens), std::move(draftLogits), excludeInputFromOutput, std::move(logitsPostProcessor)) { } - LlmRequest(RequestIdType requestId, executor::Request const& Request) + LlmRequest(RequestIdType requestId, executor::Request const& Request, + std::optional logitsPostProcessor = std::nullopt) : Base(requestId, Request) { + mLogitsPostProcessor = std::move(logitsPostProcessor); } void movePromptEmbeddingTableToGpu(runtime::BufferManager const& manager) diff --git a/cpp/include/tensorrt_llm/batch_manager/peftCacheManager.h b/cpp/include/tensorrt_llm/batch_manager/peftCacheManager.h new file mode 100644 index 000000000..024bb07b0 --- /dev/null +++ b/cpp/include/tensorrt_llm/batch_manager/peftCacheManager.h @@ -0,0 +1,165 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: LicenseRef-NvidiaProprietary + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +#include "tensorrt_llm/batch_manager/llmRequest.h" +#include "tensorrt_llm/batch_manager/peftCacheManagerConfig.h" +#include "tensorrt_llm/runtime/gptModelConfig.h" +#include "tensorrt_llm/runtime/loraCache.h" +#include "tensorrt_llm/runtime/workerPool.h" +#include "tensorrt_llm/runtime/worldConfig.h" +#include +#include +#include +#include +#include + +namespace tensorrt_llm::batch_manager +{ + +using runtime::SizeType; + +/** + * BasePeftCacheManager + * + * Manages caches of PEFT (Parameter Efficient Fine Tuning) weights. + * Does cache updates during execution loop moving weights to device as needed. + */ +class BasePeftCacheManager +{ +public: + using LlmRequestPtr = std::shared_ptr; + using RequestTable = std::map; + using PeftTable = std::map>>; + + /** + * \brief add PEFT weights from llmRequest if any. This will kickoff background copy tasks. + * \param[in] llmRequest: the request + * \param[in] tryGpuCache: if true try to load weights into gpu cache + */ + virtual void addRequestPeft(LlmRequestPtr llmRequest, bool tryGpuCache = true) = 0; + + /** + * \brief ensures device cache has all the weights needed to execute batch as specified by requestTable. + * This acts as sync for the copy tasks started by addRequestPeft + * \param[in] requestTable: current request table + * \param[in] resetGpuCache: reset (make all tasks evictable) + * \returns -- a PeftTable + */ + virtual PeftTable ensureBatch(RequestTable const& requestTable, bool resetGpuCache = false) = 0; + + /** + * \brief mark all the tasks in device cache as done + */ + virtual void resetDeviceCache() = 0; + + virtual void markRequestDone(LlmRequestPtr const& llmReq, bool pause = false) = 0; + + [[nodiscard]] virtual SizeType getMaxDevicePages() const = 0; + + [[nodiscard]] virtual SizeType getMaxHostPages() const = 0; + + [[nodiscard]] virtual SizeType determineNumPages(std::shared_ptr llmRequest) const = 0; + + [[nodiscard]] virtual bool enabled() const = 0; +}; + +class PeftCacheManager : public BasePeftCacheManager +{ +public: + PeftCacheManager(PeftCacheManagerConfig const& config, runtime::GptModelConfig const& modelConfig, + runtime::WorldConfig const& worldConfig, runtime::BufferManager const& bufferManager); + + void addRequestPeft(std::shared_ptr llmRequest, bool tryGpuCache = true) override; + + PeftTable ensureBatch(RequestTable const& requestTable, bool resetGpuCache = false) override; + + [[nodiscard]] bool isTaskCached(uint64_t taskId) const; + + [[nodiscard]] bool isTaskDone(uint64_t taskId) const; + + [[nodiscard]] bool isTaskDoneDevice(uint64_t taskId) const; + + void resetDeviceCache() override; + + void markRequestDone(std::shared_ptr const& llmReq, bool pause = false) override; + + [[nodiscard]] SizeType getMaxDevicePages() const override; + + [[nodiscard]] SizeType getMaxHostPages() const override; + + [[nodiscard]] SizeType determineNumPages(std::shared_ptr llmRequest) const override; + + inline bool enabled() const override + { + return true; + } + + std::unordered_map> const& getActiveTasks() const; + + std::unordered_map> const& getPausedTasks() const; + + void updateTaskState(uint64_t taskId, uint64_t reqId, bool terminate = false, bool pause = false); + + static std::pair getMaxNumSlots(PeftCacheManagerConfig const& config, + nvinfer1::DataType dataType, uint64_t pageWidth, uint64_t max1dModSize, + runtime::BufferManager const& bufferManager); + + static std::pair getPageManagerConfig( + PeftCacheManagerConfig const& config, runtime::GptModelConfig const& modelConfig, + runtime::WorldConfig const& worldConfig, runtime::BufferManager const& bufferManager); + +private: + std::unique_ptr mHostLoraCache; + std::unique_ptr mDeviceLoraCache; + + std::shared_ptr mPutWorkerPool; + std::unique_ptr mEnsureWorkerPool; + + mutable std::mutex mPutFuturesMutex; + std::unordered_map> mPutFutures; + + std::unordered_map> mTaskIdToReqIds; + std::unordered_map> mTaskIdToPausedReqIds; + + std::tuple>, std::map>> getTaskMaps( + RequestTable const& requestTable); + + runtime::GptModelConfig mModelConfig; + runtime::WorldConfig mWorldConfig; + + int mDevice{-1}; +}; + +class NoOpPeftCacheManager : public BasePeftCacheManager +{ + void addRequestPeft(std::shared_ptr llmRequest, bool tryGpuCache = true) override; + + PeftTable ensureBatch(RequestTable const& requestTable, bool resetGpuCache = false) override; + + void resetDeviceCache() override; + + void markRequestDone(std::shared_ptr const& llmReq, bool pause = false) override; + + [[nodiscard]] SizeType getMaxDevicePages() const override; + + [[nodiscard]] SizeType getMaxHostPages() const override; + + [[nodiscard]] SizeType determineNumPages(std::shared_ptr llmRequest) const override; + + inline bool enabled() const override + { + return false; + } +}; +} // namespace tensorrt_llm::batch_manager diff --git a/cpp/include/tensorrt_llm/batch_manager/peftCacheManagerConfig.h b/cpp/include/tensorrt_llm/batch_manager/peftCacheManagerConfig.h new file mode 100644 index 000000000..fbaec751b --- /dev/null +++ b/cpp/include/tensorrt_llm/batch_manager/peftCacheManagerConfig.h @@ -0,0 +1,94 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "tensorrt_llm/executor/executor.h" +#include "tensorrt_llm/runtime/common.h" +#include +#include + +namespace fs = std::filesystem; + +namespace tensorrt_llm::batch_manager +{ + +using runtime::SizeType; + +struct PeftCacheManagerConfig +{ + + static float constexpr kDefaultDeviceCachePercent = 0.05; + static size_t constexpr kDefaultHostCacheSize = 1024 * 1024 * 1024; + + explicit PeftCacheManagerConfig(SizeType numHostModuleLayer = 0, SizeType numDeviceModuleLayer = 0, + SizeType optimalAdapterSize = 8, SizeType maxAdapterSize = 64, SizeType numPutWorkers = 1, + SizeType numEnsureWorkers = 1, SizeType numCopyStreams = 1, SizeType maxPagesPerBlockHost = 24, + SizeType maxPagesPerBlockDevice = 8, std::optional deviceCachePercent = std::nullopt, + std::optional hostCacheSize = std::nullopt) + : numHostModuleLayer(numHostModuleLayer) + , numDeviceModuleLayer(numDeviceModuleLayer) + , optimalAdapterSize(optimalAdapterSize) + , maxAdapterSize(maxAdapterSize) + , numPutWorkers(numPutWorkers) + , numEnsureWorkers(numEnsureWorkers) + , numCopyStreams(numCopyStreams) + , maxPagesPerBlockHost(maxPagesPerBlockHost) + , maxPagesPerBlockDevice(maxPagesPerBlockDevice) + , deviceCachePercent(deviceCachePercent) + , hostCacheSize(hostCacheSize) + { + } + + PeftCacheManagerConfig(executor::PeftCacheConfig cfg) + : numHostModuleLayer(cfg.getNumHostModuleLayer()) + , numDeviceModuleLayer(cfg.getNumDeviceModuleLayer()) + , optimalAdapterSize(cfg.getOptimalAdapterSize()) + , maxAdapterSize(cfg.getMaxAdapterSize()) + , numPutWorkers(cfg.getNumPutWorkers()) + , numCopyStreams(cfg.getNumCopyStreams()) + , maxPagesPerBlockHost(cfg.getMaxPagesPerBlockHost()) + , maxPagesPerBlockDevice(cfg.getMaxPagesPerBlockDevice()) + , deviceCachePercent(cfg.getDeviceCachePercent()) + , hostCacheSize(cfg.getHostCacheSize()) + { + } + + // number of max sized 1-layer 1-module adapterSize=1 sets of weights that can be stored in host cache + SizeType numHostModuleLayer; + // number of max sized 1-layer 1-module sets of weights that can be stored in host cache + SizeType numDeviceModuleLayer; + // optimal adapter size used to set page width + SizeType optimalAdapterSize; + // max supported adapter size. Used to compute minimum + SizeType maxAdapterSize; + // number of worker threads used to put weights into host cache + SizeType numPutWorkers; + // number of worker threads used to copy weights from host to device + SizeType numEnsureWorkers; + // number of streams used to copy weights from host to device + SizeType numCopyStreams; + // Number of cache pages per allocation block (host) + SizeType maxPagesPerBlockHost; + // Number of cache pages per allocation block (device) + SizeType maxPagesPerBlockDevice; + // percent of memory after engine load to use for cache + std::optional deviceCachePercent; + // size in bytes to use for host cache + std::optional hostCacheSize; +}; +} // namespace tensorrt_llm::batch_manager diff --git a/cpp/include/tensorrt_llm/batch_manager/trtGptModelOptionalParams.h b/cpp/include/tensorrt_llm/batch_manager/trtGptModelOptionalParams.h index e42a342f9..917298bf3 100644 --- a/cpp/include/tensorrt_llm/batch_manager/trtGptModelOptionalParams.h +++ b/cpp/include/tensorrt_llm/batch_manager/trtGptModelOptionalParams.h @@ -18,6 +18,7 @@ #pragma once #include "tensorrt_llm/batch_manager/kvCacheConfig.h" +#include "tensorrt_llm/batch_manager/peftCacheManagerConfig.h" #include "tensorrt_llm/executor/executor.h" #include "tensorrt_llm/runtime/common.h" #include "tensorrt_llm/runtime/decodingMode.h" @@ -38,21 +39,23 @@ class TrtGptModelOptionalParams explicit TrtGptModelOptionalParams(KvCacheConfig const& kvCacheConfig = KvCacheConfig{}, bool enableTrtOverlap = false, std::optional> const& deviceIds = std::nullopt, bool normalizeLogProbs = true, bool enableChunkedContext = false, - std::optional const& decodingMode = std::nullopt) + std::optional const& decodingMode = std::nullopt, + PeftCacheManagerConfig const& peftCacheManagerConfig = PeftCacheManagerConfig{}) : kvCacheConfig{kvCacheConfig} , enableTrtOverlap{enableTrtOverlap} , deviceIds(deviceIds) , normalizeLogProbs{normalizeLogProbs} , enableChunkedContext{enableChunkedContext} , decodingMode{decodingMode} + , peftCacheManagerConfig(peftCacheManagerConfig) { } explicit TrtGptModelOptionalParams(executor::ExecutorConfig const& executorConfig) - : TrtGptModelOptionalParams(KvCacheConfig(executorConfig.getKvCacheConfig()), - executorConfig.getEnableTrtOverlap(), + : TrtGptModelOptionalParams(KvCacheConfig(executorConfig.getKvCacheConfig()), false, executorConfig.getParallelConfig().value_or(executor::ParallelConfig()).getDeviceIds(), - executorConfig.getNormalizeLogProbs(), executorConfig.getEnableChunkedContext()) + executorConfig.getNormalizeLogProbs(), executorConfig.getEnableChunkedContext(), std::nullopt, + PeftCacheManagerConfig(executorConfig.getPeftCacheConfig())) { } @@ -70,6 +73,7 @@ class TrtGptModelOptionalParams bool normalizeLogProbs; bool enableChunkedContext; std::optional decodingMode; + PeftCacheManagerConfig peftCacheManagerConfig; }; } // namespace tensorrt_llm::batch_manager diff --git a/cpp/tensorrt_llm/common/assert.h b/cpp/include/tensorrt_llm/common/assert.h similarity index 100% rename from cpp/tensorrt_llm/common/assert.h rename to cpp/include/tensorrt_llm/common/assert.h diff --git a/cpp/tensorrt_llm/common/stringUtils.h b/cpp/include/tensorrt_llm/common/stringUtils.h similarity index 100% rename from cpp/tensorrt_llm/common/stringUtils.h rename to cpp/include/tensorrt_llm/common/stringUtils.h diff --git a/cpp/tensorrt_llm/common/tllmException.h b/cpp/include/tensorrt_llm/common/tllmException.h similarity index 100% rename from cpp/tensorrt_llm/common/tllmException.h rename to cpp/include/tensorrt_llm/common/tllmException.h diff --git a/cpp/include/tensorrt_llm/executor/executor.h b/cpp/include/tensorrt_llm/executor/executor.h index f2a77b6d1..f3448ca62 100644 --- a/cpp/include/tensorrt_llm/executor/executor.h +++ b/cpp/include/tensorrt_llm/executor/executor.h @@ -41,6 +41,8 @@ class Serialization; class SamplingConfig { public: + /// @brief Constructor for SamplingConfig + /// See description of parameters below SamplingConfig(SizeType beamWidth = 1, std::optional topK = std::nullopt, std::optional topP = std::nullopt, std::optional topPMin = std::nullopt, std::optional topPResetIds = std::nullopt, std::optional topPDecay = std::nullopt, @@ -74,33 +76,56 @@ class SamplingConfig private: friend class Serialization; + + /// @brief The beam width. Default is 1 which disables beam search. SizeType mBeamWidth; + /// @brief Controls number of logits to sample from. Default is 0 (all logits). std::optional mTopK; + /// @brief Controls the top-P probability to sample from. Default is 0.f std::optional mTopP; + /// @brief Controls decay in the top-P algorithm. topPMin is lower-bound. Default is 1.e-6. std::optional mTopPMin; + /// @brief Controls decay in the top-P algorithm. Indicates where to reset the decay. Default is 1. std::optional mTopPResetIds; + /// @brief Controls decay in the top-P algorithm. The decay value. Default is 1.f std::optional mTopPDecay; + /// @brief Controls the random seed used by the random number generator in sampling std::optional mRandomSeed; + /// @brief Controls the modulation of logits when sampling new tokens. Default is 1.0f std::optional mTemperature; + /// @brief Lower bound on the number of tokens to generate std::optional mMinLength; + /// @brief Controls the diversity in beam search. std::optional mBeamSearchDiversityRate; + /// @brief Used to penalize tokens based on how often they appear in the sequence. Default is 0.f std::optional mRepetitionPenalty; + /// @brief Used to penalize tokens already present in the sequence (irrespective of the number of appearances). + /// Default is 0.f std::optional mPresencePenalty; + /// @brief Used to penalize tokens already present in the sequence (dependent on the number of appearances). Default + /// is 0.f std::optional mFrequencyPenalty; + /// @brief Controls how to penalize longer sequences in beam search. Default is 0.f std::optional mLengthPenalty; + /// @brief Controls whether the generation process finishes once beamWidth sentences are generated (end with + /// end_token) std::optional mEarlyStopping; }; -/// @brief Configuration that controls the outputs of a Result +/// @brief Configuration that controls the outputs of a Result class OutputConfig { public: OutputConfig(bool returnLogProbs = false, bool returnContextLogits = false, bool returnGenerationLogits = false, bool excludeInputFromOutput = false); + /// @brief Controls if Result should contain log probabilities. Default is false bool returnLogProbs; + /// @brief Controls if Result should contain the context logits. Default is false bool returnContextLogits; + /// @brief Controls if Result should contain the generation logits. Default is false. bool returnGenerationLogits; + /// @brief Controls if output tokens in Result should include the input tokens. Default is false. bool excludeInputFromOutput; }; @@ -120,8 +145,11 @@ class SpeculativeDecodingConfig private: friend class Serialization; + /// @brief The draft tokens VecTokens mTokens; + /// @brief The draft logits std::optional mLogits; + /// @brief The acceptance threshold std::optional mAcceptanceThreshold; }; @@ -129,10 +157,6 @@ class SpeculativeDecodingConfig class PromptTuningConfig { public: - /// @brief - /// @param embeddingTable The prompt embedding table. Data type must match model weights. Shape [vocabSize, - /// hiddenSize] - /// @param vocabSize PromptTuningConfig(Tensor embeddingTable); ~PromptTuningConfig(); @@ -140,6 +164,7 @@ class PromptTuningConfig private: friend class Serialization; + /// @brief The prompt embedding table Tensor mEmbeddingTable; }; @@ -147,37 +172,46 @@ class PromptTuningConfig class LoraConfig { public: - LoraConfig(Tensor weights, Tensor config); + LoraConfig( + IdType taskId, std::optional weights = std::nullopt, std::optional config = std::nullopt); ~LoraConfig(); - [[nodiscard]] Tensor getWeights() const; - [[nodiscard]] Tensor getConfig() const; + [[nodiscard]] IdType getTaskId() const; + [[nodiscard]] std::optional getWeights() const; + [[nodiscard]] std::optional getConfig() const; private: friend class Serialization; - Tensor mWeights; - Tensor mConfig; + /// @brief The Lora task id + IdType mTaskId; + /// @brief The Lora weights + std::optional mWeights; + /// @brief The Lora configuration + std::optional mConfig; }; /// @brief A class that holds information about the request class Request { public: - /// @brief + /// @brief The Request constructor + /// @param inputTokenIds The input token ids /// @param maxNewTokens The maximum number of tokens to generate - /// @param streaming // Indicates if the responses should be streamed or not - /// @param samplingConfig // The sampling configuration - /// @param outputConfig // The output configuration - /// @param endId // The end token id - /// @param padId // The pad token id - /// @param badWords // A list of bad words tokens. Each "word" can be composed of multiple tokens - /// @param stopWords // A list of stop words tokens. Each "word" can be composed of multiple tokens - /// @param embeddingBias // The embedding bias tensor. Expected type is kFP32 and shape is [vocab_size] - /// @param speculativeDecodingConfig // The speculative decoding configuration - /// @param pTuningConfig // The prompt tuning configuration - /// @param loraConfig // The LoRA configuration + /// @param streaming Indicates if the responses should be streamed or not + /// @param samplingConfig The sampling configuration + /// @param outputConfig The output configuration + /// @param endId The end token id + /// @param padId The pad token id + /// @param badWords A list of bad words tokens. Each "word" can be composed of multiple tokens + /// @param stopWords A list of stop words tokens. Each "word" can be composed of multiple tokens + /// @param embeddingBias The embedding bias tensor. Expected type is kFP32 and shape is [vocab_size] + /// @param speculativeDecodingConfig The speculative decoding configuration + /// @param pTuningConfig The prompt tuning configuration + /// @param loraConfig The LoRA configuration + /// @param logitsPostProcessorName The logits postprocessor name. Must correspond to one of the logits postprocessor + /// name provided to the ExecutorConfig. Request(VecTokens inputTokenIds, SizeType maxNewTokens, bool streaming = false, SamplingConfig samplingConfig = SamplingConfig(), OutputConfig outputConfig = OutputConfig(), std::optional endId = std::nullopt, std::optional padId = std::nullopt, @@ -186,7 +220,8 @@ class Request std::optional embeddingBias = std::nullopt, std::optional speculativeDecodingConfig = std::nullopt, std::optional pTuningConfig = std::nullopt, - std::optional loraConfig = std::nullopt); + std::optional loraConfig = std::nullopt, + std::optional logitsPostProcessorName = std::nullopt); Request(Request const& other); Request(Request&& other) noexcept; @@ -207,6 +242,7 @@ class Request [[nodiscard]] std::optional getSpeculativeDecodingConfig() const; [[nodiscard]] std::optional getPromptTuningConfig() const; [[nodiscard]] std::optional getLoraConfig() const; + [[nodiscard]] std::optional getLogitsPostProcessorName() const; void setStreaming(bool streaming); void setSamplingConfig(SamplingConfig config); @@ -219,6 +255,7 @@ class Request void setSpeculativeDecodingConfig(SpeculativeDecodingConfig specDecodingConfig); void setPromptTuningConfig(PromptTuningConfig pTuningConfig); void setLoraConfig(LoraConfig loraConfig); + void setLogitsPostProcessorName(std::string const& logitsPostProcessorName); private: friend class Serialization; @@ -229,16 +266,23 @@ class Request /// @brief Struct that holds the generation result struct Result { - // Indicates if this is the final result for the request + /// @brief Indicates if this is the final result for the request bool isFinal; /// @brief The output tokens for each beam BeamTokens outputTokenIds; - std::optional cumLogProbs; // [beamSize] - std::optional> logProbs; // [beamSize, seqLen] - std::optional contextLogits; // [promptLen, vocab_size_padded] - std::optional generationLogits; // [beam_size, mMaxNewTokens, vocab_size_padded] + /// @brief The cumulative log probabilities. Size beamSize. + std::optional cumLogProbs; + + /// @brief The log probabilities for each generated token. Size [beamSize, seqLen] + std::optional> logProbs; + + /// @brief The context logits. Size [promptLen, vocabSizePadded] + std::optional contextLogits; + + /// @brief The context logits. Size [beamSize, maxNewTokens, vocabSizePadded] + std::optional generationLogits; }; /// @brief Class that holds either an error or a result @@ -254,18 +298,18 @@ class Response Response& operator=(Response const& other); Response& operator=(Response&& other) noexcept; - // Get the id of the request for which this response was generated + /// @brief Get the id of the request for which this response was generated IdType getRequestId() const; - // Indicates if this response has an error or not + /// @brief Indicates if this response has an error or not bool hasError() const; - // Get the error msg for this response - // Will throw an exception if hasError is false + /// @brief Get the error msg for this response + /// Will throw an exception if hasError is false std::string getErrorMsg() const; - // Get the result for this response - // Will throw an exception if hasResult is true + /// @brief Get the result for this response + /// Will throw an exception if hasResult is true Result getResult() const; private: @@ -283,6 +327,7 @@ class SchedulerConfig [[nodiscard]] SchedulerPolicy getPolicy() const; private: + /// @brief The scheduler policy. See SchedulerPolicy. SchedulerPolicy mPolicy; }; @@ -293,25 +338,39 @@ class KvCacheConfig KvCacheConfig(bool enableBlockReuse = false, std::optional maxTokens = std::nullopt, std::optional maxAttentionWindow = std::nullopt, std::optional sinkTokenLength = std::nullopt, - std::optional freeGpuMemoryFraction = std::nullopt, bool useUvm = false); + std::optional freeGpuMemoryFraction = std::nullopt); [[nodiscard]] bool getEnableBlockReuse() const; [[nodiscard]] std::optional getMaxTokens() const; [[nodiscard]] std::optional getMaxAttentionWindow() const; [[nodiscard]] std::optional getSinkTokenLength() const; [[nodiscard]] std::optional getFreeGpuMemoryFraction() const; - [[nodiscard]] bool getUseUvm() const; private: + /// @brief Controls if KV cache blocks can be reused for different requests bool mEnableBlockReuse; + + /// @brief The maximum number of tokens that should be stored in the KV cache + /// If both mMaxTokens and mFreeGpuMemoryFraction are specified, memory corresponding to the minimum will be + /// allocated. std::optional mMaxTokens; + + /// @brief Size of the attention window for each sequence. Only the last mMaxAttentionWindow tokens of each sequence + /// will be stored in the KV cache. std::optional mMaxAttentionWindow; + + /// @brief Number of sink tokens (tokens to always keep in attention window) std::optional mSinkTokenLength; + + /// @brief The fraction of GPU memory fraction that should be allocated for the KV cache. Default is 90%. + /// If both mMaxTokens and mFreeGpuMemoryFraction are specified, memory corresponding to the minimum will be + /// allocated. std::optional mFreeGpuMemoryFraction; - bool mUseUvm; }; SizeType const kDefaultIterStatsMaxIterations = 1000; +// Per request stats may have additional overhead due to going through all requests. Turned off by default. +SizeType const kDefaultRequestStatsMaxIterations = 0; /// @brief A configuration class for the parallel execution parameters /// Currently only supports commType = CommunicationType::kMPI @@ -341,52 +400,132 @@ class ParallelConfig void setParticipantIds(std::vector participantIds); private: + /// @brief The type of communication protocol used. Default is MPI. CommunicationType mCommType; + + /// @brief The mode of communication. See CommunicationMode. CommunicationMode mCommMode; + + /// @brief The GPU device ids to use for executing this model std::optional> mDeviceIds; + + /// @brief The participant ids (MPI ranks for example) used for executing this model std::optional> mParticipantIds; }; +/// @brief config for PeftCacheManager +class PeftCacheConfig +{ +public: + PeftCacheConfig(SizeType numHostModuleLayer = 0, SizeType numDeviceModuleLayer = 0, SizeType optimalAdapterSize = 8, + SizeType maxAdapterSize = 64, SizeType numPutWorkers = 1, SizeType numEnsureWorkers = 1, + SizeType numCopyStreams = 1, SizeType maxPagesPerBlockHost = 24, SizeType maxPagesPerBlockDevice = 8, + std::optional deviceCachePercent = std::nullopt, std::optional hostCacheSize = std::nullopt); + + [[nodiscard]] SizeType getNumHostModuleLayer() const; + [[nodiscard]] SizeType getNumDeviceModuleLayer() const; + [[nodiscard]] SizeType getOptimalAdapterSize() const; + [[nodiscard]] SizeType getMaxAdapterSize() const; + [[nodiscard]] SizeType getNumPutWorkers() const; + [[nodiscard]] SizeType getNumEnsureWorkers() const; + [[nodiscard]] SizeType getNumCopyStreams() const; + [[nodiscard]] SizeType getMaxPagesPerBlockHost() const; + [[nodiscard]] SizeType getMaxPagesPerBlockDevice() const; + [[nodiscard]] std::optional getDeviceCachePercent() const; + [[nodiscard]] std::optional getHostCacheSize() const; + +private: + // number of max sized 1-layer 1-module adapterSize=1 sets of weights that can be stored in host cache + SizeType mNumHostModuleLayer; + // number of max sized 1-layer 1-module sets of weights that can be stored in host cache + SizeType mNumDeviceModuleLayer; + // optimal adapter size used to set page width + SizeType mOptimalAdapterSize; + // max supported adapter size. Used to compute minimum + SizeType mMaxAdapterSize; + // number of worker threads used to put weights into host cache + SizeType mNumPutWorkers; + // number of worker threads used to copy weights from host to device + SizeType mNumEnsureWorkers; + // number of streams used to copy weights from host to device + SizeType mNumCopyStreams; + // Number of cache pages per allocation block (host) + SizeType mMaxPagesPerBlockHost; + // Number of cache pages per allocation block (device) + SizeType mMaxPagesPerBlockDevice; + // percent of memory after engine load to use for cache + std::optional mDeviceCachePercent; + // size in bytes to use for host cache + std::optional mHostCacheSize; +}; + /// @brief Configuration class for the model executor class ExecutorConfig { + using LogitsPostProcessorMap = std::unordered_map; + public: ExecutorConfig(SizeType maxBeamWidth = 1, SchedulerConfig schedulerConfig = SchedulerConfig(), KvCacheConfig kvCacheConfig = KvCacheConfig(), bool enableChunkedContext = false, bool normalizeLogProbs = true, - bool enableTrtOverlap = false, SizeType iterStatsMaxIterations = kDefaultIterStatsMaxIterations, + SizeType iterStatsMaxIterations = kDefaultIterStatsMaxIterations, + SizeType requestStatsMaxIterations = kDefaultRequestStatsMaxIterations, BatchingType batchingType = BatchingType::kINFLIGHT, - std::optional parallelConfig = std::nullopt); + std::optional parallelConfig = std::nullopt, + PeftCacheConfig peftCacheConfig = PeftCacheConfig(), LogitsPostProcessorMap = {}); [[nodiscard]] SizeType getMaxBeamWidth() const; [[nodiscard]] SchedulerConfig getSchedulerConfig() const; [[nodiscard]] KvCacheConfig getKvCacheConfig() const; [[nodiscard]] bool getEnableChunkedContext() const; [[nodiscard]] bool getNormalizeLogProbs() const; - [[nodiscard]] bool getEnableTrtOverlap() const; [[nodiscard]] SizeType getIterStatsMaxIterations() const; + [[nodiscard]] SizeType getRequestStatsMaxIterations() const; [[nodiscard]] BatchingType getBatchingType() const; [[nodiscard]] std::optional getParallelConfig() const; + [[nodiscard]] PeftCacheConfig getPeftCacheConfig() const; + [[nodiscard]] LogitsPostProcessorMap getLogitsPostProcessorMap() const; void setMaxBeamWidth(SizeType maxBeamWidth); void setSchedulerConfig(SchedulerConfig schedulerConfig); void setKvCacheConfig(KvCacheConfig kvCacheConfig); void setEnableChunkedContext(bool enableChunkedContext); void setNormalizeLogProbs(bool normalizeLogProbs); - void setEnableTrtOverlap(bool enableTrtOverlap); void setIterStatsMaxIterations(SizeType iterStatsMaxIterations); + void setRequestStatsMaxIterations(SizeType requestStatsMaxIterations); void setBatchingType(BatchingType batchingType); void setParallelConfig(ParallelConfig parallelConfig); + void setPeftCacheConfig(PeftCacheConfig peftCacheConfig); + void setLogitsPostProcessorMap(LogitsPostProcessorMap logitsPostProcessorMap); private: + /// @brief The beam width value of requests that will be sent to the executor SizeType mMaxBeamWidth; + + /// @brief The scheduler configuration. SchedulerConfig mSchedulerConfig; + + /// @brief The KV cache configuration. KvCacheConfig mKvCacheConfig; + + /// @brief The KV cache configuration. bool mEnableChunkedContext; + + /// @brief Controls if log probabilities should be normalized or not. bool mNormalizeLogProbs; - bool mEnableTrtOverlap; + + /// @brief Controls the maximum number of iterations for which to keep statistics. SizeType mIterStatsMaxIterations; + + /// @brief Controls the maximum number of iterations for which to keep per-request statistics. + SizeType mRequestStatsMaxIterations; + + /// @brief The type of batching strategy to use. See BatchingType. BatchingType mBatchingType; + + /// @brief The parallel execution configuration. std::optional mParallelConfig; + PeftCacheConfig mPeftCacheConfig; + LogitsPostProcessorMap mLogitsPostProcessorMap; }; /// @brief The executor is responsible for receiving new requests and sending responses, and running the inference @@ -439,14 +578,31 @@ class Executor /// @brief Returns the per-iterations statistics computed since last call to getLatestIterationStats /// Contains at most iterStatsMaxIterations iterations - /// Will block until stats for at least one iteration are available - /// TODO: Should we use a class for iterationStats, i.e. std::deque - /// @return - std::deque getLatestIterationStats(); + /// @return Iteration stats + std::deque getLatestIterationStats(); + + /// @brief Returns the request stats of each iteration computed since last call to getLatestRequestStats + /// Contains at most requestStatsMaxIterations iterations + /// @return Request stats grouped by iterations + std::deque getLatestRequestStats(); private: class Impl; std::unique_ptr mImpl; }; +/// @brief Class with utility functions to serialize statistics to json string +class JsonSerialization +{ +public: + /// @brief Utility function to convert an iterationStats struct to a json serialized string + [[nodiscard]] static std::string toJsonStr(IterationStats const& iterationStats); + + /// @brief Utility function to convert a requestStatsPerIteration struct to a json serialized string + [[nodiscard]] static std::string toJsonStr(RequestStatsPerIteration const& requestStatsPerIter); + + /// @brief Utility function to convert a requestStats struct to a json serialized string + [[nodiscard]] static std::string toJsonStr(RequestStats const& requestStats); +}; + } // namespace tensorrt_llm::executor diff --git a/cpp/include/tensorrt_llm/executor/types.h b/cpp/include/tensorrt_llm/executor/types.h index 36872629b..7632c8376 100644 --- a/cpp/include/tensorrt_llm/executor/types.h +++ b/cpp/include/tensorrt_llm/executor/types.h @@ -17,7 +17,10 @@ #pragma once #include +#include #include +#include +#include #include #ifdef ENABLE_FP8 @@ -27,6 +30,11 @@ #include #endif +namespace tensorrt_llm::runtime +{ +class CudaStream; +} // namespace tensorrt_llm::runtime + namespace tensorrt_llm::executor { @@ -40,8 +48,11 @@ using TokenIdType = std::int32_t; using VecTokens = std::vector; using BeamTokens = std::vector; using IdType = std::uint64_t; +using IterationType = std::uint64_t; using RandomSeedType = std::uint64_t; using VecLogProbs = std::vector; +using StreamPtr = std::shared_ptr; +using LogitsPostProcessor = std::function; enum class DataType { @@ -146,7 +157,6 @@ enum class BatchingType { kSTATIC = 0, kINFLIGHT = 1, - kINFLIGHT_UNFUSED = 2, }; enum class SchedulerPolicy @@ -167,4 +177,116 @@ enum class CommunicationMode // first participant in the provided participant IDS, or 0 if participant ID is not provided }; +/// @brief Struct that holds the stats of a KV cache manager +struct KvCacheStats +{ + /// @brief Max number of blocks + SizeType maxNumBlocks; + /// @brief Number of free blocks + SizeType freeNumBlocks; + /// @brief Number of used blocks + SizeType usedNumBlocks; + /// @brief Number of tokens per block + SizeType tokensPerBlock; +}; + +/// @brief Struct that holds the stats of static batching models for a single iteration +struct StaticBatchingStats +{ + /// @brief Number of scheduled requests + SizeType numScheduledRequests; + /// @brief Number of requests in context stage + SizeType numContextRequests; + /// @brief Total number of context tokens in the iteration + SizeType numCtxTokens; + /// @brief Total number of tokens to generate in the iteration + SizeType numGenTokens; + /// @brief Total number of unused generation token slots + SizeType emptyGenSlots; +}; + +/// @brief Struct that holds the stats of inflight batching models for a single iteration +struct InflightBatchingStats +{ + /// @brief Number of scheduled requests + SizeType numScheduledRequests; + /// @brief Number of requests in context stage + SizeType numContextRequests; + /// @brief Number of requests in generation stage + SizeType numGenRequests; + /// @brief Number of paused requests + SizeType numPausedRequests; + /// @brief Total number of context tokens in the iteration + SizeType numCtxTokens; + /// @brief Index of mirco batch + SizeType microBatchId; +}; + +/// @brief Struct that holds the stats of a single iteration +struct IterationStats +{ + /// @brief Ending time of this iteration + std::string timestamp; + /// @brief Iteration id + SizeType iter; + /// @brief Number of active requests + SizeType numActiveRequests; + /// @brief Number of max active requests + SizeType maxNumActiveRequests; + /// @brief GPU memory usage in bytes + size_t gpuMemUsage; + /// @brief CPU memory usage in bytes + size_t cpuMemUsage; + /// @brief Pinned memory usage in bytes + size_t pinnedMemUsage; + /// @brief Stats specific to KV caches + std::optional kvCacheStats; + /// @brief Stats specific to static batching + std::optional staticBatchingStats; + /// @brief Stats specific to inflight batching + std::optional inflightBatchingStats; +}; + +/// @brief Enum class that represents the state of a request +enum class RequestStage +{ + /// @brief Request that have been received but not yet included in the active requests (due to constraints such as + /// maximum batch size for example). + kQUEUED, + /// @brief Active request in context phase + kCONTEXT_IN_PROGRESS, + /// @brief Active request in generation phase + kGENERATION_IN_PROGRESS, + /// @brief Active request for which generation has completed + kGENERATION_COMPLETE, + +}; + +/// @brief Struct that holds the stats of a single request +struct RequestStats +{ + /// @brief The request id + IdType id; + /// @brief The current stage the request is in + RequestStage stage; + /// @brief If using chunked context, the current context prefill position + SizeType contextPrefillPosition; + /// @brief The number of generated tokens so far + SizeType numGeneratedTokens; + /// @brief Whether the request is scheduled for the current iteration + bool scheduled; + /// @brief Whether the request is being paused at the current iteration due to lack of resources (KV cache blocks + /// exhaustion for example) + bool paused; +}; + +/// @brief Struct that holds the stats of all requests in an iteration +struct RequestStatsPerIteration +{ + /// @brief The iteration id for these stats + IterationType iter; + /// @brief The stats of all active requests for this iteration + std::vector requestStats; +}; + } // namespace tensorrt_llm::executor diff --git a/cpp/tensorrt_llm/plugins/api/tllmPlugin.h b/cpp/include/tensorrt_llm/plugins/api/tllmPlugin.h similarity index 100% rename from cpp/tensorrt_llm/plugins/api/tllmPlugin.h rename to cpp/include/tensorrt_llm/plugins/api/tllmPlugin.h diff --git a/cpp/include/tensorrt_llm/runtime/bufferManager.h b/cpp/include/tensorrt_llm/runtime/bufferManager.h index 202ab6802..d5be3332b 100644 --- a/cpp/include/tensorrt_llm/runtime/bufferManager.h +++ b/cpp/include/tensorrt_llm/runtime/bufferManager.h @@ -26,6 +26,8 @@ #include #include +class BufferManagerTest; + namespace tensorrt_llm::runtime { //! \brief A helper class for managing memory on host and device. @@ -42,7 +44,16 @@ class BufferManager //! //! \param[in] cudaStream The cuda stream to use for all operations on GPU (allocation, de-allocation, copying, //! etc.). - explicit BufferManager(CudaStreamPtr stream); + explicit BufferManager(CudaStreamPtr stream, bool trimPool = false); + + //! \brief Destructor. + ~BufferManager() + { + if (mTrimPool) + { + memoryPoolTrimTo(0); + } + } static auto constexpr kBYTE_TYPE = nvinfer1::DataType::kUINT8; @@ -171,6 +182,8 @@ class BufferManager void memoryPoolTrimTo(std::size_t size); private: + friend class ::BufferManagerTest; + void static initMemoryPool(int device); std::size_t static memoryPoolReserved(int device); @@ -185,6 +198,7 @@ class BufferManager void static memoryPoolTrimTo(int device, std::size_t size); CudaStreamPtr mStream; + bool const mTrimPool; }; } // namespace tensorrt_llm::runtime diff --git a/cpp/include/tensorrt_llm/runtime/decodingMode.h b/cpp/include/tensorrt_llm/runtime/decodingMode.h index c3d75f542..9c400668f 100644 --- a/cpp/include/tensorrt_llm/runtime/decodingMode.h +++ b/cpp/include/tensorrt_llm/runtime/decodingMode.h @@ -49,6 +49,11 @@ class DecodingMode return DecodingMode{kBeamSearch}; } + static auto constexpr Medusa() + { + return DecodingMode{kMedusa}; + } + bool constexpr isNone() { return mState == 0; @@ -79,6 +84,11 @@ class DecodingMode return anyBitSet(kBeamSearch); } + bool constexpr isMedusa() + { + return anyBitSet(kMedusa); + } + using UnderlyingType = uint8_t; bool operator==(DecodingMode const& other) const @@ -98,6 +108,7 @@ class DecodingMode static UnderlyingType constexpr kTopK{1u << 0}; static UnderlyingType constexpr kTopP{1u << 1}; static UnderlyingType constexpr kBeamSearch{1u << 2}; + static UnderlyingType constexpr kMedusa{1u << 3}; static UnderlyingType constexpr kTopKTopP{kTopK | kTopP}; bool constexpr anyBitSet(UnderlyingType bits) const @@ -117,27 +128,39 @@ static_assert(DecodingMode::None().isNone()); static_assert(!DecodingMode::None().isTopK()); static_assert(!DecodingMode::None().isTopP()); static_assert(!DecodingMode::None().isBeamSearch()); +static_assert(!DecodingMode::None().isMedusa()); static_assert(DecodingMode::TopK().isTopK()); static_assert(DecodingMode::TopK().isTopKorTopP()); static_assert(!DecodingMode::TopK().isTopKandTopP()); static_assert(!DecodingMode::TopK().isTopP()); static_assert(!DecodingMode::TopK().isBeamSearch()); +static_assert(!DecodingMode::TopK().isMedusa()); static_assert(DecodingMode::TopP().isTopP()); static_assert(DecodingMode::TopP().isTopKorTopP()); static_assert(!DecodingMode::TopP().isTopKandTopP()); static_assert(!DecodingMode::TopP().isTopK()); static_assert(!DecodingMode::TopP().isBeamSearch()); +static_assert(!DecodingMode::TopP().isMedusa()); static_assert(DecodingMode::TopKTopP().isTopK()); static_assert(DecodingMode::TopKTopP().isTopP()); static_assert(DecodingMode::TopKTopP().isTopKorTopP()); static_assert(DecodingMode::TopKTopP().isTopKandTopP()); static_assert(!DecodingMode::TopKTopP().isBeamSearch()); +static_assert(!DecodingMode::TopKTopP().isMedusa()); static_assert(DecodingMode::BeamSearch().isBeamSearch()); static_assert(!DecodingMode::BeamSearch().isTopKorTopP()); +static_assert(!DecodingMode::BeamSearch().isMedusa()); + +static_assert(!DecodingMode::Medusa().isTopK()); +static_assert(!DecodingMode::Medusa().isTopKorTopP()); +static_assert(!DecodingMode::Medusa().isTopKandTopP()); +static_assert(!DecodingMode::Medusa().isTopP()); +static_assert(!DecodingMode::Medusa().isBeamSearch()); +static_assert(DecodingMode::Medusa().isMedusa()); } // namespace runtime } // namespace tensorrt_llm diff --git a/cpp/include/tensorrt_llm/runtime/gptDecoder.h b/cpp/include/tensorrt_llm/runtime/gptDecoder.h index 1e369d13b..a201e5518 100644 --- a/cpp/include/tensorrt_llm/runtime/gptDecoder.h +++ b/cpp/include/tensorrt_llm/runtime/gptDecoder.h @@ -117,6 +117,8 @@ class GptDecoder : public virtual IGptDecoder TensorPtr mLogProbsTiled; // Buffer used to store the transpose of the logProbs. Needed because the kernels have // been written to use that shape. SamplingConfig mSamplingConfig; + + cudaDeviceProp mProp; // Avoid dangling pointers in mDynamicDecodeLayer }; inline std::unique_ptr IGptDecoder::create(DecodingMode const& mode, nvinfer1::DataType dtype, diff --git a/cpp/include/tensorrt_llm/runtime/loraCache.h b/cpp/include/tensorrt_llm/runtime/loraCache.h new file mode 100644 index 000000000..d9bf51ef9 --- /dev/null +++ b/cpp/include/tensorrt_llm/runtime/loraCache.h @@ -0,0 +1,436 @@ +/*loraCac + * Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "tensorrt_llm/runtime/bufferManager.h" +#include "tensorrt_llm/runtime/common.h" +#include "tensorrt_llm/runtime/gptModelConfig.h" +#include "tensorrt_llm/runtime/iTensor.h" +#include "tensorrt_llm/runtime/loraCachePageManagerConfig.h" +#include "tensorrt_llm/runtime/loraModule.h" +#include "tensorrt_llm/runtime/worldConfig.h" +#include +#include +#include +#include +#include +#include +#include + +namespace tensorrt_llm::runtime +{ + +/** + * Holds memory of lora cache pages, and manages allocation and freeing of whole pages. + * Memory is pre-allocated either on the host or device + * + * Note that this class is not thread safe + */ +class LoraCachePageManager +{ +public: + using TensorPtr = ITensor::SharedPtr; + + /** + * \param[in] config: a LoraCachePageManagerConfig + * \param[in] bufferManager: a Buffermanager used to allocate page blocks + */ + LoraCachePageManager(LoraCachePageManagerConfig const& config, BufferManager const& bufferManager); + + /** + * \brief claim pages + * + * \param[in] numPages number of pages to claim + * \returns a tuple, where the first values is a boolean indicating whether pages were claimed. If the first value + * is true the second value will have a list of pageIds + */ + [[nodiscard]] std::optional> claimPages(SizeType numPages); + + /** + * \brief get number of available (free) pages in manager + * + * \returns number of free pages in manager + */ + [[nodiscard]] SizeType numAvailablePages() const; + + /** + * \brief release given pages + * + * \param[in] pages: list of pages to release (free) + */ + void releasePages(std::vector const& pages); + + /** + * \brief return pointer to given page block + * + * \param[in] blockIdx; + * \returns -- pointer to page block + */ + [[nodiscard]] ITensor::SharedConstPtr blockPtr(SizeType blockIdx) const; + + /** + * \brief return pointer to given page + * + * \param[in] pageIdx: + * \returns -- const pointer to page + */ + [[nodiscard]] ITensor::SharedConstPtr pagePtr(std::size_t pageIdx) const; + + /** + * \brief return pointer to given page + * + * \param[in] pageIdx: + * \returns -- mutable pointer to page + */ + [[nodiscard]] ITensor::SharedPtr mutablePagePtr(std::size_t pageIdx); + +private: + std::vector mPageBlocks; + std::deque mFreePageIds; + std::vector mIsPageFree; + LoraCachePageManagerConfig const mConfig; + + void initialize(BufferManager const& bufferManager); +}; + +/** + * LoraCache + * + * Caches LoRA weights with LRU eviction policy. + * + * Tasks put in the cache are marked in progress and can not be evicted, until they are marked done. + * + * A cache page holds a optimally sized LoRA. A page is of size [numSlots x pageWidth] + * An optimally size LoRA is on that has the configured optimalAdapterSize. + * + * Conceptually a slot corresponds to a r=1, 1-layer, 1-module set of in/out weights. + * Page width is set to the number of weights in smallest module. + * + * The number of slots per page is then ceilDiv(num weights in optimally sized LoRA, num weights in smallest module) + * + * Cache pages are allocated on one or more blocks + */ +class LoraCache +{ +public: + using TensorPtr = ITensor::SharedPtr; + using TaskIdType = std::uint64_t; + + /** + * Contains information on a single layer / module. + * A list of these configs is associated with each task and can be used to populate runtime tensors. + */ + struct TaskLayerModuleConfig + { + std::size_t pageId; + SizeType slotIdx; + SizeType inSize; // adapterSize * inDim + SizeType outSize; // outDim * adapterSize + SizeType moduleId; + SizeType layerId; + SizeType adapterSize; + SizeType numSlots; // number of slots used by this layer / module. Used to avoid copying extra data from page. + + // pointer to inWeights cast to an int64_t + std::int64_t weightsInPointer; + // pointer to out weights cast to an int64_t + std::int64_t weightsOutPointer; + + std::string toString() const; + + bool operator==(LoraCache::TaskLayerModuleConfig const& o) const; + }; + + using TaskLayerModuleConfigListPtr = std::shared_ptr>; + + /** + * param[in] pageManagerConfig: a LoraCachePageManagerConfig + * param[in] modelConfig: a GptModelConfig + * param[in] worldConfig: a WorldConfig + * param[in] bufferManager: a BufferManager only used to allocate page blocks + */ + LoraCache(LoraCachePageManagerConfig const& pageManagerConfig, GptModelConfig const& modelConfig, + WorldConfig const& worldConfig, BufferManager const& bufferManager); + + /** + * \brief put a task in the cache, and claim pages for it, and optionally load task weights. + * + * \param[in] taskId: the task id + * \param[in] weights: lora weights tensor + * \param[in] config: lora config tensor + * \param[in] load: if true load weights before returning, otherwise do not + */ + void put(TaskIdType taskId, TensorPtr weights, TensorPtr config, bool load = true); + + /** + * \brief load task weights. This method must be called after put. It is designed to be called asynchronously + * after put returns with load = false + * + * \param[in] taslId: the task id + * \param[in] weights: lora weights tensor + * \param[in] config: lora config tensor + */ + void loadWeights(TaskIdType taskId, TensorPtr weights, TensorPtr config); + + /** + * \param[in] taskId: the task id + * \returns -- true if task is loaded (weights are in place) and false otherwise + */ + [[nodiscard]] inline bool isLoaded(TaskIdType taskId) const + { + std::lock_guard lk(mCacheMutex); + return kVALUE_STATUS_LOADED == getStatus(taskId); + } + + /** + * \param[in] taskId: the task id + * \returns -- true if task is marked done and can be evicted + */ + [[nodiscard]] bool isDone(TaskIdType taskId) const; + + /** + * \param[in] taskId: the task id + * \returns -- true if task is in the cache (not necessarily loaded) and false otherwise + */ + [[nodiscard]] inline bool has(TaskIdType taskId) const + { + std::lock_guard lk(mCacheMutex); + return kVALUE_STATUS_MISSING != getStatus(taskId); + } + + /** + * \param[in] taskId: the task id + * \returns -- list of Value objects with pointers to task weights + */ + [[nodiscard]] std::shared_ptr> get(TaskIdType taskId); + + /** + * \brief bump task and make it the most recently used + * + * \param[in] taskId: the task id + */ + void bump(TaskIdType taskId); + + /** + * \brief mark task done meaning it can be evicted + * \param[in] taskId: the task id + */ + void markTaskDone(TaskIdType taskId); + + /** + * \brief mark all tasks in cache done + */ + void markAllDone(); + + /** + * \param[in] taskId: the taskid + * \returns -- number of pages needed to store the given task + */ + [[nodiscard]] SizeType determineNumPages(TaskIdType taskId) const; + + /** + * \param[in] config: lora config tensor + * \returns -- number of pages needed to store the task configured with config tensor + */ + [[nodiscard]] SizeType determineNumPages(TensorPtr config) const; + + /** + * \param[in] config: a lora config tensor + * \returns -- true in task fits in cache false otherwise + */ + [[nodiscard]] bool fits(TensorPtr config) const; + + /** + * \brief copy task to another cache. Caches must have the same page size. + * \param[in] taskId: the task id to copy + * \param[in] otherCache: the LoraCache to move the task to + * \param[in] markDone: mark the copied task done as it's copied + */ + void copyTask(TaskIdType taskId, LoraCache& deviceCache, bool markDone = false); + + /** + * \returns -- total number of pages allocated to cache (used or not) + */ + [[nodiscard]] SizeType getNumPages() const; + + /** + * \param[in] pageId: the page id + * \returns -- const pointer to page + */ + [[nodiscard]] ITensor::SharedConstPtr getPagePtr(size_t pageId) const; + + /** + * \brief Copy task weights to cache pages. + * \param[in] weights: task weights + * \param[in] config: task config tensor + * \param[in] modelConfig: a GptModelConfig + * \param[in] worldConfig: a WorldConfig + * \param[in] modelIdToModel: map from lora module id to LoraModule + * \param[in] manager: a BufferManager the manager to use to perform the copies + * \param[out] pages: list of page tensors to copy weights to + * \param[in] pageIds: page ids for the pages + * \returns -- list of cache Values objects + */ + static std::vector copyToPages(TensorPtr weights, TensorPtr config, + GptModelConfig const& modelConfig, WorldConfig const& worldConfig, + std::unordered_map moduleIdToModel, BufferManager const& manager, + std::vector const& pages, std::vector const& pageIds); + + /** + * \brief splits second dim of input into tpSize parts and writes the tpRank split to output + * \param[out] output: output tensor + * \param[in] input: input tensor + * \param[in] tpSize: number of splits + * \param[in] tpRank: the split to write to output + */ + static void splitTransposeCpu(ITensor& output, ITensor const& input, SizeType tpSize, SizeType tpRank); + +private: + /** + * \brief Holds configuration and state for a single task + */ + struct TaskValue + { + // pageIds holding this tasks weights + std::vector pageIds; + // locations of weights in pages + TaskLayerModuleConfigListPtr configs; + // ordered location of this value in either mDoneTasks or mInProgressTasks + std::list::iterator it; + + /* indicates if the task is inProgress (in mInProgress list, not evictable) + * if inProgress=false the task is in mDoneTasks list. + */ + bool inProgress; + /* + * indicates the weights have been copied into the cache. + * If inProgress=true and loaded=false we are in the middle of adding the task to the cache. + * We cannot evict or copyTask tasks in this state. + */ + bool loaded; + /** + * Marks a task a done. This is used to mark a task as done during loading. + * if done=true at the end of loading (end of put, loadweights, or copyTask) the task will be marked as done + */ + bool done; + /** + * Indicates weights are loading either in put or loadWeights + * This is used to block concurrent loadWeights calls for the same task. + */ + bool loadInProgress; + + TaskValue() = delete; + ~TaskValue() = default; + + TaskValue(std::vector const& pageIds, TaskLayerModuleConfigListPtr const& configs, + std::list::iterator it, bool inProgress, bool loaded, bool done, bool loadInProgress = false) + : pageIds(pageIds) + , configs(configs) + , it(it) + , inProgress(inProgress) + , loaded(loaded) + , done(done) + , loadInProgress(loadInProgress) + { + } + + TaskValue(TaskValue&& o) noexcept + { + std::swap(pageIds, o.pageIds); + std::swap(configs, o.configs); + std::swap(it, o.it); + std::swap(inProgress, o.inProgress); + std::swap(loaded, o.loaded); + std::swap(done, o.done); + std::swap(loadInProgress, o.loadInProgress); + } + + TaskValue& operator=(TaskValue&& o) + { + std::swap(pageIds, o.pageIds); + std::swap(configs, o.configs); + std::swap(it, o.it); + std::swap(inProgress, o.inProgress); + std::swap(loaded, o.loaded); + std::swap(done, o.done); + std::swap(loadInProgress, o.loadInProgress); + return *this; + } + }; + + using TaskValuePtr = std::shared_ptr; + + enum ValueStatus + { + // task is not in the cache (inProgress or Done) + kVALUE_STATUS_MISSING = 0, + // task is in cache, but weights are not + kVALUE_STATUS_PROCESSING = 1, + // task and weights are in the cache + kVALUE_STATUS_LOADED = 2, + }; + + LoraCachePageManagerConfig mPageManagerConfig; + GptModelConfig mModelConfig; + WorldConfig mWorldConfig; + + // Protects mCachePageManager + mutable std::mutex mPagesMutex; + std::unique_ptr mCachePageManager; + + /* + * Protects mutations of mCacheMap, mInProgressTasks and mDoneTasks + * And the state booleans in TaskValue (ie inProgress, loaded, done, loadInProgress) + * mCacheMutex does not protect other values within a TaskValue (ie weights, pageIds, etc) + */ + mutable std::mutex mCacheMutex; + std::unordered_map mCacheMap; + std::list mInProgressTasks; + std::list mDoneTasks; + + std::vector> mDeviceBufferManagers; + std::unique_ptr mBufferManager; + + std::unordered_map mModuleIdToModule; + + template + static void splitTransposeCpuInner(ITensor& output, ITensor const& input, SizeType tpSize, SizeType tpRank); + + void loadWeights(TaskValue& cacheValue, TensorPtr weights, TensorPtr config); + void bumpTaskInProgress(TaskIdType taskId); + [[nodiscard]] ValueStatus getStatus(TaskIdType taskId) const; + + /** + * \brief claim numPages, evicting tasks if needed + * \param[in] numPages: number of pages to claim + * \returns -- list of page ids + * \throws std::runtime_error if all pages cannot be claimed + */ + [[nodiscard]] std::vector claimPagesWithEvict(SizeType numPages); + + /** + * Internal helper method used inside copyTask. Not thread safe on its own + */ + std::map> copyTaskMapPages(TaskValue& targetTaskValue, + TaskValue const& sourceTaskValue, std::vector const& targetPageIds, LoraCache const& targetCache); +}; + +std::string to_string(LoraCache::TaskLayerModuleConfig const& v); + +std::ostream& operator<<(std::ostream& os, LoraCache::TaskLayerModuleConfig const& v); + +} // namespace tensorrt_llm::runtime diff --git a/cpp/include/tensorrt_llm/runtime/loraCachePageManagerConfig.h b/cpp/include/tensorrt_llm/runtime/loraCachePageManagerConfig.h new file mode 100644 index 000000000..83b19505a --- /dev/null +++ b/cpp/include/tensorrt_llm/runtime/loraCachePageManagerConfig.h @@ -0,0 +1,168 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "tensorrt_llm/runtime/common.h" +#include "tensorrt_llm/runtime/iBuffer.h" +#include +#include +#include +#include + +namespace tensorrt_llm::runtime +{ +/** + * Configuration for LoraCachePageManager + * + * See LoraCache docs for description of pages, slots, and page blocks. + */ +class LoraCachePageManagerConfig +{ +public: + explicit constexpr LoraCachePageManagerConfig(runtime::MemoryType memType, nvinfer1::DataType dType, + SizeType totalNumPages, SizeType maxPagesPerBlock, SizeType slotsPerPage, SizeType pageWidth, + SizeType numCopyStreams) + : mMemoryType(memType) + , mDataType(dType) + , mTotalNumPages(totalNumPages) + , mMaxPagesPerBlock(maxPagesPerBlock) + , mSlotsPerPage(slotsPerPage) + , mPageWidth(pageWidth) + , mInitToZero(false) + { + } + + [[nodiscard]] runtime::MemoryType constexpr getMemoryType() const noexcept + { + return mMemoryType; + } + + void constexpr setMemoryType(runtime::MemoryType const& memoryType) noexcept + { + mMemoryType = memoryType; + } + + [[nodiscard]] nvinfer1::DataType constexpr getDataType() const noexcept + { + return mDataType; + } + + void constexpr setDataType(nvinfer1::DataType const& dtype) noexcept + { + mDataType = dtype; + } + + [[nodiscard]] SizeType constexpr getTotalNumPages() const noexcept + { + return mTotalNumPages; + } + + void constexpr setTotalNumPage(SizeType const& totalNumPages) noexcept + { + mTotalNumPages = totalNumPages; + } + + [[nodiscard]] SizeType constexpr getMaxPagesPerBlock() const noexcept + { + return mMaxPagesPerBlock; + } + + void constexpr setMaxPagesPerBlock(SizeType const& maxPagesPerBlock) noexcept + { + mMaxPagesPerBlock = maxPagesPerBlock; + } + + [[nodiscard]] SizeType constexpr getSlotsPerPage() const noexcept + { + return mSlotsPerPage; + } + + void constexpr setSlotsPerPage(SizeType const& slotsPerPage) noexcept + { + mSlotsPerPage = slotsPerPage; + } + + [[nodiscard]] SizeType constexpr getPageWidth() const noexcept + { + return mPageWidth; + } + + void constexpr setPageWidth(SizeType const& pageWidth) noexcept + { + mPageWidth = pageWidth; + } + + [[nodiscard]] bool constexpr getInitToZero() const noexcept + { + return mInitToZero; + } + + void constexpr setInitToZero(bool initToZero) noexcept + { + mInitToZero = initToZero; + } + + [[nodiscard]] SizeType constexpr getNumCopyStreams() const noexcept + { + return mNumCopyStreams; + } + + void constexpr setNumCopyStreams(SizeType numCopyStreams) noexcept + { + mNumCopyStreams = numCopyStreams; + } + +private: + runtime::MemoryType mMemoryType; + nvinfer1::DataType mDataType; + + /* + * Number cache pages in the cache. + * Generally corresponds to the number of opt sized LoRAs that can be stored in the cache + */ + SizeType mTotalNumPages; + // number of pages to allocate in one block + SizeType mMaxPagesPerBlock; + // number of slots per page, where a slot corresponds to a adapterSize=1, 1-layer, 1-module set or weights + SizeType mSlotsPerPage; + SizeType mPageWidth; + + // number of streams used to copy pages to device cache + SizeType mNumCopyStreams = 1; + + bool mInitToZero; // for testing +}; + +inline std::ostream& operator<<(std::ostream& os, LoraCachePageManagerConfig const& c) +{ + os << "{" + << "memoryType=" << static_cast::type>(c.getMemoryType()) + << " dataType=" << static_cast::type>(c.getDataType()) + << " totalNumPages=" << c.getTotalNumPages() << " maxPagesPerBlock=" << c.getMaxPagesPerBlock() + << " slotsPerPage=" << c.getSlotsPerPage() << " pageWidth=" << c.getPageWidth() + << " initToZero=" << c.getInitToZero() << "}"; + return os; +} + +inline std::string to_string(LoraCachePageManagerConfig const& c) +{ + std::stringstream sstream; + sstream << c; + return sstream.str(); +} +} // namespace tensorrt_llm::runtime diff --git a/cpp/tensorrt_llm/runtime/loraModule.h b/cpp/include/tensorrt_llm/runtime/loraModule.h similarity index 77% rename from cpp/tensorrt_llm/runtime/loraModule.h rename to cpp/include/tensorrt_llm/runtime/loraModule.h index c2026542c..81a33f514 100644 --- a/cpp/tensorrt_llm/runtime/loraModule.h +++ b/cpp/include/tensorrt_llm/runtime/loraModule.h @@ -71,6 +71,67 @@ class LoraModule return adapterSize * (mInDim + mOutDim); } + [[nodiscard]] SizeType constexpr inSize(SizeType adapterSize) const noexcept + { + return adapterSize * mInDim; + } + + [[nodiscard]] SizeType constexpr outSize(SizeType adapterSize) const noexcept + { + return adapterSize * mOutDim; + } + + [[nodiscard]] SizeType constexpr localInSize(SizeType adapterSize, SizeType tpSize) const noexcept + { + return localInAdapterSize(adapterSize, tpSize) * localInDim(tpSize); + } + + [[nodiscard]] SizeType constexpr localOutSize(SizeType adapterSize, SizeType tpSize) const noexcept + { + return localOutAdapterSize(adapterSize, tpSize) * localOutDim(tpSize); + } + + [[nodiscard]] SizeType constexpr localInDim(SizeType tpSize) const noexcept + { + if (inTpSplitDim() == 1) + { + return inDim() / tpSize; + } + return inDim(); + } + + [[nodiscard]] SizeType constexpr localOutDim(SizeType tpSize) const noexcept + { + if (outTpSplitDim() == 0) + { + return outDim() / tpSize; + } + return outDim(); + } + + [[nodiscard]] SizeType constexpr localInAdapterSize(SizeType adapterSize, SizeType tpSize) const noexcept + { + if (inTpSplitDim() == 0) + { + return adapterSize / tpSize; + } + return adapterSize; + } + + [[nodiscard]] SizeType constexpr localOutAdapterSize(SizeType adapterSize, SizeType tpSize) const noexcept + { + if (outTpSplitDim() == 1) + { + return adapterSize / tpSize; + } + return adapterSize; + } + + [[nodiscard]] SizeType constexpr localInOutSize(SizeType adapterSize, SizeType tpSize) const noexcept + { + return localInSize(adapterSize, tpSize) + localOutSize(adapterSize, tpSize); + } + [[nodiscard]] SizeType constexpr value() const noexcept { return static_cast(mType); diff --git a/cpp/include/tensorrt_llm/runtime/worldConfig.h b/cpp/include/tensorrt_llm/runtime/worldConfig.h index 22e31d92f..2f33036c4 100644 --- a/cpp/include/tensorrt_llm/runtime/worldConfig.h +++ b/cpp/include/tensorrt_llm/runtime/worldConfig.h @@ -106,13 +106,13 @@ class WorldConfig [[nodiscard]] std::vector getPipelineParallelGroup() const; - static bool validConfig(SizeType tensorParallelism, SizeType pipelineParallelism); - static WorldConfig mpi(SizeType gpusPerNode = kDefaultGpusPerNode, std::optional tensorParallelism = std::nullopt, std::optional pipelineParallelism = std::nullopt, std::optional> const& deviceIds = std::nullopt); + [[nodiscard]] bool validMpiConfig() const; + private: SizeType mTensorParallelism; SizeType mPipelineParallelism; diff --git a/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.a b/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.a index e9a6edcb6..977e4c72d 100644 --- a/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.a +++ b/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.a @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:c56ee13bb109917ab10df168ca15e6057436df1cd8b64a4268c6e7aae78a5ad8 -size 2126310 +oid sha256:fd8e608359009dffbcc5817cd96531254c3ad13df7030b3b7cdf2d609fea99e1 +size 2408892 diff --git a/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a b/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a index 8caf13772..7084a8a4a 100644 --- a/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a +++ b/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:339532215fa4c16e68ca28ee23d0a0e09c9caefa7bd19b563d2f7b83cad6822e -size 2142070 +oid sha256:e59449c78d8682be1f0671fa6d8073c71eb37ae452417b70f70bb7db4a68f48b +size 2434826 diff --git a/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/version.txt b/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/version.txt index 35e4cafb8..4a1c133db 100644 --- a/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/version.txt +++ b/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/version.txt @@ -1,3 +1,3 @@ -c9c505e2cb6e95b7cfc124c04ab1fcb3 libtensorrt_llm_batch_manager_static.a -2f5cec5a5b42e0031bc2edc688c1e74b libtensorrt_llm_batch_manager_static.pre_cxx11.a -741fb083cc42933439ae54557b177b6d7064da4f commit +ae7c209c38b4c343b0fc49decff6fed5 libtensorrt_llm_batch_manager_static.a +f2fdaabe328c0eb1e46e8ded7bec4d87 libtensorrt_llm_batch_manager_static.pre_cxx11.a +d2cce02a8 commit diff --git a/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.a b/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.a index a2cced8e4..bb6977900 100644 --- a/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.a +++ b/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.a @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:a4060f2d60472850344e5b5799f9ad88390f4ad9c056e3843f3bdbcc046ca68b -size 2106440 +oid sha256:88e519a38b4172b960083acf12db2ce17c880ce355cc1c9361f1ae85d839551d +size 2377646 diff --git a/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a b/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a index 0290f5c4e..705661b9a 100644 --- a/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a +++ b/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:829f1ed5af0b0d2577e57fd13979706fe0b3636bd6338aac3c34a615f64afedc -size 2064310 +oid sha256:54199fac4bbe94dc314bed8c889753cbb00d2bad1e672384a350dc2b97e4a0b1 +size 2343620 diff --git a/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/version.txt b/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/version.txt deleted file mode 100644 index 06938d29a..000000000 --- a/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/version.txt +++ /dev/null @@ -1,2 +0,0 @@ -2db5c985786dad3dd16c22ec54af0803 libtensorrt_llm_batch_manager_static.a -96940249ff7b3ff09754b89ad25fcf9f libtensorrt_llm_batch_manager_static.pre_cxx11.a diff --git a/cpp/tensorrt_llm/common/reduceKernelUtils.cuh b/cpp/tensorrt_llm/common/reduceKernelUtils.cuh index 16f048a7e..979d8dd6f 100644 --- a/cpp/tensorrt_llm/common/reduceKernelUtils.cuh +++ b/cpp/tensorrt_llm/common/reduceKernelUtils.cuh @@ -291,43 +291,46 @@ __inline__ __device__ void cgBlockReduceSumElements(float* element_list, float* template struct TopK { - int p[MAX_K]; - T u[MAX_K]; + int p[MAX_K]; // index, being -1 at the tail if the array is not full + T u[MAX_K]; // value in descend order, being -MAX_T_VAL if the element is invalid - __device__ __forceinline__ void insert(T elem, int elem_id) + __device__ __forceinline__ void insert(T const elem, int const elem_id) { if (elem_id < 0) { return; } - - if (elem > u[MAX_K - 1] || (p[MAX_K - 1] == -1) || ((elem == u[MAX_K - 1]) && (elem_id < p[MAX_K - 1]))) - // if (elem > u[MAX_K-1] || ((elem == u[MAX_K-1]) && (elem_id < p[MAX_K-1]))) + // Condition of updating the array + // 1. array is not full + // 2. elem is greater than the smallest (last) element in the array + // 3. elem is equal to the smallest (last) element in the array but its elem_id is smaller + bool const need_update + = (p[MAX_K - 1] == -1 || elem > u[MAX_K - 1] || elem == u[MAX_K - 1] && elem_id < p[MAX_K - 1]); + if (!need_update) { - u[MAX_K - 1] = elem; - p[MAX_K - 1] = elem_id; + return; } - - for (int k = MAX_K - 2; k >= 0; --k) + // Find suitable index for the new element + int i; + for (i = MAX_K - 2; i >= 0; --i) { - if ((u[k + 1] > u[k]) || (p[k] == -1) || ((u[k + 1] == u[k]) && (p[k + 1] < p[k]))) - // if ((u[k+1] > u[k]) || ((u[k+1] == u[k])&&(p[k+1] < p[k]))) - { - T u2 = u[k]; - int p2 = p[k]; - u[k] = u[k + 1]; - p[k] = p[k + 1]; - u[k + 1] = u2; - p[k + 1] = p2; - } + bool const need_decrease = (p[i] == -1 || elem > u[i] || elem == u[i] && elem_id < p[i]); + if (!need_decrease) + break; + } + // Move elements to correct positions + for (int k = MAX_K - 2; k >= i; --k) + { + p[k + 1] = p[k]; + u[k + 1] = u[k]; } + p[i] = elem_id; + u[i] = elem; } __device__ __forceinline__ void init() { - bool const IS_FP16 = std::is_same::value; - const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; - + T const MAX_T_VAL = (std::is_same::value) ? HALF_FLT_MAX : FLT_MAX; for (int i = 0; i < MAX_K; i++) { p[i] = -1; diff --git a/cpp/tensorrt_llm/executor/aarch64-linux-gnu/libtensorrt_llm_executor_static.a b/cpp/tensorrt_llm/executor/aarch64-linux-gnu/libtensorrt_llm_executor_static.a index 69e9c4741..4402bae54 100644 --- a/cpp/tensorrt_llm/executor/aarch64-linux-gnu/libtensorrt_llm_executor_static.a +++ b/cpp/tensorrt_llm/executor/aarch64-linux-gnu/libtensorrt_llm_executor_static.a @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:e1cdcabfbc5115c0d3228c567800d2706f1bc9e3752aaaa8148bcfe83be2c08c -size 716756 +oid sha256:57a1c54097341e561ae44f5ae69fa6a7e33061e2d0451d2f42a37f22993a22bb +size 818584 diff --git a/cpp/tensorrt_llm/executor/aarch64-linux-gnu/libtensorrt_llm_executor_static.pre_cxx11.a b/cpp/tensorrt_llm/executor/aarch64-linux-gnu/libtensorrt_llm_executor_static.pre_cxx11.a index 473be92c1..038d2d773 100644 --- a/cpp/tensorrt_llm/executor/aarch64-linux-gnu/libtensorrt_llm_executor_static.pre_cxx11.a +++ b/cpp/tensorrt_llm/executor/aarch64-linux-gnu/libtensorrt_llm_executor_static.pre_cxx11.a @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:ea48a79b211bc9857e7a881d6b9bc22580280e1d7cf3b30d6613466f4f440f8f -size 721934 +oid sha256:3d443d55b92501991a6102c523d46ddfdf620fa5ab37abcee3e2d6ee4c4d9e90 +size 833262 diff --git a/cpp/tensorrt_llm/executor/aarch64-linux-gnu/version.txt b/cpp/tensorrt_llm/executor/aarch64-linux-gnu/version.txt index 25ff8fe51..ca136b29c 100644 --- a/cpp/tensorrt_llm/executor/aarch64-linux-gnu/version.txt +++ b/cpp/tensorrt_llm/executor/aarch64-linux-gnu/version.txt @@ -1,3 +1,3 @@ -56853a19cf213aa5330ea087c9d86a60 libtensorrt_llm_executor_static.a -213487d55c816a1987aa79547091068f libtensorrt_llm_executor_static.pre_cxx11.a -741fb083cc42933439ae54557b177b6d7064da4f commit +b92b19f8d7eff851dadb8a8e3010a565 libtensorrt_llm_executor_static.a +a546902e11b24c1b890fd913c3e844c5 libtensorrt_llm_executor_static.pre_cxx11.a +d2cce02a8 commit diff --git a/cpp/tensorrt_llm/executor/x86_64-linux-gnu/libtensorrt_llm_executor_static.a b/cpp/tensorrt_llm/executor/x86_64-linux-gnu/libtensorrt_llm_executor_static.a index 6dbc45f08..5d33f7e09 100644 --- a/cpp/tensorrt_llm/executor/x86_64-linux-gnu/libtensorrt_llm_executor_static.a +++ b/cpp/tensorrt_llm/executor/x86_64-linux-gnu/libtensorrt_llm_executor_static.a @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:499f3aac1b98c5b411f1dacdddf8521b2b1f600388b44e6f7aab5b3f0cdf1280 -size 721366 +oid sha256:9233382570d3c9c5417ed1f279c234d323b4dd465bbdca86612e137fabfb9962 +size 866182 diff --git a/cpp/tensorrt_llm/executor/x86_64-linux-gnu/libtensorrt_llm_executor_static.pre_cxx11.a b/cpp/tensorrt_llm/executor/x86_64-linux-gnu/libtensorrt_llm_executor_static.pre_cxx11.a index 236314ea7..044de3cae 100644 --- a/cpp/tensorrt_llm/executor/x86_64-linux-gnu/libtensorrt_llm_executor_static.pre_cxx11.a +++ b/cpp/tensorrt_llm/executor/x86_64-linux-gnu/libtensorrt_llm_executor_static.pre_cxx11.a @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:9c2c7e84be6b0e8baf296196ee9d7e84509bda2630ce3ada8a39dc498713ff48 -size 700000 +oid sha256:03ee314aa8ca65abf013c6e5106b701defb5c1435d5fe8879829952c1d2cab1f +size 812078 diff --git a/cpp/tensorrt_llm/executor/x86_64-linux-gnu/version.txt b/cpp/tensorrt_llm/executor/x86_64-linux-gnu/version.txt deleted file mode 100644 index 61e1ebf3f..000000000 --- a/cpp/tensorrt_llm/executor/x86_64-linux-gnu/version.txt +++ /dev/null @@ -1,2 +0,0 @@ -dcca3b095dad76dac36611be6104f011 libtensorrt_llm_executor_static.a -6cae7ce493704f7ad8d724cf8a538e2c libtensorrt_llm_executor_static.pre_cxx11.a diff --git a/cpp/tensorrt_llm/kernels/beamSearchTopkKernels.cu b/cpp/tensorrt_llm/kernels/beamSearchTopkKernels.cu index cf4dfa05a..e1b9e0860 100644 --- a/cpp/tensorrt_llm/kernels/beamSearchTopkKernels.cu +++ b/cpp/tensorrt_llm/kernels/beamSearchTopkKernels.cu @@ -36,726 +36,6 @@ namespace tensorrt_llm namespace kernels { -template -__device__ __forceinline__ T apply_length_penalty(T log_prob, int length, float length_penalty) -{ - // score = log(prob) / (length ^ length_penalty) - if (length_penalty == 0.0f || length == 1) - { - return log_prob; - } - return log_prob / static_cast(powf((float) length, length_penalty)); -} - -template -__launch_bounds__(THREADBLOCK_SIZE) __global__ - void beam_topK_kernel(T const* log_probs, int* topk_tmp_id_buf, T* topk_tmp_val_buf, bool const* finished, - int const* sequence_lengths, int const vocab_size, T diversity_rate, float length_penalty) -{ - typedef cub::BlockReduce, THREADBLOCK_SIZE> BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; - - int thread_id = threadIdx.x; - int block_id = blockIdx.x; // batch beam index. - TopK partial; - - bool const IS_FP16 = std::is_same::value; - const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; - -#pragma unroll - for (int i = 0; i < MAX_K; ++i) - { - partial.p[i] = -1; - partial.u[i] = -MAX_T_VAL; - } - -#pragma unroll - for (int elem_id = thread_id; elem_id < vocab_size; elem_id += THREADBLOCK_SIZE) - { - int index = elem_id + block_id * vocab_size; - T score = length_penalty == 0.0f - ? log_probs[index] - : apply_length_penalty(log_probs[index], - finished[block_id] ? sequence_lengths[block_id] : sequence_lengths[block_id] + 1, length_penalty); - partial.insert(score, index); - } - - TopK total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_op); - - if (thread_id == 0) - { - int index = block_id * MAX_K; - -#pragma unroll - for (int i = 0; i < MAX_K; ++i) - { - topk_tmp_id_buf[index + i] = total.p[i]; - topk_tmp_val_buf[index + i] = total.u[i] + diversity_rate * (T) i; - } - } -} - -template -__launch_bounds__(THREADBLOCK_SIZE) __global__ - void batch_topK_kernel(int* topk_tmp_id_buf, T* topk_tmp_val_buf, int* id_buf) -{ - int thread_id = threadIdx.x; - int block_id = blockIdx.x; - bool const IS_FP16 = std::is_same::value; - const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; - TopK partial; - if (thread_id == 0) - { - for (int i = 0; i < MAX_K; ++i) - { - partial.p[i] = -1; - partial.u[i] = -MAX_T_VAL; - } - - int index = block_id * MAX_K * MAX_K; - for (int i = 0; i < MAX_K * MAX_K; i++) - { - partial.insert((T) topk_tmp_val_buf[index + i], topk_tmp_id_buf[index + i]); - } - - index = block_id * MAX_K; - for (int i = 0; i < MAX_K; i++) - { - id_buf[index + i] = partial.p[i]; - } - } -} - -template -__launch_bounds__(THREADBLOCK_SIZE) __global__ - void batch_topK_kernel_v2(int* topk_tmp_id_buf, T* topk_tmp_val_buf, int* id_buf) -{ - typedef cub::BlockReduce, THREADBLOCK_SIZE> BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; - - int tid = threadIdx.x; - int bid = blockIdx.x; - TopK partial; - bool const IS_FP16 = std::is_same::value; - const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; - -#pragma unroll - for (int i = 0; i < MAX_K; ++i) - { - partial.p[i] = -1; - partial.u[i] = -MAX_T_VAL; - } - - int ite = MAX_K * MAX_K / THREADBLOCK_SIZE; -#pragma unroll - for (int i = 0; i < ite; i++) - { - int index = bid * MAX_K * MAX_K + i * THREADBLOCK_SIZE + tid; - partial.insert((T) topk_tmp_val_buf[index], topk_tmp_id_buf[index]); - } - - TopK total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_op); - - if (tid == 0) - { -#pragma unroll - for (int i = 0; i < MAX_K; i++) - { - id_buf[bid * MAX_K + i] = total.p[i]; - } - } -} - -template -__global__ void topk_stage_1_opt3(T const* __restrict log_probs, T* tmp_log_probs, int* topk_tmp_id_buf, - T* topk_tmp_val_buf, bool const* finished, int const* sequence_lengths, int const k, int const vocab_size, - float const length_penalty, int const* end_ids) -{ - typedef cub::BlockReduce, BLOCK_SIZE_> BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; - - int const tid = threadIdx.x; - int const bid = blockIdx.x; - - int const row_id = bid / BLOCKS_PER_BEAM_; // row id for log_probs (batchbeam index) - int const block_lane = bid % BLOCKS_PER_BEAM_; // block id for a beam - int const tmp_log_buf_index = row_id * vocab_size; - int const tmp_topk_buf_index = row_id * BLOCKS_PER_BEAM_ * k + block_lane * k; - TopK_2 partial; - bool const IS_FP16 = std::is_same::value; - const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; - - if (finished != nullptr && finished[row_id] == true) - { - if (tid < k) - { - int const index = tmp_topk_buf_index + tid; - if (block_lane == 0 && tid == 0) - { - int const end_id = end_ids[row_id / k]; - topk_tmp_id_buf[index] = tmp_log_buf_index + end_id; - topk_tmp_val_buf[index] = log_probs[tmp_log_buf_index + end_id]; - } - else - { - topk_tmp_id_buf[index] = -1; - topk_tmp_val_buf[index] = -MAX_T_VAL; - } - } - return; - } - - for (int elem_id = tid + block_lane * BLOCK_SIZE_; elem_id < vocab_size; elem_id += BLOCK_SIZE_ * BLOCKS_PER_BEAM_) - { - int index = elem_id + tmp_log_buf_index; - tmp_log_probs[index] = log_probs[index]; - } - - for (int ite = 0; ite < k; ite++) - { - partial.init(); -#pragma unroll - for (int elem_id = tid + block_lane * BLOCK_SIZE_; elem_id < vocab_size; - elem_id += BLOCK_SIZE_ * BLOCKS_PER_BEAM_) - { - int index = elem_id + tmp_log_buf_index; - partial.insert(tmp_log_probs[index], index); - } - - TopK_2 total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_op_2); - - if (tid == 0) - { - int const index = tmp_topk_buf_index + ite; - topk_tmp_id_buf[index] = total.p; - topk_tmp_val_buf[index] = total.u; - tmp_log_probs[total.p] = -MAX_T_VAL; - } - __syncthreads(); - } -} - -template -__global__ void topk_stage_2_opt3(int const* __restrict topk_tmp_id_buf, T* topk_tmp_val_buf, int* ids, - BeamHypotheses beam_hyps, int const* end_ids, int const vocab_size, int const k) -{ - int const size = k * k * BLOCKS_PER_BEAM_; - int const tid = threadIdx.x; - int const batch_id = blockIdx.x; - bool const IS_FP16 = std::is_same::value; - const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; - float const length_penalty{beam_hyps.length_penalties == nullptr ? 1.0f : beam_hyps.length_penalties[batch_id]}; - - typedef cub::BlockReduce, BLOCK_SIZE_> BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; - extern __shared__ char array[]; - T* s_val = topk_tmp_val_buf + batch_id * size; - int* s_id = (int*) (array); - - __shared__ int selected_beams; - __shared__ bool is_stop; - - if (tid == 0) - { - selected_beams = 0; - is_stop = false; - } - __syncthreads(); - if (beam_hyps.num_beams != nullptr) - { - int const global_batch_idx = beam_hyps.ite * beam_hyps.local_batch_size + batch_id; - if (beam_hyps.num_beams[global_batch_idx] == 0 && tid == 0) - { - // initialize the buffer - beam_hyps.min_normed_scores[global_batch_idx] = FLT_MAX; - } - else if (beam_hyps.num_beams[global_batch_idx] == k) - { - return; - } - } - - TopK_2 partial; - - // In some cases, we may encounter k finished sentences, but scores are bad. - // So, the max iteration is 2*k here - for (int ite = 0; ite < 2 * k; ite++) - { - partial.init(); -#pragma unroll - for (int i = tid; i < size; i += BLOCK_SIZE_) - { - partial.insert(s_val[i], i); - } - - TopK_2 total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_op_2); - - if (tid == 0) - { - if (beam_hyps.num_beams != nullptr - && topk_tmp_id_buf[batch_id * size + total.p] % vocab_size == end_ids[batch_id]) - { - // if beam_token does not belong to top num_beams tokens, it should not - // be added. Refer from - // https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/generation_beam_search.py#L257 - if (ite >= k) - { - s_val[total.p] = -MAX_T_VAL; - } - else - { - int const global_batch_idx = beam_hyps.ite * beam_hyps.local_batch_size + batch_id; - float const normed_score = apply_length_penalty(s_val[total.p], beam_hyps.step, length_penalty); - int const num_beam = beam_hyps.num_beams[global_batch_idx]; - int beam_idx = num_beam; - // If there are beam_width finished sentences, check that the score of - // selected candidatet is higher than min_normed_score or not. If - // current score is better, replace worst one and update the - // min_normed_score. - if (num_beam == k) - { - if (normed_score < beam_hyps.min_normed_scores[global_batch_idx]) - { - // end the tracing and exist this for loop - selected_beams = k; - is_stop = true; - break; - } - else - { - // find the beam index which's score = min_normed_score, erase it. - for (int j = 0; j < k; j++) - { - if (beam_hyps.normed_scores[global_batch_idx * k + j] - == beam_hyps.min_normed_scores[global_batch_idx]) - { - beam_idx = j; - beam_hyps.num_beams[global_batch_idx]--; - - beam_hyps.min_normed_scores[global_batch_idx] = FLT_MAX; - beam_hyps.normed_scores[global_batch_idx * k + j] = normed_score; - for (int l = 0; l < k; l++) - { - beam_hyps.min_normed_scores[global_batch_idx] - = min(beam_hyps.min_normed_scores[global_batch_idx], - beam_hyps.normed_scores[global_batch_idx * k + l]); - } - break; - } - } - } - } - int const tgt_id_offset = ((batch_id + beam_hyps.ite * beam_hyps.local_batch_size) * k + beam_idx) - * beam_hyps.max_seq_len; - beam_hyps.output_ids_tgt[tgt_id_offset + beam_hyps.step] = end_ids[batch_id]; - - int prev_id = (topk_tmp_id_buf[batch_id * size + total.p] / vocab_size) % k; - for (int j = beam_hyps.step - 1; j >= 0; j--) - { - int const src_idx = j * beam_hyps.batch_size * k - + beam_hyps.ite * beam_hyps.local_batch_size * k + batch_id * k + prev_id; - - beam_hyps.output_ids_tgt[tgt_id_offset + j] = beam_hyps.output_ids_src[src_idx]; - prev_id = beam_hyps.parent_ids_src[src_idx]; - } - int const tgt_beam_idx = global_batch_idx * k + beam_idx; - beam_hyps.sequence_lengths_tgt[tgt_beam_idx] = beam_hyps.step; - beam_hyps.normed_scores[tgt_beam_idx] = normed_score; - beam_hyps.min_normed_scores[global_batch_idx] - = min(beam_hyps.min_normed_scores[global_batch_idx], beam_hyps.normed_scores[tgt_beam_idx]); - - s_val[total.p] = -MAX_T_VAL; - - beam_hyps.num_beams[global_batch_idx]++; - } - } - else - { - s_id[selected_beams] = total.p; - s_val[total.p] = -MAX_T_VAL; - selected_beams++; - } - } - __syncthreads(); - if (selected_beams >= k) - { - break; - } - } - if (tid < k && is_stop == false) - { - ids[batch_id * k + tid] = topk_tmp_id_buf[batch_id * size + s_id[tid]]; - } -} - -template -__global__ void topk_stage_1_opt2_general(T const* __restrict log_probs, T* tmp_log_probs, int* topk_tmp_id_buf, - T* topk_tmp_val_buf, bool const* finished, int const* sequence_lengths, int const k, int const vocab_size, - float const length_penalty) -{ - bool const IS_FP16 = std::is_same::value; - const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; - typedef cub::BlockReduce, BLOCK_SIZE> BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; - - int const tid = threadIdx.x; - int const bid = blockIdx.x; - int const row_id = bid / BLOCKS_PER_BEAM; // row id for log_probs - int const block_lane = bid % BLOCKS_PER_BEAM; // block id for a beam - int const tmp_log_buf_index = row_id * vocab_size; - int const tmp_topk_buf_index = row_id * BLOCKS_PER_BEAM * k + block_lane * k; - TopK_2 partial; - - for (int elem_id = tid + block_lane * BLOCK_SIZE; elem_id < vocab_size; elem_id += BLOCK_SIZE * BLOCKS_PER_BEAM) - { - int index = elem_id + tmp_log_buf_index; - tmp_log_probs[index] = log_probs[index]; - } - - for (int ite = 0; ite < k; ite++) - { - partial.init(); -#pragma unroll - for (int elem_id = tid + block_lane * BLOCK_SIZE; elem_id < vocab_size; elem_id += BLOCK_SIZE * BLOCKS_PER_BEAM) - { - int index = elem_id + tmp_log_buf_index; - partial.insert(tmp_log_probs[index], index); - } - - TopK_2 total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_op_2); - - if (tid == 0) - { - int const index = tmp_topk_buf_index + ite; - topk_tmp_id_buf[index] = total.p; - topk_tmp_val_buf[index] = total.u; - tmp_log_probs[total.p] = -MAX_T_VAL; - } - __syncthreads(); - } -} - -template -__global__ void topk_stage_2_opt2_general(int const* __restrict topk_tmp_id_buf, T* topk_tmp_val_buf, int* ids, - BeamHypotheses beam_hyps, int const* end_ids, int const k, int const vocab_size) -{ - int const size = k * k * BLOCKS_PER_BEAM; - int const tid = threadIdx.x; - int const batch_id = blockIdx.x; - bool const IS_FP16 = std::is_same::value; - const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; - float const length_penalty{beam_hyps.length_penalties == nullptr ? 1.0f : beam_hyps.length_penalties[batch_id]}; - - typedef cub::BlockReduce, BLOCK_SIZE> BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; - extern __shared__ char array[]; - T* s_val = topk_tmp_val_buf + batch_id * size; - int* s_id = (int*) (array); - - __shared__ int selected_beams; - __shared__ bool is_stop; - - if (tid == 0) - { - selected_beams = 0; - is_stop = false; - } - __syncthreads(); - if (beam_hyps.num_beams != nullptr) - { - int const global_batch_idx = beam_hyps.ite * beam_hyps.local_batch_size + batch_id; - if (beam_hyps.num_beams[global_batch_idx] == 0 && tid == 0) - { - beam_hyps.min_normed_scores[global_batch_idx] = FLT_MAX; - } - else if (beam_hyps.num_beams[global_batch_idx] == k) - { - return; - } - } - - TopK_2 partial; - - // In some cases, we may encounter k finished sentences, but scores are bad. - // So, the max iteration is 2*k here - for (int ite = 0; ite < 2 * k; ite++) - { - partial.init(); -#pragma unroll - for (int i = tid; i < size; i += BLOCK_SIZE) - { - partial.insert(s_val[i], i); - } - - TopK_2 total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_op_2); - - if (tid == 0) - { - if (beam_hyps.num_beams != nullptr - && topk_tmp_id_buf[batch_id * size + total.p] % vocab_size == end_ids[batch_id]) - { - // if beam_token does not belong to top num_beams tokens, it should not - // be added. Refer from - // https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/generation_beam_search.py#L257 - if (ite >= k) - { - s_val[total.p] = -MAX_T_VAL; - } - else - { - int const global_batch_idx = beam_hyps.ite * beam_hyps.local_batch_size + batch_id; - float const normed_score = apply_length_penalty(s_val[total.p], beam_hyps.step, length_penalty); - int const num_beam = beam_hyps.num_beams[global_batch_idx]; - int beam_idx = num_beam; - // If there are beam_width finished sentences, check that the score of - // selected candidatet is higher than min_normed_score or not. If - // current score is better, replace worst one and update the - // min_normed_score. - if (num_beam == k) - { - if (normed_score < beam_hyps.min_normed_scores[global_batch_idx]) - { - // end the tracing and exist this for loop - selected_beams = k; - is_stop = true; - break; - } - else - { - // find the beam index which's score = min_normed_score, erase it. - for (int j = 0; j < k; j++) - { - if (beam_hyps.normed_scores[global_batch_idx * k + j] - == beam_hyps.min_normed_scores[global_batch_idx]) - { - beam_idx = j; - beam_hyps.num_beams[global_batch_idx]--; - - beam_hyps.min_normed_scores[global_batch_idx] = FLT_MAX; - beam_hyps.normed_scores[global_batch_idx * k + j] = normed_score; - for (int l = 0; l < k; l++) - { - beam_hyps.min_normed_scores[global_batch_idx] - = min(beam_hyps.min_normed_scores[global_batch_idx], - beam_hyps.normed_scores[global_batch_idx * k + l]); - } - break; - } - } - } - } - int const tgt_id_offset = ((batch_id + beam_hyps.ite * beam_hyps.local_batch_size) * k + beam_idx) - * beam_hyps.max_seq_len; - beam_hyps.output_ids_tgt[tgt_id_offset + beam_hyps.step] = end_ids[batch_id]; - - int prev_id = (topk_tmp_id_buf[batch_id * size + total.p] / vocab_size) % k; - for (int j = beam_hyps.step - 1; j >= 0; j--) - { - int const src_idx = j * beam_hyps.batch_size * k - + beam_hyps.ite * beam_hyps.local_batch_size * k + batch_id * k + prev_id; - - beam_hyps.output_ids_tgt[tgt_id_offset + j] = beam_hyps.output_ids_src[src_idx]; - prev_id = beam_hyps.parent_ids_src[src_idx]; - } - int const tgt_beam_idx = global_batch_idx * k + beam_idx; - beam_hyps.sequence_lengths_tgt[tgt_beam_idx] = beam_hyps.step; - beam_hyps.normed_scores[tgt_beam_idx] = normed_score; - beam_hyps.min_normed_scores[global_batch_idx] - = min(beam_hyps.min_normed_scores[global_batch_idx], beam_hyps.normed_scores[tgt_beam_idx]); - - s_val[total.p] = -MAX_T_VAL; - - beam_hyps.num_beams[global_batch_idx]++; - } - } - else - { - s_id[selected_beams] = total.p; - s_val[total.p] = -MAX_T_VAL; - selected_beams++; - } - } - __syncthreads(); - if (selected_beams >= k) - { - break; - } - } - if (tid < k && is_stop == false) - { - ids[batch_id * k + tid] = topk_tmp_id_buf[batch_id * size + s_id[tid]]; - } -} - -#define CASE_K_DIV(K, BLOCK_SIZE_1, BLOCK_SIZE_2) \ - case K: \ - beam_topK_kernel<<>>(log_probs, \ - topk_tmp_id_buf, topk_tmp_val_buf, finished, sequence_lengths, vocab_size, diversity_rate, \ - length_penalty); \ - if (K < 10) \ - batch_topK_kernel \ - <<>>(topk_tmp_id_buf, topk_tmp_val_buf, ids); \ - else \ - batch_topK_kernel_v2<<>>(topk_tmp_id_buf, topk_tmp_val_buf, ids); \ - break; - -#define CASE_K(K, BLOCK_SIZE_1_, BLOCK_SIZE_2_, BLOCKS_PER_BEAM_) \ - case K: \ - topk_stage_1_opt3 \ - <<>>(log_probs, temp_log_probs, \ - topk_tmp_id_buf, topk_tmp_val_buf, finished, sequence_lengths, beam_width, vocab_size, length_penalty, \ - end_ids); \ - topk_stage_2_opt3 \ - <<>>( \ - topk_tmp_id_buf, topk_tmp_val_buf, ids, *beam_hyps, end_ids, vocab_size, beam_width); \ - sync_check_cuda_error(); \ - break; - -template -void invokeTopkBeamSearch(void* workspace, size_t& workspace_size, T* log_probs, int* ids, BeamHypotheses* beam_hyps, - bool const* finished, int const* sequence_lengths, int const batch_size, int const beam_width, - int const vocab_size_padded_, const T diversity_rate, float const length_penalty, int const* end_ids, - cudaStream_t stream) -{ - // log_probs: (batch, beam, vocab) cumulative log_probs of beams ending with a - // token. - int const vocab_size = vocab_size_padded_; - // Beam size should be less than or equal to vocab size. - assert(beam_width <= vocab_size); - // Beam search needs the sequence lengths of beams to apply length penalty. - assert(length_penalty == 0.0f || sequence_lengths != nullptr); - int const max_block_per_beam = 8; - int temp_log_probs_buf_size = batch_size * beam_width * vocab_size; // type float - int topk_tmp_ids_buf_size = batch_size * beam_width * beam_width * max_block_per_beam; // type int - int topk_tmp_val_buf_size = batch_size * beam_width * beam_width * max_block_per_beam; // type float - - // prevent memory misaligned address - temp_log_probs_buf_size = (int) (ceil(temp_log_probs_buf_size / 4.)) * 4; - topk_tmp_ids_buf_size = (int) (ceil(topk_tmp_ids_buf_size / 4.)) * 4; - topk_tmp_val_buf_size = (int) (ceil(topk_tmp_val_buf_size / 4.)) * 4; - - if (workspace == nullptr) - { - workspace_size = sizeof(float) * temp_log_probs_buf_size + sizeof(int) * topk_tmp_ids_buf_size - + sizeof(float) * topk_tmp_val_buf_size; - return; - } - else - { - T* temp_log_probs = (T*) workspace; - int* topk_tmp_id_buf = (int*) (temp_log_probs + temp_log_probs_buf_size); - T* topk_tmp_val_buf = (T*) (topk_tmp_id_buf + topk_tmp_ids_buf_size); - if (diversity_rate == 0.0f) - { - switch (beam_width) - { - CASE_K(1, 128, 128, 8); - CASE_K(4, 128, 128, 8); - CASE_K(10, 128, 128, 8); - CASE_K(16, 128, 128, 5); - CASE_K(32, 256, 128, 1); - CASE_K(64, 256, 256, 1); - default: - topk_stage_1_opt2_general<<>>(log_probs, - temp_log_probs, topk_tmp_id_buf, topk_tmp_val_buf, finished, sequence_lengths, beam_width, - vocab_size, length_penalty); - topk_stage_2_opt2_general - <<>>(topk_tmp_id_buf, topk_tmp_val_buf, ids, *beam_hyps, end_ids, beam_width, vocab_size); - break; - } - } - else - { - switch (beam_width) - { - CASE_K_DIV(1, 256, 256); - CASE_K_DIV(4, 256, 256); - CASE_K_DIV(16, 256, 64); - CASE_K_DIV(32, 256, 64); - CASE_K_DIV(64, 256, 64); - default: TLLM_THROW("Topk kernel does not support beamwidth = %d \n", beam_width); - } - } - return; - } -} - -#undef CASE_K -#undef CASE_K_DIV - -template void invokeTopkBeamSearch(void* workspace, size_t& workspace_size, float* log_probs, int* ids, - BeamHypotheses* beam_hyps, bool const* finished, int const* sequence_lengths, int const batch_size, - int const beam_width, int const vocab_size_padded_, float const diversity_rate, float const length_penalty, - int const* end_ids, cudaStream_t stream); - -template -__global__ void tileEncoderResults(T* tiled_output, int* tiled_sequence_length, T const* output, - int const* sequence_length, const uint32_t batch_size, const uint32_t beam_width, const uint32_t d_model) -{ - if (blockIdx.x == 0) - { - for (uint32_t i = threadIdx.x; i < batch_size * beam_width; i += blockDim.x) - { - tiled_sequence_length[i] = sequence_length[i / beam_width]; - } - } - - int tgt_offset - = blockIdx.x * gridDim.y * gridDim.z * d_model + blockIdx.y * gridDim.z * d_model + blockIdx.z * d_model; - int src_offset = blockIdx.x * gridDim.z * d_model + blockIdx.z * d_model; - for (uint32_t i = threadIdx.x; i < d_model; i += blockDim.x) - { - tiled_output[i + tgt_offset] = output[i + src_offset]; - } -} - -template -void invokeTileEncoderResults(T* tiled_output, int* tiled_sequence_length, T const* output, int const* sequence_length, - const size_t batch_size, const size_t beam_width, const size_t mem_max_seq_len, const size_t d_model, - cudaStream_t stream) -{ - // tiled_output: [batch_size, beam_width, mem_max_seq_len, d_model] - // tiled_sequence_length: [batch_size, beam_width] - - // output: [batch_size, mem_max_seq_len, d_model] - // sequence_length [batch_size] - - dim3 grid(batch_size, beam_width, mem_max_seq_len); - bool is_half2 = (std::is_same::value) && (d_model % 2 == 0); - - if (is_half2) - { - using T2 = typename TypeConverter::Type; // fp16 to half2, bf16 to bf162 - dim3 block(min(512, (int) (d_model / 2))); - tileEncoderResults<<>>((T2*) tiled_output, tiled_sequence_length, - (const T2*) output, sequence_length, batch_size, beam_width, d_model / 2); - } - else - { - dim3 block(min(512, (int) d_model)); - tileEncoderResults<<>>( - tiled_output, tiled_sequence_length, output, sequence_length, batch_size, beam_width, d_model); - } -} - -template void invokeTileEncoderResults(float* tiled_output, int* tiled_sequence_length, float const* output, - int const* sequence_length, const size_t batch_size, const size_t beam_width, const size_t mem_max_seq_len, - const size_t d_model, cudaStream_t stream); - -template void invokeTileEncoderResults(half* tiled_output, int* tiled_sequence_length, half const* output, - int const* sequence_length, const size_t batch_size, const size_t beam_width, const size_t mem_max_seq_len, - const size_t d_model, cudaStream_t stream); - -template void invokeTileEncoderResults(half2* tiled_output, int* tiled_sequence_length, half2 const* output, - int const* sequence_length, const size_t batch_size, const size_t beam_width, const size_t mem_max_seq_len, - const size_t d_model, cudaStream_t stream); -#ifdef ENABLE_BF16 -template void invokeTileEncoderResults(__nv_bfloat16* tiled_output, int* tiled_sequence_length, - __nv_bfloat16 const* output, int const* sequence_length, const size_t batch_size, const size_t beam_width, - const size_t mem_max_seq_len, const size_t d_model, cudaStream_t stream); -#endif - __global__ void insertUnfinishedPath(BeamHypotheses beam_hyps, FinishedState const* finished, float const* cum_log_probs, int const batch_size, int const beam_width) { @@ -815,75 +95,5 @@ void invokeInsertUnfinishedPath(BeamHypotheses beam_hyps, FinishedState const* f { insertUnfinishedPath<<>>(beam_hyps, finished, cum_log_probs, batch_size, beam_width); } - -__global__ void copyBatchMajorToGeneralPtr( - void* output_ids_ptr, int* output_ids, int batch_size, int beam_width, int max_seq_len) -{ - // output_ids_ptr: batch_size int*, each int* has [beam_width, max_seq_len] - // output_ids: [max_seq_len, batch, beam] - int** output_ids_int_ptr = (int**) output_ids_ptr; - for (int idx = threadIdx.x; idx < beam_width * max_seq_len; idx += blockDim.x) - { - auto const src_step = idx % max_seq_len; - auto const src_beam_idx = idx / max_seq_len; - output_ids_int_ptr[blockIdx.x][idx] - = output_ids[src_step * batch_size * beam_width + blockIdx.x * beam_width + src_beam_idx]; - } -} - -void invokeCopyBatchMajorToGeneralPtr( - void* output_ids_ptr, int* output_ids, int batch_size, int beam_width, int max_seq_len, cudaStream_t stream) -{ - copyBatchMajorToGeneralPtr<<>>( - output_ids_ptr, output_ids, batch_size, beam_width, max_seq_len); -} - -__global__ void copyGeneralPtrToBatchMajor( - int* output_ids, void* output_ids_ptr, int batch_size, int beam_width, int max_seq_len) -{ - // output_ids_ptr: batch_size int*, each int* has [beam_width, max_seq_len] - // output_ids: [max_seq_len, batch, beam] - int** output_ids_int_ptr = (int**) output_ids_ptr; - for (int idx = threadIdx.x; idx < beam_width * max_seq_len; idx += blockDim.x) - { - auto const tgt_step = idx % max_seq_len; - auto const tgt_beam_idx = idx / max_seq_len; - output_ids[tgt_step * batch_size * beam_width + blockIdx.x * beam_width + tgt_beam_idx] - = output_ids_int_ptr[blockIdx.x][idx]; - } -} - -void invokeCopyGeneralPtrToBatchMajor( - int* output_ids, void* output_ids_ptr, int batch_size, int beam_width, int max_seq_len, cudaStream_t stream) -{ - copyGeneralPtrToBatchMajor<<>>( - output_ids, output_ids_ptr, batch_size, beam_width, max_seq_len); -} - -__global__ void SeqlenMajorToBatchMajor( - int* batchMajoredIds, int* seqlenMajorIds, int batch_size, int beam_width, int max_seq_len) -{ - for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < batch_size * beam_width * max_seq_len; - idx += gridDim.x * blockDim.x) - { - auto tmp_idx{idx}; - auto const beam_idx{tmp_idx % beam_width}; - tmp_idx = (tmp_idx - beam_idx) / beam_width; - auto const batch_idx{tmp_idx % batch_size}; - tmp_idx = (tmp_idx - batch_idx) / batch_size; - auto const seqlen_idx{tmp_idx % max_seq_len}; - - batchMajoredIds[batch_idx * beam_width * max_seq_len + beam_idx * max_seq_len + seqlen_idx] - = seqlenMajorIds[idx]; - } -} - -void invokeSeqlenMajorToBatchMajor( - int* batchMajoredIds, int* seqlenMajorIds, int batch_size, int beam_width, int max_seq_len, cudaStream_t stream) -{ - SeqlenMajorToBatchMajor<<>>( - batchMajoredIds, seqlenMajorIds, batch_size, beam_width, max_seq_len); -} - } // namespace kernels } // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/beamSearchTopkKernels.h b/cpp/tensorrt_llm/kernels/beamSearchTopkKernels.h index 8627808de..95e74092b 100644 --- a/cpp/tensorrt_llm/kernels/beamSearchTopkKernels.h +++ b/cpp/tensorrt_llm/kernels/beamSearchTopkKernels.h @@ -23,24 +23,16 @@ namespace tensorrt_llm { namespace kernels { - -// In original beam search implementation, if a beam is finished, we set it as -// finished and only continue to do beam search on remain beams (namely, -// beam_width - 1 beams in next step) -// -// In this implementation, when a beam is finished, we trace the path and record -// it in output_ids_tgt, and also record the normalized scores. And the beam -// search continue to use `beam_width` beams in next step. -// -// After we collect `beam_width` beams, we will sort them by their norm_scores. +// We keep tracing `beam_width` beams during iterations, once a beam is finished, +// we record the ids and its normed score in output_ids_tgt and normed_scores struct BeamHypotheses { // BS: batch_size // BM: beam_width // mSL: max_seq_length - // %%: parameter name when we call [generation.py] dynamic_decoder.forward + // %%: parameter name when we call [generation.py] dynamic_decoder.forward (python workflow) - // Pointers initialized in these two functions: + // Pointers initialized in these two functions below: // [gptDecoder.cpp] GptDecoder::forward or [dynamicDecodeOp.cpp] FtDynamicDecode::forward bool* is_done{nullptr}; // [BS] %% self.beam_hyps_is_done float* cum_log_probs{nullptr}; // [BS, BM*2] %% self.beam_hyps_cum_log_probs @@ -48,7 +40,7 @@ struct BeamHypotheses float* min_normed_scores{nullptr}; // [BS] %% self.beam_hyps_min_normed_scores float* normed_scores{nullptr}; // [BS, BM*2] %% self.beam_hyps_normed_scores int* num_beams{nullptr}; // [BS] %% self.beam_hyps_num_beams - int* output_ids_tgt{nullptr}; // [BS, BM*2, mSL] %% self.beam_hyps_is_done + int* output_ids_tgt{nullptr}; // [BS, BM*2, mSL] %% self.beam_hyps_output_ids_tgt int* sequence_lengths_tgt{nullptr}; // [BS, BM*2] %% self.beam_hyps_sequence_lengths_tgt int const* input_lengths{nullptr}; // [BS*BM] %% context_length @@ -58,12 +50,13 @@ struct BeamHypotheses float* cum_log_probs_src{nullptr}; // [BS, BM] %% self.cum_log_probs float* log_probs_src{nullptr}; // [mSL, BS, BM] %% self.log_probs_tiled int* sequence_lengths_src{nullptr}; // [BS*BM] %% self.sequence_length_buffer - int** output_ids_tgt_ptr{nullptr}; // [BS][BM, mSL] from [dynamicDecodeLayer.cpp] - int** parent_ids_tgt_ptr{nullptr}; // [BS][BM, mSL] from [dynamicDecodeLayer.cpp] + // These two pointers are relocated in [dynamicDecodeLayer.cpp] DynamicDecodeLayer::prepareIdsPtrs + int** output_ids_tgt_ptr{nullptr}; // [BS][BM, mSL] %% self.output_ids + int** parent_ids_tgt_ptr{nullptr}; // [BS][BM, mSL] %% self.parent_ids - float* diversity_rates{nullptr}; // [BS] from SamplingConfig - float* length_penalties{nullptr}; // [BS] from SamplingConfig - int* early_stoppings{nullptr}; // [BS] from SamplingConfig + float* diversity_rates{nullptr}; // [BS] from SamplingConfig + float* length_penalties{nullptr}; // [BS] from SamplingConfig + int* early_stoppings{nullptr}; // [BS] from SamplingConfig // Pointers for function gatherTree int const* output_ids_src{nullptr}; // @@ -80,28 +73,24 @@ struct BeamHypotheses int vocab_size{0}; // vocab_size_padded }; +template +__device__ __forceinline__ T apply_length_penalty(T log_prob, int length, float length_penalty) +{ + // score = log(prob) / (length ^ length_penalty) + if (length_penalty == 0.0f || length == 1) + { + return log_prob; + } + return log_prob / static_cast(powf(static_cast(length), length_penalty)); +} + template void invokeTopkBeamSearch(void* workspace, size_t& workspace_size, T* log_probs, int* ids, BeamHypotheses* beam_hyps, bool const* finished, int const* sequence_lengths, int const batch_size, int const beam_width, int const vocab_size_padded_, const T diversity_rate, float const length_penalty, int const* end_ids, cudaStream_t stream); -template -void invokeTileEncoderResults(T* tiled_encoder_output, int* tiled_encoder_sequence_length, T const* encoder_output, - int const* encoder_sequence_length, const size_t batch_size, const size_t beam_width, const size_t mem_max_seq_len, - const size_t d_model, cudaStream_t stream); - void invokeInsertUnfinishedPath(BeamHypotheses beam_hyps, FinishedState const* finished, float const* cum_log_probs, int const batch_size, int const beam_width, cudaStream_t stream); - -void invokeCopyBatchMajorToGeneralPtr( - void* output_ids_ptr, int* output_ids, int batch_size, int beam_width, int max_seq_len, cudaStream_t stream); - -void invokeCopyGeneralPtrToBatchMajor( - int* output_ids, void* output_ids_ptr, int batch_size, int beam_width, int max_seq_len, cudaStream_t stream); - -void invokeSeqlenMajorToBatchMajor( - int* batchMajoredIds, int* seqlenMajorIds, int batch_size, int beam_width, int max_seq_len, cudaStream_t stream); - } // namespace kernels } // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/decodingKernels.cu b/cpp/tensorrt_llm/kernels/decodingKernels.cu index 8f640c73c..11c84dcf3 100644 --- a/cpp/tensorrt_llm/kernels/decodingKernels.cu +++ b/cpp/tensorrt_llm/kernels/decodingKernels.cu @@ -457,27 +457,39 @@ void invokeInitializeOutput(int* outputIds, int const* endIds, int batchBeam, in initializeOutput<<>>(outputIds, endIds, maxSeqLen); } -__global__ void copyNextStepIds(int* nextStepIds, int** outputIdsPtr, int const* sequenceLengths, int const* batchSlots, - int batchSize, int beamWidth, int maxSeqLen) +__global__ void copyNextStepIds(TokenIdType* nextStepIds, TokenIdType const* const* outputIdsPtr, + SizeType const* sequenceLengths, SizeType const* numNewTokens, SizeType const* batchSlots, SizeType batchSize, + SizeType maxBatchSize, SizeType beamWidth, SizeType maxSeqLen, SizeType maxTokensPerStep) { - for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < batchSize * beamWidth; - index += blockDim.x * gridDim.x) + for (auto index = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + index < batchSize * beamWidth * maxTokensPerStep; index += static_cast(blockDim.x * gridDim.x)) { - int const batchIdx{index / beamWidth}; + auto const batchIdx{index / (beamWidth * maxTokensPerStep)}; auto const batchSlot = batchSlots != nullptr ? batchSlots[batchIdx] : batchIdx; - int const beamIdx{index % beamWidth}; + auto const remainder{index % (beamWidth * maxTokensPerStep)}; + auto const beamIdx{remainder / maxTokensPerStep}; + auto const tokenIdx{remainder % maxTokensPerStep}; + auto const newTokens = numNewTokens == nullptr ? 1 : numNewTokens[batchSlot]; auto const batchBeamIdx = batchSlot * beamWidth + beamIdx; - nextStepIds[batchBeamIdx] = outputIdsPtr[batchSlot][beamIdx * maxSeqLen + sequenceLengths[batchBeamIdx] - 1]; + auto const tokenBatchBeamIdx = tokenIdx * maxBatchSize * beamWidth + batchSlot * beamWidth + beamIdx; + if (tokenIdx >= newTokens) + { + continue; + } + nextStepIds[tokenBatchBeamIdx] + = outputIdsPtr[batchSlot][beamIdx * maxSeqLen + sequenceLengths[batchBeamIdx] - newTokens + tokenIdx]; } } -void invokeCopyNextStepIds(int* nextStepIds, int** outputIdsPtr, int const* sequenceLengths, int const* batchSlots, - int batchSize, int beamWidth, int maxSeqLen, cudaStream_t stream) +void invokeCopyNextStepIds(TokenIdType* nextStepIds, TokenIdType const* const* outputIdsPtr, + SizeType const* sequenceLengths, SizeType const* numNewTokens, SizeType const* batchSlots, SizeType batchSize, + SizeType maxBatchSize, SizeType beamWidth, SizeType maxSeqLen, SizeType maxTokensPerStep, cudaStream_t stream) { - dim3 block(min(256, batchSize * beamWidth)); - dim3 grid(divUp(batchSize * beamWidth, block.x)); - copyNextStepIds<<>>( - nextStepIds, outputIdsPtr, sequenceLengths, batchSlots, batchSize, beamWidth, maxSeqLen); + auto const numElems = batchSize * beamWidth * maxTokensPerStep; + dim3 block(min(256, numElems)); + dim3 grid(divUp(numElems, block.x)); + copyNextStepIds<<>>(nextStepIds, outputIdsPtr, sequenceLengths, numNewTokens, batchSlots, + batchSize, maxBatchSize, beamWidth, maxSeqLen, maxTokensPerStep); } __global__ void transposeLogProbs(float* outputLogProbs, float* outputLogProbsTiled, int const* sequenceLengths, @@ -723,14 +735,14 @@ __device__ __forceinline__ int4 reduceMaxInt4(int4 const& a, int4 const& b) template __global__ void acceptDraftTokensByIdsWithPaths(TokenIdType* outputIds, TokenIdType const* targetIds, - SizeType* sequenceLengths, FinishedState* finishedFinal, SizeType const* batchSlots, SizeType const* paths, - TokenIdType const* endIds, T const* medusaLogits, T const** logitsPtrs, SizeType batchSize, SizeType vocabSize, - SizeType maxBatchSize, SizeType maxDraftSeqLen, SizeType maxTargetSeqLen, SizeType maxNumHeads, + SizeType* sequenceLengths, SizeType* acceptedLengths, FinishedState* finishedFinal, SizeType const* batchSlots, + SizeType const* paths, TokenIdType const* endIds, T const* medusaLogits, T const** logitsPtrs, SizeType batchSize, + SizeType vocabSize, SizeType maxBatchSize, SizeType maxDraftSeqLen, SizeType maxTargetSeqLen, SizeType maxNumHeads, SizeType maxTokensPerStep) { auto const batchIdx = static_cast(blockIdx.x); auto const batchSlot = batchSlots == nullptr ? batchIdx : batchSlots[batchIdx]; - auto& inputLength = sequenceLengths[batchSlot]; + auto const inputLength = sequenceLengths[batchSlot]; auto const endId = endIds[batchSlot]; auto const maxNumDraftTokens = maxNumHeads + 1; @@ -742,8 +754,19 @@ __global__ void acceptDraftTokensByIdsWithPaths(TokenIdType* outputIds, TokenIdT auto acceptedLength = maxNumDraftTokens; auto const pathOffset = flat_index3(batchSlot, pathIdx, 0, maxTokensPerStep, maxNumDraftTokens); bool hasEnd = false; + + auto const tokenId = paths[pathOffset]; + // Continue if path does not exist + if (tokenId == -1) + { + continue; + } + auto const targetTokenIdx = batchSlot * maxTargetSeqLen + tokenId; + auto targetToken = targetIds[targetTokenIdx]; + auto nextIdx = tokenId; + // Go along the path - for (SizeType ti = 0; ti < maxNumDraftTokens; ++ti) + for (SizeType ti = 1; ti < maxNumDraftTokens; ++ti) { auto const tokenId = paths[pathOffset + ti]; // Break if path terminates @@ -755,16 +778,18 @@ __global__ void acceptDraftTokensByIdsWithPaths(TokenIdType* outputIds, TokenIdT auto const targetTokenIdx = batchSlot * maxTargetSeqLen + tokenId; auto const draftTokenIdx = batchSlot * maxDraftSeqLen + inputLength + tokenId; auto const draftToken = outputIds[draftTokenIdx]; - auto const targetToken = targetIds[targetTokenIdx]; // Check if draft tokens are the same as target tokens bool const accepted = draftToken == targetToken; hasEnd = targetToken == endId; if (!accepted || hasEnd) { - acceptedLength = hasEnd ? ti : ti + 1; + acceptedLength = hasEnd ? ti - 1 : ti; + nextIdx = tokenId; break; } + targetToken = targetIds[targetTokenIdx]; + nextIdx = tokenId; } // Get longest path of the thread if (partialMax.x < acceptedLength) @@ -772,6 +797,7 @@ __global__ void acceptDraftTokensByIdsWithPaths(TokenIdType* outputIds, TokenIdT partialMax.x = acceptedLength; partialMax.y = pathIdx; partialMax.z = hasEnd; + partialMax.w = nextIdx; } } @@ -790,6 +816,7 @@ __global__ void acceptDraftTokensByIdsWithPaths(TokenIdType* outputIds, TokenIdT auto const acceptedLength = totalShared.x; auto const bestPathIdx = totalShared.y; + auto const bestNextIdx = totalShared.w; auto const pathOffset = flat_index3(batchSlot, bestPathIdx, 0, maxTokensPerStep, maxNumDraftTokens); for (auto ti = static_cast(threadIdx.x); ti < acceptedLength; ti += static_cast(blockDim.x)) { @@ -801,8 +828,6 @@ __global__ void acceptDraftTokensByIdsWithPaths(TokenIdType* outputIds, TokenIdT outputIds[draftDstTokenIdx] = targetToken; } - __syncthreads(); - // Leading thread reconstructs winning path and sets new data if (threadIdx.x == 0) { @@ -813,42 +838,43 @@ __global__ void acceptDraftTokensByIdsWithPaths(TokenIdType* outputIds, TokenIdT finishedFinal[batchSlot].setFinishedEOS(); } // Make correction to the sequence length - inputLength += acceptedLength; + sequenceLengths[batchSlot] += acceptedLength; + acceptedLengths[batchSlot] = acceptedLength; } // Prepare logits pointers to respective logits from Medusa Heads for the all-top-K sampling kernel for (auto hi = static_cast(threadIdx.x); hi < maxNumHeads; hi += static_cast(blockDim.x)) { logitsPtrs[batchIdx * maxNumHeads + hi] - = medusaLogits + flat_index4(hi, batchIdx, acceptedLength, 0, maxBatchSize, maxTokensPerStep, vocabSize); + = medusaLogits + flat_index4(hi, batchIdx, bestNextIdx, 0, maxBatchSize, maxTokensPerStep, vocabSize); } } template void acceptDraftTokensByIdsWithPaths(TokenIdType* outputIds, TokenIdType const* targetIds, SizeType* sequenceLengths, - FinishedState* finishedFinal, SizeType const* batchSlots, SizeType const* paths, TokenIdType const* endIds, - T const* medusaLogits, T const** logitsPtrs, SizeType batchSize, SizeType vocabSize, SizeType maxBatchSize, - SizeType maxDraftSeqLen, SizeType maxTargetSeqLen, SizeType maxNumHeads, SizeType maxTokensPerStep, - cudaStream_t stream) + SizeType* acceptedLengths, FinishedState* finishedFinal, SizeType const* batchSlots, SizeType const* paths, + TokenIdType const* endIds, T const* medusaLogits, T const** logitsPtrs, SizeType batchSize, SizeType vocabSize, + SizeType maxBatchSize, SizeType maxDraftSeqLen, SizeType maxTargetSeqLen, SizeType maxNumHeads, + SizeType maxTokensPerStep, cudaStream_t stream) { constexpr SizeType BLOCK_SIZE = 256; dim3 block(BLOCK_SIZE); dim3 grid(batchSize); acceptDraftTokensByIdsWithPaths<<>>(outputIds, targetIds, sequenceLengths, - finishedFinal, batchSlots, paths, endIds, medusaLogits, logitsPtrs, batchSize, vocabSize, maxBatchSize, - maxDraftSeqLen, maxTargetSeqLen, maxNumHeads, maxTokensPerStep); + acceptedLengths, finishedFinal, batchSlots, paths, endIds, medusaLogits, logitsPtrs, batchSize, vocabSize, + maxBatchSize, maxDraftSeqLen, maxTargetSeqLen, maxNumHeads, maxTokensPerStep); } template void acceptDraftTokensByIdsWithPaths(TokenIdType* outputIds, TokenIdType const* targetIds, - SizeType* sequenceLengths, FinishedState* finishedFinal, SizeType const* batchSlots, SizeType const* paths, - TokenIdType const* endIds, float const* medusaLogits, float const** logitsPtrs, SizeType batchSize, - SizeType vocabSize, SizeType maxBatchSize, SizeType maxDraftSeqLen, SizeType maxTargetSeqLen, SizeType maxNumHeads, - SizeType maxTokensPerStep, cudaStream_t stream); + SizeType* sequenceLengths, SizeType* acceptedLengths, FinishedState* finishedFinal, SizeType const* batchSlots, + SizeType const* paths, TokenIdType const* endIds, float const* medusaLogits, float const** logitsPtrs, + SizeType batchSize, SizeType vocabSize, SizeType maxBatchSize, SizeType maxDraftSeqLen, SizeType maxTargetSeqLen, + SizeType maxNumHeads, SizeType maxTokensPerStep, cudaStream_t stream); template void acceptDraftTokensByIdsWithPaths(TokenIdType* outputIds, TokenIdType const* targetIds, - SizeType* sequenceLengths, FinishedState* finishedFinal, SizeType const* batchSlots, SizeType const* paths, - TokenIdType const* endIds, half const* medusaLogits, half const** logitsPtrs, SizeType batchSize, - SizeType vocabSize, SizeType maxBatchSize, int32_t maxDraftSeqLen, SizeType maxTargetSeqLen, SizeType maxNumHeads, - SizeType maxTokensPerStep, cudaStream_t stream); + SizeType* sequenceLengths, SizeType* acceptedLengths, FinishedState* finishedFinal, SizeType const* batchSlots, + SizeType const* paths, TokenIdType const* endIds, half const* medusaLogits, half const** logitsPtrs, + SizeType batchSize, SizeType vocabSize, SizeType maxBatchSize, int32_t maxDraftSeqLen, SizeType maxTargetSeqLen, + SizeType maxNumHeads, SizeType maxTokensPerStep, cudaStream_t stream); } // namespace kernels } // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/decodingKernels.h b/cpp/tensorrt_llm/kernels/decodingKernels.h index ba2a891ca..554a2d70d 100644 --- a/cpp/tensorrt_llm/kernels/decodingKernels.h +++ b/cpp/tensorrt_llm/kernels/decodingKernels.h @@ -64,8 +64,29 @@ void invokeFinalize(int32_t* outputIds, int32_t* sequenceLengths, float* cumLogP void invokeInitializeOutput( int32_t* outputIds, int32_t const* endIds, int batchBeam, int maxSeqLen, cudaStream_t stream); -void invokeCopyNextStepIds(int32_t* nextStepIds, int32_t** outputIdsPtr, int32_t const* sequenceLengths, - int32_t const* batchSlots, int32_t batchSize, int32_t beamWidth, int32_t maxSeqLen, cudaStream_t stream); +//! \brief Copies last numNewTokens (or 1 if numNewTokens == nullptr) tokens from outputIdsPtr +//! to nextStepIds according to sequenceLengths. +//! +//! \param nextStepIds output buffer [maxTokensPerStep, maxBatchSize, maxBeamWidth], +//! destination of the new tokens. +//! \param outputIdsPtr input buffer [maxBatchSize][maxBeamWidth, maxSeqLen], +//! array of pointers to the source of the copy. +//! \param sequenceLengths input buffer [maxBatchSize], sequence length of the request +//! in outputIdsPtr that includes all new tokens. It must be guaranteed that sequenceLengths <= maxSeqLen. +//! \param numNewTokens input buffer [maxBatchSize], optional, number of tokens to be copied. +//! If nullptr, only 1 token is copied. It must be guaranteed that numNewTokens <= sequenceLengths. +//! \param batchSlots input buffer [batchSize], address map from local index +//! to global index [0, batchSize] -> [0, maxBatchSize] +//! \param batchSize current batch size +//! \param maxBatchSize maximum batch size +//! \param beamWidth current beam width +//! \param maxSeqLen maximum sequence length +//! \param maxTokensPerStep maximum tokens per step +//! \param stream stream +void invokeCopyNextStepIds(runtime::TokenIdType* nextStepIds, runtime::TokenIdType const* const* outputIdsPtr, + runtime::SizeType const* sequenceLengths, runtime::SizeType const* numNewTokens, + runtime::SizeType const* batchSlots, runtime::SizeType batchSize, runtime::SizeType maxBatchSize, + runtime::SizeType beamWidth, runtime::SizeType maxSeqLen, runtime::SizeType maxTokensPerStep, cudaStream_t stream); //! \brief Accepts or rejects draft tokens based on the equality of draft and target tokens //! for speculative decoding. Target token is accepted if targetToken == draftToken. @@ -82,7 +103,8 @@ void invokeCopyNextStepIds(int32_t* nextStepIds, int32_t** outputIdsPtr, int32_t //! \param finished input buffer [maxDraftTokens + 1, batchSize] finished states at each decoding iteration //! \param finishedFinal output buffer [batchSize] finished states after accepting/rejecting tokens //! \param finishedSum output buffer [1] total number of requests in batch that finished the execution -//! \param batchSlots +//! \param batchSlots input buffer [batchSize], address map from local index +//! to global index [0, batchSize] -> [0, maxBatchSize] //! \param batchSize current batch size //! \param maxBatchSize maximum batch size //! \param beamWidth beam width @@ -114,7 +136,8 @@ void invokeAcceptDraftTokensByIds(int32_t const* draftIds, int32_t const* target //! At each step sets to NOT_FINISHED if token is accepted or SKIP_DECODING if token is not accepted //! \param curandState input buffer [batchSize]. Curand states properly //! initialized using invokeCurandInitialize per request. -//! \param batchSlots +//! \param batchSlots input buffer [batchSize], address map from local index +//! to global index [0, batchSize] -> [0, maxBatchSize] //! \param batchSize current batch size //! \param maxBatchSize maximum batch size //! \param beamWidth beam width @@ -145,6 +168,7 @@ void invokeTransposeLogProbs(float* output_log_probs, float* output_log_probs_ti //! \param targetIds input buffer [maxBatchSize, maxTargetSeqLen], tokens predicted from the target medusa head //! \param sequenceLengths input/output buffer [maxBatchSize], length of the data in outputIds without draft tokens //! Incrememnted according to the accepted length +//! \param acceptedLengths output buffer [maxBatchSize], length of the data accepted tokens //! \param finishedFinal input buffer [maxBatchSize], finished states per request //! \param batchSlots input buffer [batchSize], address map from local index //! to global index [0, batchSize] -> [0, maxBatchSize] @@ -165,10 +189,10 @@ void invokeTransposeLogProbs(float* output_log_probs, float* output_log_probs_ti //! \param stream stream template void acceptDraftTokensByIdsWithPaths(runtime::TokenIdType* outputIds, runtime::TokenIdType const* targetIds, - runtime::SizeType* sequenceLengths, FinishedState* finishedFinal, runtime::SizeType const* batchSlots, - runtime::SizeType const* paths, runtime::TokenIdType const* endIds, T const* medusaLogits, T const** logitsPtrs, - runtime::SizeType batchSize, runtime::SizeType maxBatchSize, runtime::SizeType vocabSize, - runtime::SizeType maxDraftSeqLen, runtime::SizeType maxTargetSeqLen, runtime::SizeType maxNumHeads, - runtime::SizeType maxTokensPerStep, cudaStream_t stream); + runtime::SizeType* sequenceLengths, runtime::SizeType* acceptedLengths, FinishedState* finishedFinal, + runtime::SizeType const* batchSlots, runtime::SizeType const* paths, runtime::TokenIdType const* endIds, + T const* medusaLogits, T const** logitsPtrs, runtime::SizeType batchSize, runtime::SizeType maxBatchSize, + runtime::SizeType vocabSize, runtime::SizeType maxDraftSeqLen, runtime::SizeType maxTargetSeqLen, + runtime::SizeType maxNumHeads, runtime::SizeType maxTokensPerStep, cudaStream_t stream); } // namespace kernels } // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/onlineSoftmaxBeamsearchKernels/onlineSoftmaxBeamsearchKernelsTemplate.h b/cpp/tensorrt_llm/kernels/onlineSoftmaxBeamsearchKernels/onlineSoftmaxBeamsearchKernelsTemplate.h index aee2e59e5..3bd4b195c 100644 --- a/cpp/tensorrt_llm/kernels/onlineSoftmaxBeamsearchKernels/onlineSoftmaxBeamsearchKernelsTemplate.h +++ b/cpp/tensorrt_llm/kernels/onlineSoftmaxBeamsearchKernels/onlineSoftmaxBeamsearchKernelsTemplate.h @@ -41,134 +41,60 @@ static int const SMALL_TOP_K_SOFTMAX_THREADBLOCK_SIZE = 256; #define TOPK_FP16_STORAGE 0 -template -__device__ __forceinline__ T apply_length_penalty(T log_prob, int length, float length_penalty) -{ - // score = log(prob) / (length ^ length_penalty). - if (length_penalty == 0.0f || length == 1) - { - return log_prob; - } - return log_prob / static_cast(powf(length, length_penalty)); -} - -/* -// Useless kernels, remove them? -template -__launch_bounds__(THREADBLOCK_SIZE) __global__ void batch_topK_kernel(int* topk_id, T* topk_val, int* id_buf) -{ - int const thread_id = threadIdx.x; - int const block_id = blockIdx.x; - TopK partial; - - if (thread_id == 0) - { - for (int i = 0; i < MAX_K; ++i) - { - partial.p[i] = -1; - partial.u[i] = -FLT_MAX; - } - - int index = block_id * MAX_K * MAX_K; - for (int i = 0; i < MAX_K * MAX_K; i++) - { - partial.insert(topk_val[index + i], topk_id[index + i]); - } - - index = block_id * MAX_K; - for (int i = 0; i < MAX_K; i++) - { - id_buf[index + i] = partial.p[i]; - } - } -} - -template -__launch_bounds__(THREADBLOCK_SIZE) __global__ void batch_topK_kernel( - int const* __restrict topk_id, T const* __restrict topk_val, int* __restrict id_buf, T* __restrict val_buf) -{ - int const thread_id = threadIdx.x; - int const block_id = blockIdx.x; - TopK partial; - - if (thread_id == 0) - { - for (int i = 0; i < MAX_K; ++i) - { - partial.p[i] = -1; - partial.u[i] = -FLT_MAX; - } - - int index = block_id * MAX_K * MAX_K; - for (int i = 0; i < MAX_K * MAX_K; i++) - { - partial.insert(topk_val[index + i], topk_id[index + i]); - } - - index = block_id * MAX_K; - for (int i = 0; i < MAX_K; i++) - { - id_buf[index + i] = partial.p[i]; - val_buf[index + i] = partial.u[i]; - } - } -} -*/ template __launch_bounds__(THREADBLOCK_SIZE) __global__ void batch_topk_kernel( - int const* __restrict topk_id, T const* __restrict topk_val, BeamHypotheses beam_hyps, int const candidate_size) + int const* __restrict topk_id, T const* __restrict topk_val, BeamHypotheses bh, int const candidate_size) { - int const thread_id = threadIdx.x; - int const vector_id = blockIdx.x; - int const global_batch_idx{beam_hyps.ite * beam_hyps.local_batch_size + vector_id}; - int const K{beam_hyps.beam_width}; - int const vocab_size{beam_hyps.vocab_size}; + int const tid = threadIdx.x; + int const bid = blockIdx.x; + int const global_batch_idx{bh.ite * bh.local_batch_size + bid}; + int const K{bh.beam_width}; + int const vocab_size{bh.vocab_size}; T const MAX_T_VAL = (std::is_same::value) ? HALF_FLT_MAX : FLT_MAX; - float const diversity_rate{beam_hyps.diversity_rates[global_batch_idx]}; - float const length_penalty{beam_hyps.length_penalties[global_batch_idx]}; - int const early_stopping{beam_hyps.early_stoppings[global_batch_idx]}; - int const* input_lengths{beam_hyps.input_lengths}; - int const* sequence_lengths{beam_hyps.sequence_lengths_src}; - - float* __restrict cum_log_probs_src{beam_hyps.cum_log_probs_src}; // copy since it will be modified + float const diversity_rate{bh.diversity_rates[global_batch_idx]}; + float const length_penalty{bh.length_penalties[global_batch_idx]}; + int const early_stopping{bh.early_stoppings[global_batch_idx]}; + int const* input_lengths{bh.input_lengths}; + int* sequence_lengths{bh.sequence_lengths_src}; using cub_kvp = cub::KeyValuePair; using BlockReduce = cub::BlockReduce; - extern __shared__ char buf_s_[]; // intermediate result + extern __shared__ char buf_s_[]; T* buf_s = reinterpret_cast(buf_s_); __shared__ typename BlockReduce::TempStorage temp_storage; - __shared__ float old_cum_log_probs[MAX_K2]; + __shared__ float old_cum_log_probs[MAX_K2 / 2]; __shared__ cub_kvp cta_topk[MAX_K2]; __shared__ int selected_beams; __shared__ int thread_requiring_update; // reposition topk_id, topk_val to data for the current vector - topk_id += vector_id * candidate_size; - topk_val += vector_id * candidate_size; + topk_id += bid * candidate_size; + topk_val += bid * candidate_size; - if (thread_id == 0) + if (tid == 0) { selected_beams = 0; } - if (thread_id < K) + if (tid < K) { - old_cum_log_probs[thread_id] = cum_log_probs_src[vector_id * K + thread_id]; + old_cum_log_probs[tid] = bh.cum_log_probs_src[bid * K + tid]; } __syncthreads(); - if (beam_hyps.num_beams != nullptr) + if (bh.num_beams != nullptr) { // Beam search is enabled - if (beam_hyps.num_beams[global_batch_idx] == 0 && thread_id == 0) + if (bh.num_beams[global_batch_idx] == 0 && tid == 0) { - // Initialize worst_score if this batch has no finished beam - beam_hyps.min_normed_scores[global_batch_idx] = FLT_MAX; + // Initialize worst_score in the first time + bh.min_normed_scores[global_batch_idx] = FLT_MAX; } - else if (beam_hyps.num_beams[global_batch_idx] == K) + else if (early_stopping && bh.num_beams[global_batch_idx] == K + || !early_stopping && bh.finished[bid * K].isFinished()) { - // Return if this batch has enough finished beams + // We have got enough beams return; } } @@ -177,20 +103,20 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__ void batch_topk_kernel( cub::ArgMax arg_max; cub_kvp partial_topk{candidate_size - 1, -MAX_T_VAL}; - for (int id = thread_id; id < candidate_size; id += THREADBLOCK_SIZE) + for (int id = tid; id < candidate_size; id += THREADBLOCK_SIZE) { - int i = beam_hyps.num_beams == nullptr ? id % K : id / 2 / K; - T elem = topk_val[id] + static_cast(diversity_rate * i); // use token score for TopK - cub_kvp new_elem{id, elem}; + int const index = bh.num_beams == nullptr ? id % K : id / 2 / K; + T val = topk_val[id] + static_cast(diversity_rate * index); // use token score for TopK + cub_kvp new_elem{id, val}; partial_topk = arg_max(partial_topk, new_elem); - buf_s[id] = elem; + buf_s[id] = val; } __syncthreads(); for (int i = 0; i < 2 * K; ++i) { cub_kvp total_topk = BlockReduce(temp_storage).Reduce(partial_topk, arg_max); - if (threadIdx.x == 0) + if (tid == 0) { cta_topk[i] = total_topk; buf_s[total_topk.key] = -MAX_T_VAL; @@ -200,19 +126,19 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__ void batch_topk_kernel( // Only one thread needs to update the old partial before the next block reduce. // No need to do this in the last iteration. - if (thread_id == thread_requiring_update && i < (2 * K - 1)) + if (tid == thread_requiring_update && i < (2 * K - 1)) { partial_topk.key = candidate_size - 1; partial_topk.value = -MAX_T_VAL; - for (int tid = thread_id; tid < candidate_size; tid += THREADBLOCK_SIZE) + for (int index = tid; index < candidate_size; index += THREADBLOCK_SIZE) { - cub_kvp new_elem{tid, buf_s[tid]}; + cub_kvp new_elem{index, buf_s[index]}; partial_topk = arg_max(partial_topk, new_elem); } } } - if (thread_id == 0) + if (tid == 0) { // Adjust beams or select completed beams sequentially // Reference (might be changed along HF in the future): @@ -221,104 +147,110 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__ void batch_topk_kernel( { int const current_key = cta_topk[i].key; T const current_value = cta_topk[i].value; - bool const is_end_token = topk_id[current_key] % vocab_size == beam_hyps.end_ids[vector_id]; - - if (i < K && beam_hyps.num_beams != nullptr && is_end_token) + bool const is_end_token = topk_id[current_key] % vocab_size == bh.end_ids[bid]; + if (i < K && bh.num_beams != nullptr && is_end_token) { - // Consider to add beam only if this token is end_token and belongs to top K range - int const seq_len = sequence_lengths[vector_id * K + i] - input_lengths[global_batch_idx]; - int const pad = static_cast(!beam_hyps.finished[vector_id * K + i].isFinished()); + // Condition of this branch + // In Beam search mode, this token is end_token and belongs to top K range in Beam search mode + int const seq_len = sequence_lengths[bid * K + i] - input_lengths[global_batch_idx]; + int const pad = static_cast(!bh.finished[bid * K + i].isFinished()); float const normed_score = apply_length_penalty(current_value, seq_len + pad, length_penalty); - int beam_idx = beam_hyps.num_beams[global_batch_idx]; + int beam_idx = bh.num_beams[global_batch_idx]; if (beam_idx == K) { // There are already K beams - if (normed_score < beam_hyps.min_normed_scores[global_batch_idx]) + if (normed_score < bh.min_normed_scores[global_batch_idx]) { // Current score is worse than the worst one in candidate beams - // Stop considering new beams - selected_beams = K; - break; + if (early_stopping) + { + // Stop since we have got enough beams + break; + } + else + { + // Continue since there might be longer but better beams + continue; + } } else { // Current score is better than the worst one in candidate beams - // Find the beam index which score == min_normed_score and erase it + // Find the candidate beam index with the worst score and erase it for (int j = 0; j < K; j++) { - if (beam_hyps.normed_scores[global_batch_idx * (K * 2) + j] - == beam_hyps.min_normed_scores[global_batch_idx]) + if (bh.normed_scores[global_batch_idx * (K * 2) + j] + == bh.min_normed_scores[global_batch_idx]) { beam_idx = j; - beam_hyps.num_beams[global_batch_idx]--; - beam_hyps.min_normed_scores[global_batch_idx] = FLT_MAX; - beam_hyps.normed_scores[global_batch_idx * (K * 2) + j] = normed_score; + bh.num_beams[global_batch_idx]--; + bh.min_normed_scores[global_batch_idx] = FLT_MAX; + bh.normed_scores[global_batch_idx * (K * 2) + j] = normed_score; for (int l = 0; l < K; l++) { - beam_hyps.min_normed_scores[global_batch_idx] - = min(beam_hyps.min_normed_scores[global_batch_idx], - beam_hyps.normed_scores[global_batch_idx * (K * 2) + l]); + bh.min_normed_scores[global_batch_idx] = min(bh.min_normed_scores[global_batch_idx], + bh.normed_scores[global_batch_idx * (K * 2) + l]); } break; } } } } - int const tgt_id_offset - = ((vector_id + beam_hyps.ite * beam_hyps.local_batch_size) * (K * 2) + beam_idx) - * (beam_hyps.max_seq_len); int prev_id = (topk_id[current_key] / vocab_size) % K; - int const current_step{sequence_lengths[vector_id * K + prev_id]}; - beam_hyps.output_ids_tgt[tgt_id_offset + current_step] = beam_hyps.end_ids[vector_id]; - if (beam_hyps.log_probs != nullptr) + int const current_step = sequence_lengths[bid * K + prev_id]; + int const tgt_id_offset = ((bid + bh.ite * bh.local_batch_size) * (K * 2) + beam_idx) * bh.max_seq_len; + bh.output_ids_tgt[tgt_id_offset + current_step] = bh.end_ids[bid]; + if (bh.log_probs != nullptr) { - beam_hyps.log_probs[tgt_id_offset + current_step] + bh.log_probs[tgt_id_offset + current_step] = (float) topk_val[current_key] - old_cum_log_probs[(topk_id[current_key] / vocab_size) % K]; } - + // Copy finished beam from "%% self.output_ids" to "%% self.beam_hyps_output_ids_tgt" for (int j = current_step - 1; j >= 0; j--) { - int const src_idx = j * beam_hyps.batch_size * K + beam_hyps.ite * beam_hyps.local_batch_size * K - + vector_id * K + prev_id; - beam_hyps.output_ids_tgt[tgt_id_offset + j] - = beam_hyps.output_ids_tgt_ptr[vector_id][prev_id * beam_hyps.max_seq_len + j]; - if (beam_hyps.log_probs != nullptr && beam_hyps.log_probs_src != nullptr) + bh.output_ids_tgt[tgt_id_offset + j] = bh.output_ids_tgt_ptr[bid][prev_id * bh.max_seq_len + j]; + prev_id = bh.parent_ids_tgt_ptr[bid][prev_id * bh.max_seq_len + j]; + } + if (bh.log_probs != nullptr && bh.log_probs_src != nullptr) + { + prev_id = (topk_id[current_key] / vocab_size) % K; + for (int j = current_step - 1; j >= 0; j--) { - beam_hyps.log_probs[tgt_id_offset + j] = beam_hyps.log_probs_src[src_idx]; + int const index = j * bh.batch_size * K + bh.ite * bh.local_batch_size * K + bid * K + prev_id; + bh.log_probs[tgt_id_offset + j] = bh.log_probs_src[index]; + prev_id = bh.parent_ids_tgt_ptr[bid][prev_id * bh.max_seq_len + j]; } - prev_id = beam_hyps.parent_ids_tgt_ptr[vector_id][prev_id * beam_hyps.max_seq_len + j]; } int const tgt_beam_idx = global_batch_idx * (K * 2) + beam_idx; - beam_hyps.sequence_lengths_tgt[tgt_beam_idx] = current_step; - beam_hyps.normed_scores[tgt_beam_idx] = normed_score; - beam_hyps.min_normed_scores[global_batch_idx] - = min(beam_hyps.min_normed_scores[global_batch_idx], beam_hyps.normed_scores[tgt_beam_idx]); - beam_hyps.num_beams[global_batch_idx]++; - beam_hyps.cum_log_probs[tgt_beam_idx] = (float) topk_val[current_key]; + bh.sequence_lengths_tgt[tgt_beam_idx] = current_step; + bh.normed_scores[tgt_beam_idx] = normed_score; + bh.min_normed_scores[global_batch_idx] + = min(bh.min_normed_scores[global_batch_idx], bh.normed_scores[tgt_beam_idx]); + bh.num_beams[global_batch_idx]++; + bh.cum_log_probs[tgt_beam_idx] = (float) topk_val[current_key]; } - else if (i < K || beam_hyps.num_beams != nullptr && !is_end_token) + else if (i < K || bh.num_beams != nullptr && !is_end_token) { // Condition of this branch - // 1. beam_hyps.num_beams == nullptr && i < K, i.e., beam search is disable - // 2. beam_hyps.num_beams != nullptr && i < K && is_end_token == false, i.e., add token at the end - // 3. beam_hyps.num_beams != nullptr && i >= K && is_end_token == false, i.e., add token at the end - int const current_step = sequence_lengths[vector_id * K + selected_beams]; - beam_hyps.output_ids_tgt_ptr[vector_id][selected_beams * beam_hyps.max_seq_len + current_step] - = topk_id[current_key]; - if (beam_hyps.log_probs_src != nullptr) + // 1. bh.num_beams == nullptr && i < K, i.e., beam search is disable + // 2. bh.num_beams != nullptr && i < K && is_end_token == false, i.e., add token at the end + // 3. bh.num_beams != nullptr && i >= K && is_end_token == false, i.e., add token at the end + int const current_step = sequence_lengths[bid * K + selected_beams]; + // Write the selected token to output.output_ids + bh.output_ids_tgt_ptr[bid][selected_beams * bh.max_seq_len + current_step] = topk_id[current_key]; + if (bh.log_probs_src != nullptr) { - beam_hyps.log_probs_src[current_step * beam_hyps.batch_size * K + vector_id * K + selected_beams] + bh.log_probs_src[current_step * bh.batch_size * K + bid * K + selected_beams] = (float) topk_val[current_key] - old_cum_log_probs[(topk_id[current_key] / vocab_size) % K]; } - cum_log_probs_src[vector_id * K + selected_beams] = (float) topk_val[current_key]; + bh.cum_log_probs_src[bid * K + selected_beams] = (float) topk_val[current_key]; selected_beams++; } else { - ; // Condition of this branch, which we do nothing for it - // 1. beam_hyps.num_beams == nullptr && i >= K, i.e., beam search is disable - // 2. beam_hyps.num_beams != nullptr && i >= K && is_end_token == true, i.e., ignore the worse beams + // 1. bh.num_beams == nullptr && i >= K, i.e., beam search is disable + // 2. bh.num_beams != nullptr && i >= K && is_end_token == true, i.e., ignore the worse beams } if (selected_beams >= K) @@ -328,43 +260,72 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__ void batch_topk_kernel( } } - // update beam_hyps.is_done - if (thread_id == 0 && beam_hyps.num_beams != nullptr) + // Update bh.is_done + if (tid == 0 && bh.num_beams != nullptr) { - if (beam_hyps.num_beams[vector_id] < K) + if (bh.num_beams[bid] < K) { // no enough beams - beam_hyps.is_done[vector_id] = false; - return; + bh.is_done[bid] = false; } - int seq_len = 0; - float highest_attainable_score = 0.0f; - switch (early_stopping) + else if (early_stopping == 1) { - case 1: // enough beams with early_stopping - beam_hyps.is_done[vector_id] = true; - return; - case 0: - // enough beams with non_early_stopping - seq_len = sequence_lengths[vector_id * K] - input_lengths[global_batch_idx]; - highest_attainable_score = apply_length_penalty(cum_log_probs_src[vector_id * K], seq_len, length_penalty); - beam_hyps.is_done[vector_id] = beam_hyps.min_normed_scores[global_batch_idx] >= highest_attainable_score; - return; - default: - // early_stopping == "never" in HF, i.e., compute the best possible score depending on length_penalty + bh.is_done[bid] = true; + } + else + { + // Condition of this branch + // 1. enough beams with early_stopping == 0, i.e. non_early_stopping + // 2. enough beams with early_stopping being other values, i.e. early_stopping == "never" in HF // https://github.com/huggingface/transformers/blob/main/src/transformers/generation/beam_search.py#L990 - if (length_penalty > 0.0f) - { - seq_len = beam_hyps.max_seq_len - input_lengths[global_batch_idx]; - } - else + int seq_len = sequence_lengths[bid * K] + 1 - input_lengths[global_batch_idx]; + float const best_sum_logprobs = cta_topk[0].value; + // According to semantics of HF, cta_topk[0].value is used as best_sum_logprobs + // https://github.com/huggingface/transformers/blob/main/src/transformers/generation/beam_search.py#L307 + // But maybe bh.cum_log_probs_src[bid * K + i] is more suitable? + if (early_stopping != 0 && length_penalty > 0.0f) { - seq_len = sequence_lengths[vector_id * K] - input_lengths[global_batch_idx]; + // Specialize for early_stopping == "never" and length_penalty > 0 + seq_len = bh.max_seq_len - input_lengths[global_batch_idx]; } - highest_attainable_score = apply_length_penalty(cum_log_probs_src[vector_id * K], seq_len, length_penalty); - beam_hyps.is_done[vector_id] = beam_hyps.min_normed_scores[global_batch_idx] >= highest_attainable_score; - return; + float const highest_attainable_score = apply_length_penalty(best_sum_logprobs, seq_len, length_penalty); + bh.is_done[bid] = bh.min_normed_scores[global_batch_idx] >= highest_attainable_score; + } + } + __syncthreads(); + + // Update sequence_lengths, parent_ids, output_ids and finished + __shared__ int s_sequence_lengths[MAX_K2 / 2]; + if (tid < K) + { + s_sequence_lengths[tid] = sequence_lengths[bid * K + tid]; + } + __syncthreads(); + + if (tid < K) + { + int const bb_index = bid * K + tid; + int const current_step = s_sequence_lengths[tid]; + if (!bh.finished[bb_index].isFinished()) + { + s_sequence_lengths[tid]++; + } + int const new_id = bh.output_ids_tgt_ptr[bid][tid * bh.max_seq_len + current_step]; + int const new_beam_id = (new_id / vocab_size) % K; + int const new_word_id = new_id % vocab_size; + sequence_lengths[bb_index] = s_sequence_lengths[new_beam_id]; + if (new_word_id == bh.end_ids[bid]) + { + bh.finished[bb_index].setFinishedEOS(); + } + bh.parent_ids_tgt_ptr[bid][tid * bh.max_seq_len + current_step] = new_beam_id; + bh.output_ids_tgt_ptr[bid][tid * bh.max_seq_len + current_step] = new_word_id; + if (early_stopping && (bh.num_beams != nullptr && bh.num_beams[bh.ite * bh.local_batch_size + bid] == K) + || !early_stopping && bh.is_done[bid]) // TODO: simplify this condition + { + bh.is_done[bid] = true; + bh.finished[bb_index].setFinished(); } } } @@ -377,12 +338,10 @@ struct __align__(8) MD __device__ __forceinline__ MD reduce_md_op(MD a, MD b) { - bool a_bigger = (a.m > b.m); - MD bigger_m = a_bigger ? a : b; - MD smaller_m = a_bigger ? b : a; - MD res; - res.d = bigger_m.d + smaller_m.d * __expf(smaller_m.m - bigger_m.m); - res.m = bigger_m.m; + bool const is_a_bigger = a.m > b.m; + MD const bigger = is_a_bigger ? a : b; + MD const smaller = is_a_bigger ? b : a; + MD res{bigger.m, bigger.d + smaller.d * __expf(smaller.m - bigger.m)}; return res; } @@ -407,16 +366,13 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__ void beam_online_softmax_topk_ker T const* __restrict bias, float const* __restrict cum_log_probs, FinishedState const* __restrict finished, int* __restrict topk_id, T* __restrict topk_val, int vocab_size, int K, int const* __restrict end_ids) { - int const thread_id = threadIdx.x; - int const vector_id = blockIdx.x; + int const tid = threadIdx.x; + int const bid = blockIdx.x; T const MAX_T_VAL = (std::is_same::value) ? HALF_FLT_MAX : FLT_MAX; typedef cub::BlockReduce, THREADBLOCK_SIZE> BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; - // reposition log_probs to data for the current vector - log_probs += vector_id * vocab_size; - TopKMD partial; for (int i = 0; i < MAX_K; ++i) { @@ -426,46 +382,40 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__ void beam_online_softmax_topk_ker partial.md.m = -MAX_T_VAL; partial.md.d = 0.0F; - if (finished[vector_id].isFinished()) + if (finished[bid].isFinished()) { - for (int id = thread_id; id < vocab_size; id += THREADBLOCK_SIZE) + for (int id = tid; id < vocab_size; id += THREADBLOCK_SIZE) { - float elem = (id == end_ids[vector_id / K]) ? MAX_T_VAL : -MAX_T_VAL; - MD new_elem{elem, 1.0F}; + float val = (id == end_ids[bid / K]) ? MAX_T_VAL : -MAX_T_VAL; + MD new_elem{val, 1.0F}; partial.md = reduce_md_op(partial.md, new_elem); - partial.topk.insert(elem, id); + partial.topk.insert(val, id); } } else { - for (int id = thread_id; id < vocab_size; id += THREADBLOCK_SIZE) + T const* local_log_probs = log_probs + bid * vocab_size; + for (int id = tid; id < vocab_size; id += THREADBLOCK_SIZE) { - float elem = log_probs[id] + bias[id]; - MD new_elem{elem, 1.0F}; + float val = local_log_probs[id] + bias[id]; + MD new_elem{val, 1.0F}; partial.md = reduce_md_op(partial.md, new_elem); - partial.topk.insert(elem, id); + partial.topk.insert(val, id); } } TopKMD total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_md_op); - if (thread_id == 0) + if (tid == 0) { - topk_id += vector_id * K; - topk_val += vector_id * K; - cum_log_probs += vector_id; - - // float d_total_inverse = __fdividef(1.0F, total.md.d); - float d_total_log = logf(total.md.d); - for (int i = 0; i < MAX_K; ++i) + int* local_topk_id = topk_id + bid * K; + T const* local_topk_val = topk_val + bid * K; + float const d_total_log = logf(total.md.d); + float local_cum_log_probs = cum_log_probs[bid]; + for (int i = 0; i < K; ++i) { - // float val = __expf(total.topk.u[i] - total.md.m) * d_total_inverse; - float val = total.topk.u[i] - total.md.m - d_total_log; - if (i < K) - { - topk_id[i] = total.topk.p[i] + vector_id * vocab_size; // trtllm needs absolute id - topk_val[i] = val + cum_log_probs[0]; - } + local_topk_id[i] = total.topk.p[i] + bid * vocab_size; + local_topk_val[i] = total.topk.u[i] - total.md.m - d_total_log + local_cum_log_probs; } } } @@ -475,8 +425,8 @@ __launch_bounds__(THREADBLOCK_SIZE, 1) __global__ void beam_online_softmax_topk_ T const* __restrict log_probs, T const* __restrict bias, FinishedState const* __restrict finished, float* __restrict tmp_buffer, int vocab_size, int K, int const* __restrict end_ids) { - int const thread_id = threadIdx.x; - int const vector_id = blockIdx.x; + int const tid = threadIdx.x; + int const bid = blockIdx.x; T const MAX_T_VAL = (std::is_same::value) ? HALF_FLT_MAX : FLT_MAX; int const PACKED_TOP_KMD_SIZE = 2 * MAX_K2 + 2; // one threadblock has multiple sections per vocab_size @@ -495,9 +445,6 @@ __launch_bounds__(THREADBLOCK_SIZE, 1) __global__ void beam_online_softmax_topk_ __shared__ typename BlockReduce::TempStorage temp_storage; __shared__ float buf_s[PACKED_TOP_KMD_SIZE]; - // reposition log_probs to the data for the current vector - log_probs += vector_id * vocab_size; - for (int i = 0; i < MAX_K2; ++i) { partial.topk.p[i] = -1; @@ -506,27 +453,28 @@ __launch_bounds__(THREADBLOCK_SIZE, 1) __global__ void beam_online_softmax_topk_ partial.md.m = -MAX_T_VAL; partial.md.d = 0.0F; - if (finished[vector_id].isFinished()) + if (finished[bid].isFinished()) { #pragma unroll 1 - for (int id = section_start + thread_id; id < section_end; id += THREADBLOCK_SIZE) + for (int id = section_start + tid; id < section_end; id += THREADBLOCK_SIZE) { - float elem = (id == end_ids[vector_id / K]) ? MAX_T_VAL : -MAX_T_VAL; - MD new_elem{elem, 1.0F}; - partial.md = reduce_md_op(partial.md, new_elem); - partial.topk.insert(elem, id); + float val = (id == end_ids[bid / K]) ? MAX_T_VAL : -MAX_T_VAL; + MD new_elem_md{val, 1.0F}; + partial.md = reduce_md_op(partial.md, new_elem_md); + partial.topk.insert(val, id); } } else { + T const* local_log_probs = log_probs + bid * vocab_size; #pragma unroll 1 - for (int id = section_start + thread_id; id < section_end; id += THREADBLOCK_SIZE) + for (int id = section_start + tid; id < section_end; id += THREADBLOCK_SIZE) { T b = bias == nullptr ? (T) 0.0f : bias[id]; - T elem = log_probs[id] + b; - MD new_elem{elem, 1.0F}; - partial.md = reduce_md_op(partial.md, new_elem); - partial.topk.insert(elem, id); + T val = local_log_probs[id] + b; + MD new_elem_md{val, 1.0F}; + partial.md = reduce_md_op(partial.md, new_elem_md); + partial.topk.insert(val, id); } } @@ -536,11 +484,11 @@ __launch_bounds__(THREADBLOCK_SIZE, 1) __global__ void beam_online_softmax_topk_ TopKMD total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_md_op); #endif - if (thread_id == 0) + if (tid == 0) { for (int i = 0; i < 2 * K; i++) { - reinterpret_cast(buf_s)[i] = total.topk.p[i] + vector_id * vocab_size; // trtllm needs absolute id + reinterpret_cast(buf_s)[i] = total.topk.p[i] + bid * vocab_size; // trtllm needs absolute id buf_s[MAX_K2 + i] = total.topk.u[i]; } buf_s[2 * MAX_K2] = total.md.d; @@ -548,9 +496,10 @@ __launch_bounds__(THREADBLOCK_SIZE, 1) __global__ void beam_online_softmax_topk_ } __syncthreads(); - for (int id = thread_id; id < PACKED_TOP_KMD_SIZE; id += THREADBLOCK_SIZE) + float* local_tmp_buffer = tmp_buffer + bid * PACKED_TOP_KMD_SIZE * gridDim.y + blockIdx.y * PACKED_TOP_KMD_SIZE; + for (int id = tid; id < PACKED_TOP_KMD_SIZE; id += THREADBLOCK_SIZE) { - tmp_buffer[blockIdx.x * PACKED_TOP_KMD_SIZE * gridDim.y + blockIdx.y * PACKED_TOP_KMD_SIZE + id] = buf_s[id]; + local_tmp_buffer[id] = buf_s[id]; } } @@ -559,8 +508,8 @@ __launch_bounds__(THREADBLOCK_SIZE, 1) __global__ void beam_online_softmax_topk_ T const* __restrict log_probs, T const* __restrict bias, FinishedState const* __restrict finished, float* __restrict t, int vocab_size, int K, int const* __restrict end_ids, int const v_local) { - int const thread_id = threadIdx.x; - int const vector_id = blockIdx.x; + int const tid = threadIdx.x; + int const bid = blockIdx.x; T const MAX_T_VAL = (std::is_same::value) ? HALF_FLT_MAX : FLT_MAX; int const PACKED_TOP_KMD_SIZE = 2 * MAX_K2 + 2; // one threadblock has multiple sections per vocab_size @@ -570,13 +519,13 @@ __launch_bounds__(THREADBLOCK_SIZE, 1) __global__ void beam_online_softmax_topk_ #if TOPK_FP16_STORAGE == 1 using cub_kvp = cub::KeyValuePair; - using BlockReduceTopK = cub::BlockReduce; #else using cub_kvp = cub::KeyValuePair; - using BlockReduceTopK = cub::BlockReduce; #endif + using BlockReduceTopK = cub::BlockReduce; using BlockReduceMD = cub::BlockReduce; + auto reduce_md_func = [](const MD& a, const MD& b) { return reduce_md_op(a, b); }; extern __shared__ char buf_smem_logprobs_[]; T* buf_smem_logprobs = reinterpret_cast(buf_smem_logprobs_); @@ -589,87 +538,83 @@ __launch_bounds__(THREADBLOCK_SIZE, 1) __global__ void beam_online_softmax_topk_ typename BlockReduceTopK::TempStorage topk_smem; } temp_storage; - // reposition log_probs to the data for the current vector - log_probs += vector_id * vocab_size; - cub::ArgMax arg_max; cub_kvp partial_topk{vocab_size - 1, -MAX_T_VAL}; MD partial_md{-MAX_T_VAL, 0.0f}; - if (finished[vector_id].isFinished()) + if (finished[bid].isFinished()) { #pragma unroll 1 - for (int id = section_start + thread_id; id < section_end; id += THREADBLOCK_SIZE) + for (int id = section_start + tid; id < section_end; id += THREADBLOCK_SIZE) { - float elem = (id == end_ids[vector_id / K]) ? MAX_T_VAL : -MAX_T_VAL; - buf_smem_logprobs[id - section_start] = elem; - MD new_elem{elem, 1.0F}; - partial_md = reduce_md_op(partial_md, new_elem); - + float const val = (id == end_ids[bid / K]) ? MAX_T_VAL : -MAX_T_VAL; int const smem_index = id - section_start; - cub_kvp new_elem_topk{smem_index, elem}; + buf_smem_logprobs[smem_index] = val; + MD new_elem_md{val, 1.0F}; + partial_md = reduce_md_op(partial_md, new_elem_md); + cub_kvp new_elem_topk{smem_index, val}; partial_topk = arg_max(partial_topk, new_elem_topk); - buf_smem_logprobs[smem_index] = elem; } } else { + T const* local_log_probs = log_probs + bid * vocab_size; #pragma unroll 1 - for (int id = section_start + thread_id; id < section_end; id += THREADBLOCK_SIZE) + for (int id = section_start + tid; id < section_end; id += THREADBLOCK_SIZE) { - T b = bias == nullptr ? (T) 0.0f : bias[id]; - T elem = log_probs[id] + b; - MD new_elem_md{elem, 1.0F}; - partial_md = reduce_md_op(partial_md, new_elem_md); - + T const b = bias == nullptr ? (T) 0.0f : bias[id]; + T const val = local_log_probs[id] + b; int const smem_index = id - section_start; - cub_kvp new_elem_topk{smem_index, elem}; + buf_smem_logprobs[smem_index] = val; + MD new_elem_md{val, 1.0F}; + partial_md = reduce_md_op(partial_md, new_elem_md); + cub_kvp new_elem_topk{smem_index, val}; partial_topk = arg_max(partial_topk, new_elem_topk); - buf_smem_logprobs[smem_index] = elem; } } __syncthreads(); for (int i = 0; i < 2 * K; ++i) { + // Pop the best choice from "total_topk" to "buf_s" per iteration cub_kvp total_topk = BlockReduceTopK(temp_storage.topk_smem).Reduce(partial_topk, arg_max); - if (threadIdx.x == 0) + if (tid == 0) { - reinterpret_cast(buf_s)[i] - = section_start + total_topk.key + vector_id * vocab_size; // trtllm needs absolute id + int const index = bid * vocab_size + section_start + total_topk.key; + reinterpret_cast(buf_s)[i] = index; buf_s[MAX_K2 + i] = total_topk.value; - buf_smem_logprobs[total_topk.key] = -MAX_T_VAL; + buf_smem_logprobs[total_topk.key] = -MAX_T_VAL; // delete the value of the best choice thread_requiring_update = total_topk.key % THREADBLOCK_SIZE; } __syncthreads(); - // Only one thread needs to update the old partial before the next block reduce. - // No need to do this in the last iteration. - if (thread_id == thread_requiring_update && i < 2 * K - 1) + if (tid == thread_requiring_update && i < 2 * K - 1) { + // The thread with the biggest element updates its partial_topk + // No need to do this in the last iteration partial_topk.key = vocab_size - 1; partial_topk.value = -MAX_T_VAL; - for (int tid = thread_id; tid < valid_smem_length; tid += THREADBLOCK_SIZE) + for (int index = tid; index < valid_smem_length; index += THREADBLOCK_SIZE) { - cub_kvp new_elem{tid, buf_smem_logprobs[tid]}; + cub_kvp new_elem{index, buf_smem_logprobs[index]}; partial_topk = arg_max(partial_topk, new_elem); } } } - auto reduce_md_func = [](const MD& a, const MD& b) { return reduce_md_op(a, b); }; MD total_md = BlockReduceMD(temp_storage.md_smem).Reduce(partial_md, reduce_md_func); - if (threadIdx.x == 0) + if (tid == 0) { buf_s[2 * MAX_K2] = total_md.d; buf_s[2 * MAX_K2 + 1] = total_md.m; } __syncthreads(); - for (int id = thread_id; id < PACKED_TOP_KMD_SIZE; id += THREADBLOCK_SIZE) + float* local_t = t + bid * PACKED_TOP_KMD_SIZE * gridDim.y + blockIdx.y * PACKED_TOP_KMD_SIZE; + for (int id = tid; id < PACKED_TOP_KMD_SIZE; id += THREADBLOCK_SIZE) { - t[blockIdx.x * PACKED_TOP_KMD_SIZE * gridDim.y + blockIdx.y * PACKED_TOP_KMD_SIZE + id] = buf_s[id]; + local_t[id] = buf_s[id]; } } @@ -678,8 +623,8 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__ void beam_online_softmax_topk_sta float const* __restrict temp_storage, float const* __restrict cum_log_probs, int* __restrict ids, T* __restrict vals, int K, int parts_per_beam, int const vocab_size) { - int const vector_id = blockIdx.x; - int const thread_id = threadIdx.x; + int const bid = blockIdx.x; + int const tid = threadIdx.x; T const MAX_T_VAL = (std::is_same::value) ? HALF_FLT_MAX : FLT_MAX; int const PACKED_TOP_KMD_SIZE = 2 * MAX_K2 + 2; @@ -698,16 +643,17 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__ void beam_online_softmax_topk_sta } shared_temp_storage; - temp_storage += vector_id * PACKED_TOP_KMD_SIZE * parts_per_beam; - cub::ArgMax arg_max; MD partial_md{-MAX_T_VAL, 0.0f}; cub_kvp total_topk{vocab_size - 1, -MAX_T_VAL}; + auto reduce_md_func = [](const MD& a, const MD& b) { return reduce_md_op(a, b); }; + // Load and unpack into registers through smem - for (int idx = thread_id; idx < PACKED_TOP_KMD_SIZE * parts_per_beam; idx += THREADBLOCK_SIZE) + float const* local_temp_storage = temp_storage + bid * PACKED_TOP_KMD_SIZE * parts_per_beam; + for (int idx = tid; idx < PACKED_TOP_KMD_SIZE * parts_per_beam; idx += THREADBLOCK_SIZE) { - buf_s[idx] = temp_storage[idx]; + buf_s[idx] = local_temp_storage[idx]; } __syncthreads(); @@ -717,13 +663,12 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__ void beam_online_softmax_topk_sta { cub_kvp partial_topk{vocab_size - 1, -MAX_T_VAL}; // Only threads responsible for a chunk will do the computation - if (threadIdx.x < parts_per_beam) + if (tid < parts_per_beam) { - float* b_s = buf_s + threadIdx.x * PACKED_TOP_KMD_SIZE; - for (int i = 0; i < K; ++i) + for (int i = 0; i < 2 * K; ++i) { - int current_index = threadIdx.x * PACKED_TOP_KMD_SIZE + i; - T current_value = b_s[MAX_K2 + i]; + int const current_index = tid * PACKED_TOP_KMD_SIZE + i; + T current_value = buf_s[current_index + MAX_K2]; cub_kvp new_elem = {current_index, current_value}; partial_topk = arg_max(partial_topk, new_elem); } @@ -732,7 +677,7 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__ void beam_online_softmax_topk_sta cub_kvp total_topk = BlockReduceTopK(shared_temp_storage.topk_smem).Reduce(partial_topk, arg_max); __syncthreads(); - if (threadIdx.x == 0) + if (tid == 0) { // Store kv pairs in shared mem buffer int temp_offset = total_topk.key; @@ -748,22 +693,17 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__ void beam_online_softmax_topk_sta } // Extract and reduce MD values across the chunks - if (threadIdx.x < parts_per_beam) + if (tid < parts_per_beam) { - float* b_s = buf_s + threadIdx.x * PACKED_TOP_KMD_SIZE; - partial_md.d = b_s[2 * MAX_K2]; - partial_md.m = b_s[2 * MAX_K2 + 1]; + partial_md.d = buf_s[tid * PACKED_TOP_KMD_SIZE + 2 * MAX_K2]; + partial_md.m = buf_s[tid * PACKED_TOP_KMD_SIZE + 2 * MAX_K2 + 1]; } __syncthreads(); - auto reduce_md_func = [](const MD& a, const MD& b) { return reduce_md_op(a, b); }; MD total_md = BlockReduceMD(shared_temp_storage.md_smem).Reduce(partial_md, reduce_md_func); - if (thread_id == 0) + if (tid == 0) { - ids += vector_id * 2 * K; - vals += vector_id * 2 * K; - cum_log_probs += vector_id; float d_total_log = logf(total_md.d); for (int i = 0; i < MAX_K2; ++i) @@ -771,8 +711,8 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__ void beam_online_softmax_topk_sta float val = (float) buf_smem_kv[i].value - total_md.m - d_total_log; if (i < 2 * K) { - ids[i] = buf_smem_kv[i].key; - vals[i] = (float) val + (float) cum_log_probs[0]; + ids[bid * 2 * K + i] = buf_smem_kv[i].key; + vals[bid * 2 * K + i] = val + cum_log_probs[bid]; } } } @@ -782,8 +722,7 @@ template void beam_online_softmax_topk_stage2_kernelLauncher(float const* temp_storage, float const* cum_log_probs, int* ids, T* vals, int batch_size, int beam_width, int parts_per_beam, cudaStream_t stream, int const vocab_size) { - // TODO: rewrite beam_online_softmax_topk_stage2_kernel to remove dependence - // of constant block size in oreder to reduce compilation time + // TODO: rewrite kernel to remove dependence of constant block size to reduce compilation time int const smem_stage2_size = parts_per_beam * (2 * MAX_K2 + 2) * sizeof(float); if (parts_per_beam <= 32) @@ -812,14 +751,14 @@ void beam_online_softmax_topk_stage2_kernelLauncher(float const* temp_storage, f template void topK_softMax_kernelLauncher(T const* log_probs, T const* bias, void* temp_storage, int const temp_storage_size, - BeamHypotheses& beam_hyps, cudaStream_t stream) + BeamHypotheses& bh, cudaStream_t stream) { - int const batch_size{beam_hyps.local_batch_size}; - int const beam_width{beam_hyps.beam_width}; - int const vocab_size{beam_hyps.vocab_size}; - int const* end_ids{beam_hyps.end_ids}; - float* cum_log_probs{beam_hyps.cum_log_probs_src}; - FinishedState const* finished{beam_hyps.finished}; + int const batch_size{bh.local_batch_size}; + int const beam_width{bh.beam_width}; + int const vocab_size{bh.vocab_size}; + int const* end_ids{bh.end_ids}; + float* cum_log_probs{bh.cum_log_probs_src}; + FinishedState const* finished{bh.finished}; int const items_per_thread = 1; int const block_sz = (MAX_K < 16) ? ((MAX_K < 8) ? SMALL_TOP_K_SOFTMAX_THREADBLOCK_SIZE : 128) : 64; @@ -827,7 +766,7 @@ void topK_softMax_kernelLauncher(T const* log_probs, T const* bias, void* temp_s assert(temp_storage_size % 2 == 0); assert(temp_storage_size >= 2 * batch_size * beam_width * beam_width * 2); // Input and current sequence lengths are needed for computation of length penalty - assert(beam_hyps.length_penalties == nullptr || beam_hyps.sequence_lengths_src != nullptr); + assert(bh.length_penalties == nullptr || bh.sequence_lengths_src != nullptr); int const topk_buf_offset = ceil(batch_size * beam_width * beam_width * 2 / 4.) * 4; int* topk_id = reinterpret_cast(temp_storage); @@ -935,7 +874,7 @@ void topK_softMax_kernelLauncher(T const* log_probs, T const* bias, void* temp_s log_probs, bias, cum_log_probs, finished, topk_id, topk_val, vocab_size, beam_width, end_ids); #endif - // Keep 2*MAX_K candidates in case of k candidates finishes in one iteration + // Keep 2 * MAX_K candidates in case of k candidates finishes in one iteration int const candidates = beam_width * beam_width * 2; int const smem_size_batch_topk = sizeof(T) * candidates; if (smem_size_batch_topk >= (48 << 10)) @@ -945,13 +884,13 @@ void topK_softMax_kernelLauncher(T const* log_probs, T const* bias, void* temp_s } batch_topk_kernel - <<>>(topk_id, topk_val, beam_hyps, candidates); + <<>>(topk_id, topk_val, bh, candidates); sync_check_cuda_error(); } #define INSTANTIATE_BEAMSEARCH_K(T, MAX_K) \ template void topK_softMax_kernelLauncher(T const* log_probs, T const* bias, void* temp_storage, \ - int const temp_storage_size, BeamHypotheses& beam_hyps, cudaStream_t stream); + int const temp_storage_size, BeamHypotheses& bh, cudaStream_t stream); } // namespace kernels } // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/penaltyKernels.h b/cpp/tensorrt_llm/kernels/penaltyKernels.h index 7b3da736e..57acb635b 100644 --- a/cpp/tensorrt_llm/kernels/penaltyKernels.h +++ b/cpp/tensorrt_llm/kernels/penaltyKernels.h @@ -39,11 +39,11 @@ struct InvokeBatchApplyPenaltyParams float const* presencePenalties; float const* frequencyPenalties; bool const accumulateVocab; - size_t const batchSize; + runtime::SizeType const batchSize; runtime::SizeType const beamWidth; runtime::SizeType const maxSeqLen; - size_t const vocabSize; - size_t const vocabSizePadded; + runtime::SizeType const vocabSize; + runtime::SizeType const vocabSizePadded; runtime::TokenIdType const** outputIdsPtr; runtime::SizeType const** parentIdsPtr; runtime::SizeType const* inputLengths; diff --git a/cpp/tensorrt_llm/kernels/samplingAirTopPKernels.cu b/cpp/tensorrt_llm/kernels/samplingAirTopPKernels.cu index 473d9d17a..6bf8b0092 100644 --- a/cpp/tensorrt_llm/kernels/samplingAirTopPKernels.cu +++ b/cpp/tensorrt_llm/kernels/samplingAirTopPKernels.cu @@ -25,6 +25,8 @@ #include "tensorrt_llm/common/cudaUtils.h" #include "tensorrt_llm/common/memoryUtils.h" #include "tensorrt_llm/kernels/samplingAirTopPKernels.h" +#include + #include #include @@ -74,7 +76,7 @@ struct alignas(128) Counter // For a row inside a batch, we may launch multiple thread blocks. This counter is // used to determine if the current block is the last running block. If so, this block // will execute scan() and chooseBucket(). - alignas(128) unsigned int finishedBlockCnt; + alignas(128) uint32_t finishedBlockCnt; }; /*******************************Functions*********************************/ @@ -118,7 +120,7 @@ __host__ __device__ int constexpr calcNumPasses() * significant (rightmost)). This way, we can skip some passes in the end at the cost of having an unsorted output. */ template -__device__ int constexpr calcsStartBit(int pass) +__device__ int constexpr calcStartBit(int pass) { int startBit = static_cast(sizeof(T) * 8) - (pass + 1) * BitsPerPass; if (startBit < 0) @@ -129,13 +131,231 @@ __device__ int constexpr calcsStartBit(int pass) } template -__device__ unsigned constexpr calcMask(int pass) +__device__ uint32_t constexpr calcMask(int pass) { static_assert(BitsPerPass <= 31); - int numBits = calcsStartBit(pass - 1) - calcsStartBit(pass); + int numBits = calcStartBit(pass - 1) - calcStartBit(pass); return (1 << numBits) - 1; } +template +__device__ constexpr uint32_t getNumTotalMantissa() +{ + if constexpr (std::is_same_v) + { + return 10; + } + else if constexpr (std::is_same_v) + { + return 23; + } +} + +template +__device__ uint32_t calcMantissa(T value); + +template <> +__device__ uint32_t calcMantissa(float value) +{ + union + { + uint32_t bits; + float value; + } input; + + input.value = value; + + constexpr uint32_t numTotalMantissa = getNumTotalMantissa(); + uint32_t mask = (1u << numTotalMantissa) - 1; + return input.bits & mask; +} + +__device__ uint32_t calcMantissa(half value) +{ + union + { + uint16_t bits; + half value; + } input; + + input.value = value; + + constexpr uint32_t numTotalMantissa = getNumTotalMantissa(); + uint32_t t = 0u | input.bits; + uint32_t mask = (1u << numTotalMantissa) - 1; + return t & mask; +} + +template +__device__ uint32_t calcExponent(T value); + +template <> +__device__ uint32_t calcExponent(float value) +{ + union + { + uint32_t bits; + float value; + } input; + + input.value = value; + + constexpr uint32_t numTotalMantissa = getNumTotalMantissa(); + uint32_t mask = (1u << numTotalMantissa) - 1; + return input.bits & ~mask; +} + +template <> +__device__ uint32_t calcExponent(half value) +{ + union + { + uint16_t bits; + half value; + } input; + + input.value = value; + + constexpr uint32_t numTotalMantissa = getNumTotalMantissa(); + uint32_t t = 0u | input.bits; + uint32_t mask = (1u << numTotalMantissa) - 1; + return t & ~mask; +} + +__device__ float calcHalfValue(uint32_t count, uint32_t exponent, uint32_t sign, uint64_t bitSum) +{ + constexpr uint32_t numTotalBits = 64; // The bit number of uint64_t + constexpr uint32_t numOffset = 16; // The bits number difference between float and half data type + constexpr uint32_t numTotalMantissaHalf + = getNumTotalMantissa(); // The bit number of mantissa for half data type + constexpr uint32_t numTotalMantissaFloat + = getNumTotalMantissa(); // The bit number of mantissa for float data type + + uint64_t extraInMatissa = (bitSum >> numTotalMantissaHalf); + + // Count the bit number for exceeding mantissa and the extra unwritten 1s + uint32_t numExtra = 0; + uint32_t numDeNorm = 0; + int numNorm = 0; + uint32_t mask = 0; + extraInMatissa = (exponent == 0) ? extraInMatissa : extraInMatissa + count; + numExtra = numTotalBits - __clzll(extraInMatissa); + numNorm = (exponent == 0) ? 0 : -1; + if (extraInMatissa == 0) + { + numDeNorm = numTotalMantissaHalf - (numTotalBits - __clzll(bitSum)); + } + exponent = exponent + ((numExtra + numNorm + 127 - 15 - numDeNorm) << numTotalMantissaHalf); + // As extra bits (extraInMatissa) need to be part of the mantissa, we have to move the current + // mantissa within the range of [0-23]bits. + // This is the only step cause precision loss + uint32_t mantissa; + if (extraInMatissa != 0) + { + int numMove = numTotalMantissaFloat - (numExtra - 1); + mask = (1u << (numExtra - 1)) - 1; + // As the first bit of extraInMatissa is the unwritten 1, + // we need to mask that to zero + extraInMatissa = extraInMatissa & mask; + if (numMove > 0) + { + extraInMatissa = extraInMatissa << numMove; + mask = (1u << numTotalMantissaHalf) - 1; + mantissa = (((bitSum & mask) << (numTotalMantissaFloat - numTotalMantissaHalf)) >> (numExtra - 1)) + | extraInMatissa; + } + else + { + mantissa = extraInMatissa >> (-1 * numMove); + } + } + else + { + mask = (1u << numTotalMantissaHalf) - 1; + mantissa = bitSum << (numDeNorm + 1); + mantissa = mantissa & mask; + mantissa = mantissa << (numTotalMantissaFloat - numTotalMantissaHalf); + } + + uint32_t bitFloat = (sign << numOffset) | (exponent << (numTotalMantissaFloat - numTotalMantissaHalf)) | mantissa; + return reinterpret_cast(bitFloat); +} + +__device__ float calcFloatValue(uint32_t count, uint32_t exponent, uint64_t bitSum) +{ + constexpr uint32_t numTotalBits = 64; + constexpr uint32_t numTotalMantissa = getNumTotalMantissa(); + uint64_t extraInMatissa = (bitSum >> numTotalMantissa); + // Count the bit number for exceeding mantissa and the extra unwritten 1s + uint32_t numExtra; + int numNorm = 0; + uint32_t mask = 0; + extraInMatissa = (exponent == 0) ? extraInMatissa : extraInMatissa + count; + numExtra = numTotalBits - __clzll(extraInMatissa); + numNorm = (exponent == 0) ? 0 : -1; + exponent = exponent + ((numExtra + numNorm) << numTotalMantissa); + // As extra integers need to be part of the mantissa, we have to move the current + // mantissa within the range of [0-23]bits. + // This is the only step cause precision loss + uint32_t mantissa; + if (extraInMatissa != 0) + { + int numMove = numTotalMantissa - (numExtra - 1); + // As the first bit of extraInMatissa is the unwritten 1, + // we need to mask that to zero + mask = (1u << (numExtra - 1)) - 1; + extraInMatissa = extraInMatissa & mask; + if (numMove > 0) + { + extraInMatissa = extraInMatissa << numMove; + mask = (1u << numTotalMantissa) - 1; + mantissa = ((bitSum & mask) >> (numExtra - 1)) | extraInMatissa; + } + else + { + mantissa = extraInMatissa >> (-1 * numMove); + } + } + else + { + mantissa = bitSum; + } + uint32_t bitFloat = exponent | mantissa; + return reinterpret_cast(bitFloat); +} + +template +__device__ constexpr void calcAtomicAdd(HisT* dst, T value) +{ + if constexpr (isDeterministic) + { + uint32_t mantissa = calcMantissa(value); + if constexpr (std::is_same_v) + { + atomicAdd(dst, mantissa); + } + else + { + // Have to use reinterpret_cast() to convert uint64_t to "unsigned long long" + // Otherwise, the complication will report the follow error: + //"error: no instance of overloaded function "atomicAdd" matches the argument list + // argument types are: (uint64_t *, uint64_t)" + atomicAdd(reinterpret_cast(dst), static_cast(mantissa)); + } + } + else + { + if constexpr (std::is_same_v) + { + atomicAdd(dst, __half2float(value)); + } + else + { + atomicAdd(dst, value); + } + } +} + /** * Use CUB to twiddle bits. */ @@ -166,7 +386,7 @@ __device__ T twiddleOut(typename cub::Traits::UnsignedBits bits, bool selectM * Find the bucket based on the radix */ template -__device__ int calcBucket(T x, int startBit, unsigned mask, bool selectMin) +__device__ int calcBucket(T x, int startBit, uint32_t mask, bool selectMin) { static_assert(BitsPerPass <= sizeof(int) * 8 - 1, "BitsPerPass is too large that the result type could not be int"); return (twiddleIn(x, selectMin) >> startBit) & mask; @@ -308,19 +528,18 @@ __device__ void vectorizedProcess(size_t threadRank, size_t numThreads, T const* * Fused filtering of the current pass and building histogram for the next pass (see steps 4 & 1 in `airTopPSampling` * description). */ -template +template __device__ __forceinline__ void filterAndHistogram(T const* inBuf, IdxT const* inIdxBuf, T* outBuf, IdxT* outIdxBuf, - int previousLen, Counter* counter, AccT* histogram, IdxT* countHistogram, int pass, - float* outputLogProbs, float* cumLogProbs, IdxT** ids, IdxT const* endIds, IdxT* sequenceLengths, - FinishedState* finishedOutput, int const batchId, int maxBatchSize, bool earlyStop) + int previousLen, Counter* counter, HisT* histogram, IdxT* countHistogram, HisT* histogramSmem, + IdxT* countHistogramSmem, int pass, float* outputLogProbs, float* cumLogProbs, IdxT** ids, IdxT const* endIds, + IdxT* sequenceLengths, FinishedState* finishedOutput, int const batchId, int maxBatchSize, bool earlyStop) { static_assert(std::is_same_v | std::is_same_v, "T needs to be either half or float"); static_assert(std::is_same_v, "AccT needs to be float"); int constexpr numBuckets = calcNumBuckets(); bool constexpr selectMin = false; - __shared__ AccT histogramSmem[numBuckets]; - __shared__ IdxT countHistogramSmem[numBuckets]; + for (IdxT i = threadIdx.x; i < numBuckets; i += blockDim.x) { histogramSmem[i] = 0; @@ -328,8 +547,8 @@ __device__ __forceinline__ void filterAndHistogram(T const* inBuf, IdxT const* i } __syncthreads(); - int const startBit = calcsStartBit(pass); - unsigned const mask = calcMask(pass); + int const startBit = calcStartBit(pass); + uint32_t const mask = calcMask(pass); if (pass == 0) { @@ -337,18 +556,10 @@ __device__ __forceinline__ void filterAndHistogram(T const* inBuf, IdxT const* i // parallel, i.e. the work is split along the input (both, in batches and // chunks of a single row). Later, the histograms are merged using // atomicAdd. - auto f = [selectMin, startBit, mask](T value, IdxT) + auto f = [selectMin, startBit, mask, histogramSmem, countHistogramSmem](T value, IdxT) { int bucket = calcBucket(value, startBit, mask, selectMin); - if constexpr (std::is_same_v) - { - atomicAdd(histogramSmem + bucket, __half2float(value)); - } - else - { - atomicAdd(histogramSmem + bucket, value); - } - + calcAtomicAdd(histogramSmem + bucket, value); atomicAdd(countHistogramSmem + bucket, static_cast(1)); }; vectorizedProcess(static_cast(blockIdx.x) * blockDim.x + threadIdx.x, @@ -358,23 +569,33 @@ __device__ __forceinline__ void filterAndHistogram(T const* inBuf, IdxT const* i { IdxT* pFilterCnt = &counter->filterCnt; auto const kthValueBits = counter->kthValueBits; - int const previousStartBit = calcsStartBit(pass - 1); + int const previousStartBit = calcStartBit(pass - 1); // See the remark above on the distributed execution of `f` using // vectorizedProcess. auto f = [inIdxBuf, outBuf, outIdxBuf, selectMin, startBit, mask, previousStartBit, kthValueBits, pFilterCnt, - outputLogProbs, cumLogProbs, ids, endIds, sequenceLengths, finishedOutput, batchId, maxBatchSize, - earlyStop](T value, IdxT i) + histogramSmem, countHistogramSmem, outputLogProbs, cumLogProbs, ids, endIds, sequenceLengths, + finishedOutput, batchId, maxBatchSize, earlyStop](T value, IdxT i) { auto const previousBits = (twiddleIn(value, selectMin) >> previousStartBit) << previousStartBit; if (previousBits == kthValueBits) { if (earlyStop) { + int const currentStep = sequenceLengths[batchId]; IdxT index = inIdxBuf ? inIdxBuf[i] : i; ids[batchId][currentStep] = index; - epilogue(value, index, outputLogProbs, cumLogProbs, endIds, sequenceLengths, finishedOutput, + float valueFloat; + if constexpr (std::is_same_v) + { + valueFloat = __half2float(value); + } + else + { + valueFloat = value; + } + epilogue(valueFloat, index, outputLogProbs, cumLogProbs, endIds, sequenceLengths, finishedOutput, batchId, maxBatchSize); } if (outBuf) @@ -385,15 +606,7 @@ __device__ __forceinline__ void filterAndHistogram(T const* inBuf, IdxT const* i } int bucket = calcBucket(value, startBit, mask, selectMin); - if constexpr (std::is_same_v) - { - atomicAdd(histogramSmem + bucket, __half2float(value)); - } - else - { - atomicAdd(histogramSmem + bucket, value); - } - + calcAtomicAdd(histogramSmem + bucket, value); atomicAdd(countHistogramSmem + bucket, static_cast(1)); } }; @@ -412,7 +625,18 @@ __device__ __forceinline__ void filterAndHistogram(T const* inBuf, IdxT const* i { if (histogramSmem[i] != 0) { - atomicAdd(histogram + i, histogramSmem[i]); + if constexpr ((isDeterministic) && (std::is_same_v) ) + { + // Have to use reinterpret_cast() to convert uint64_t to "unsigned long long" + // Otherwise, the complication will report the follow error: + //"error: no instance of overloaded function "atomicAdd" matches the argument list + // argument types are: (uint64_t *, uint64_t)" + atomicAdd(reinterpret_cast(histogram + i), histogramSmem[i]); + } + else + { + atomicAdd(histogram + i, histogramSmem[i]); + } } if (countHistogramSmem[i] != 0) { @@ -425,7 +649,7 @@ __device__ __forceinline__ void filterAndHistogram(T const* inBuf, IdxT const* i * Replace histogram with its own prefix sum (step 2 in `airTopPSampling` description) */ template -__device__ void scan(IdxT volatile* histogram) +__device__ void scan(IdxT volatile* histogram, IdxT* histogramOut) { int constexpr numBuckets = calcNumBuckets(); if constexpr (numBuckets >= BlockSize) @@ -451,7 +675,7 @@ __device__ void scan(IdxT volatile* histogram) BlockScan(tempStorage.scan).InclusiveSum(threadData, threadData); __syncthreads(); - BlockStore(tempStorage.store).Store(histogram, threadData); + BlockStore(tempStorage.store).Store(histogramOut, threadData); } else { @@ -469,7 +693,7 @@ __device__ void scan(IdxT volatile* histogram) if (threadIdx.x < numBuckets) { - histogram[threadIdx.x] = threadData; + histogramOut[threadIdx.x] = threadData; } } } @@ -498,7 +722,7 @@ __device__ void chooseBucket( counter->sum = sum - prev; // how many values still are there to find counter->len = countHistogram[i]; // cur - prev; // number of values in next pass typename cub::Traits::UnsignedBits bucket = i; - int startBit = calcsStartBit(pass); + int startBit = calcStartBit(pass); counter->kthValueBits |= bucket << startBit; } } @@ -543,34 +767,159 @@ __device__ void epilogue(T const value, IdxT const index, float* outputLogProbs, * Find the target element. * (steps 4 in `airTopPSampling` description) */ -template +template __device__ void lastFilter(T const* inBuf, IdxT const* inIdxBuf, IdxT currentLen, Counter* counter, float* outputLogProbs, float* cumLogProbs, IdxT** ids, IdxT const* endIds, IdxT* sequenceLengths, - FinishedState* finishedOutput, int const batchId, int maxBatchSize) + FinishedState* finishedOutput, int const batchId, int maxBatchSize, IdxT* lastIdxBuf, IdxT* countHistogram) { auto const kthValueBits = counter->kthValueBits; auto const equalValue = twiddleOut(kthValueBits, false); int const currentStep = sequenceLengths[batchId]; IdxT* outIdx = &ids[batchId][currentStep]; - if (threadIdx.x == 0) + + float equalValueFloat; + if constexpr (std::is_same_v) { - *outIdx = cuda::std::numeric_limits::max(); + equalValueFloat = __half2float(equalValue); } - __syncthreads(); + else + { + equalValueFloat = equalValue; + } + if constexpr (!isDeterministic) + { - for (IdxT i = threadIdx.x; i < currentLen; i += blockDim.x) + for (IdxT i = threadIdx.x; i < currentLen; i += blockDim.x) + { + if (inBuf[i] == equalValue) + { + *outIdx = inIdxBuf ? inIdxBuf[i] : i; + break; + } + } + } + else { - if (inBuf[i] == equalValue) + IdxT const bufLen = calcBufLen(counter->oriLen); + IdxT neededNumOfKth = counter->sum > 0 ? ceil(counter->sum / equalValueFloat) : 1; + + if (counter->len < neededNumOfKth) { - atomicMin(outIdx, inIdxBuf ? inIdxBuf[i] : i); + neededNumOfKth = counter->len; + } + + if (neededNumOfKth < bufLen) + { + for (int i = threadIdx.x; i < neededNumOfKth; i += blockDim.x) + { + lastIdxBuf[i] = cuda::std::numeric_limits::max(); + } + __threadfence_block(); + __syncthreads(); + + cuda::atomic_ref refLast(lastIdxBuf[neededNumOfKth - 1]); + + for (IdxT i = threadIdx.x; i < currentLen; i += blockDim.x) + { + if (inBuf[i] == equalValue) + { + IdxT newIdx = inIdxBuf ? inIdxBuf[i] : i; + if (newIdx < refLast.load(cuda::memory_order_relaxed)) + { + for (int j = 0; j < neededNumOfKth; j++) + { + IdxT preIdx = atomicMin_block(&lastIdxBuf[j], newIdx); + if (preIdx > newIdx) + { + newIdx = preIdx; + } + } + } + } + } + __syncthreads(); + if (threadIdx.x == 0) + { + *outIdx = refLast.load(cuda::memory_order_relaxed); + } + } + else + { + int numPass = calcNumPasses(); + int constexpr numBuckets = calcNumBuckets(); + __shared__ typename cub::Traits::UnsignedBits kthValueBitsIdx; + __shared__ IdxT neededNumOfKthSmem; + if (threadIdx.x == 0) + { + kthValueBitsIdx = 0; + neededNumOfKthSmem = neededNumOfKth; + } + __syncthreads(); + for (int pass = 0; pass < numPass; pass++) + { + for (IdxT i = threadIdx.x; i < numBuckets; i += blockDim.x) + { + countHistogram[i] = 0; + } + __syncthreads(); + + int preNeededNumOfKth = neededNumOfKthSmem; + int const startBit = calcStartBit(pass); + uint32_t const mask = calcMask(pass); + for (IdxT j = threadIdx.x; j < currentLen; j += blockDim.x) + { + if (inBuf[j] == equalValue) + { + IdxT newIdx = inIdxBuf ? inIdxBuf[j] : j; + bool isQualified = (pass == 0) ? true : false; + if (pass > 0) + { + int const previousStartBit = calcStartBit(pass - 1); + auto const previousBits = (twiddleIn(newIdx, true) >> previousStartBit) << previousStartBit; + if (previousBits == kthValueBitsIdx) + { + isQualified = true; + } + } + if (isQualified) + { + int bucket = calcBucket(newIdx, startBit, mask, true); + atomicAdd(countHistogram + bucket, static_cast(1)); + } + } + } // end histogram + __syncthreads(); + + scan(countHistogram, countHistogram); // prefix sum + __syncthreads(); + // Locate the bucket + for (int i = threadIdx.x; i < numBuckets; i += blockDim.x) + { + IdxT prev = (i == 0) ? 0 : countHistogram[i - 1]; + IdxT cur = countHistogram[i]; + // one and only one thread will satisfy this condition, so counter is + // written by only one thread + if (prev < preNeededNumOfKth && preNeededNumOfKth <= cur) + { + neededNumOfKthSmem = neededNumOfKthSmem - prev; + typename cub::Traits::UnsignedBits bucket = i; + kthValueBitsIdx |= bucket << startBit; + } + } + __syncthreads(); + } + if (threadIdx.x == 0) + { + *outIdx = twiddleOut(kthValueBitsIdx, true); + } } } __syncthreads(); if (threadIdx.x == 0) { - epilogue(equalValue, *outIdx, outputLogProbs, cumLogProbs, endIds, sequenceLengths, finishedOutput, batchId, - maxBatchSize); + epilogue(equalValueFloat, *outIdx, outputLogProbs, cumLogProbs, endIds, sequenceLengths, finishedOutput, + batchId, maxBatchSize); } } @@ -611,8 +960,9 @@ __device__ void lastFilter(T const* inBuf, IdxT const* inIdxBuf, IdxT currentLen * rather than from `in_buf`. The benefit is that we can save the cost of writing candidates and * their indices. */ -template -__global__ void airTopPSampling(Counter* counters, AccT* histograms, IdxT* countHistograms, IdxT** ids, +template +__global__ void airTopPSampling(Counter* counters, HisT* histograms, IdxT* countHistograms, IdxT** ids, int* sequenceLengths, FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, float* outputLogProbs, IdxT const* endIds, int const maxBatchSize, bool const* skipDecode, int const pass, T* buf1, IdxT* idxBuf1, T* buf2, IdxT* idxBuf2, int32_t const* batchSlots) @@ -699,10 +1049,13 @@ __global__ void airTopPSampling(Counter* counters, AccT* histogra int constexpr numBuckets = calcNumBuckets(); auto histogram = histograms + batchId * numBuckets; auto countHistogram = countHistograms + batchId * numBuckets; + __shared__ HisT histogramSmem[numBuckets]; + __shared__ IdxT countHistogramSmem[numBuckets]; + AccT* histValueSmem = reinterpret_cast(histogramSmem); - filterAndHistogram(inBuf, inIdxBuf, outBuf, outIdxBuf, previousLen, counter, histogram, - countHistogram, pass, outputLogProbs, cumLogProbs, ids, endIds, sequenceLengths, finishedOutput, batchSlot, - maxBatchSize, earlyStop); + filterAndHistogram(inBuf, inIdxBuf, outBuf, outIdxBuf, + previousLen, counter, histogram, countHistogram, histogramSmem, countHistogramSmem, pass, outputLogProbs, + cumLogProbs, ids, endIds, sequenceLengths, finishedOutput, batchSlot, maxBatchSize, earlyStop); __syncthreads(); __threadfence(); @@ -710,7 +1063,7 @@ __global__ void airTopPSampling(Counter* counters, AccT* histogra bool isLastBlock = false; if (threadIdx.x == 0) { - unsigned int finished = atomicInc(&counter->finishedBlockCnt, gridDim.x - 1); + uint32_t finished = atomicInc(&counter->finishedBlockCnt, gridDim.x - 1); isLastBlock = (finished == (gridDim.x - 1)); } @@ -718,9 +1071,70 @@ __global__ void airTopPSampling(Counter* counters, AccT* histogra { if (earlyStop) { + if (threadIdx.x == 0) + { + // avoid duplicated epilgue() + counter->previousLen = 0; + counter->len = 0; + } return; } + if constexpr (isDeterministic) + { + for (int i = threadIdx.x; i < numBuckets; i += blockDim.x) + { + uint64_t value = (uint64_t) histogram[i]; + IdxT count = countHistogram[i]; + + if (count != 0) + { + uint32_t startBit = calcStartBit(pass); + float bucketValueFloat; + if constexpr (std::is_same_v) + { + // To acquire the summation in single-precision format, we need to get the original exponent + // value first counter->kthValueBits stores the bits selected by previous pass, which contains + // the bit corresponds to the exponent value + uint16_t bucketValue = counter->kthValueBits; + + // For the first pass, different bucket indices correspond to different exponents. + // The bucket index can be used to deduce the exponent. + if (pass == 0) + { + // Right shift the bucket index with startBit bits (5 bits for half-precision when pass==0), + // so that the bucket index fills the bit related to exponent. + bucketValue = i << startBit; + } + uint32_t exponent = calcExponent(twiddleOut(bucketValue, false)); + uint32_t mask = (1u << (sizeof(half) * CHAR_BIT - 1)) - 1; + uint32_t sign = exponent & (~mask); + exponent = exponent & mask; + float tmp = calcHalfValue((uint32_t) count, exponent, sign, value); + histValueSmem[i] = tmp; + } + else + { + // To acquire the summation in single-precision format, we need to get the original exponent + // value first + uint32_t bucketValue = counter->kthValueBits; + if (pass == 0) + { + // Right shift the bucket index with startBit bits (22 bits for single-precision when + // pass==0), so that the bucket index fills the bit related to exponent. + bucketValue = i << startBit; + } + bucketValueFloat = twiddleOut(bucketValue, false); + uint32_t exponent = calcExponent(bucketValueFloat); + histValueSmem[i] = calcFloatValue((uint32_t) count, exponent, value); + } + } + else + { + histValueSmem[i] = 0.0f; + } + } + } __shared__ IdxT maxBucket; if (pass > 0) { @@ -743,21 +1157,22 @@ __global__ void airTopPSampling(Counter* counters, AccT* histogra __syncthreads(); } - scan(histogram); + scan( + isDeterministic ? histValueSmem : reinterpret_cast(histogram), histValueSmem); __syncthreads(); if (pass == 0) { - currentSum = histogram[numBuckets - 1] * counter->p; + currentSum = histValueSmem[numBuckets - 1] * counter->p; } else { - if (currentSum > histogram[maxBucket]) + if (currentSum > histValueSmem[maxBucket]) { - currentSum = histogram[maxBucket]; + currentSum = histValueSmem[maxBucket]; } } - chooseBucket(counter, histogram, countHistogram, currentSum, pass); + chooseBucket(counter, histValueSmem, countHistogram, currentSum, pass); __syncthreads(); int constexpr numPasses = calcNumPasses(); @@ -779,12 +1194,18 @@ __global__ void airTopPSampling(Counter* counters, AccT* histogra if (pass == numPasses - 1) { - if constexpr (is_fused_filter) + // Used when isDeterministic==true + // idxBuf1 and idxBuf2 are ping-pong buffers used in previous iterations to store candidates. + // In the last pass (pass==2 for single-precision and pass==1 for half-precision), + // we reuse the buffer didn't store the candidates (idxBuf1 for single-precision and idxBuf2 for + // half-precision) to help find the correct index of the result. + IdxT* lastIdxBuf = (pass % 2 == 0) ? idxBuf1 + bufLen * batchId : idxBuf2 + bufLen * batchId; + if constexpr (isFusedFilter) { - lastFilter(outBuf ? outBuf : inBuf, outIdxBuf ? outIdxBuf : inIdxBuf, - outBuf ? currentLen : counter->oriLen, counter, outputLogProbs, cumLogProbs, ids, endIds, - sequenceLengths, finishedOutput, batchSlot, maxBatchSize); - + lastFilter(outBuf ? outBuf : inBuf, + outIdxBuf ? outIdxBuf : inIdxBuf, outBuf ? currentLen : counter->oriLen, counter, outputLogProbs, + cumLogProbs, ids, endIds, sequenceLengths, finishedOutput, batchSlot, maxBatchSize, lastIdxBuf, + countHistogramSmem); __syncthreads(); } } @@ -794,9 +1215,9 @@ __global__ void airTopPSampling(Counter* counters, AccT* histogra /** * Initialize the Counter and the histogram and countHistogram. */ -template +template __global__ void airTopPInitialize(Counter* counters, int const batchSize, int const len, T const* in, - IdxT const* inIdx, float const topP, float const* topPs, curandState_t* curandstate, AccT* histograms, + IdxT const* inIdx, float const topP, float const* topPs, curandState_t* curandstate, HisT* histograms, IdxT* countHistograms, int32_t const* batchSlots) { auto const batchIdx = blockIdx.x; @@ -828,7 +1249,7 @@ __global__ void airTopPInitialize(Counter* counters, int const ba } int constexpr numBuckets = calcNumBuckets(); - AccT* histogram = histograms + batchIdx * numBuckets; + HisT* histogram = histograms + batchIdx * numBuckets; for (int i = threadIdx.x; i < numBuckets; i += BlockSize) { histogram[i] = 0; @@ -848,15 +1269,28 @@ __global__ void airTopPInitialize(Counter* counters, int const ba /* * Calculate the number of blocks based on the batchSize and len to avoid tailing effect. */ -template -unsigned calcAirTopPBlockNum(int batchSize, IdxT len, int smCnt) +template +uint32_t calcAirTopPBlockNum(int batchSize, int len, int smCnt, bool isDeterministic) { + int constexpr BitsPerPass = 11; + int constexpr BlockSize = 512; int constexpr VECTORIZED_READ_SIZE = 16; static_assert(VECTORIZED_READ_SIZE / sizeof(T) >= 1); + TLLM_CHECK_WITH_INFO( + smCnt > 0, "AIR Top-P needs the count of multiprocessor to calculate the proper block dimension settings"); int activeBlocks; - cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &activeBlocks, airTopPSampling, BlockSize, 0); + if (isDeterministic) + { + using HisT = std::conditional_t, uint64_t, uint32_t>; + cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &activeBlocks, airTopPSampling, BlockSize, 0); + } + else + { + cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &activeBlocks, airTopPSampling, BlockSize, 0); + } activeBlocks *= smCnt; IdxT bestNumBlocks = 0; @@ -892,15 +1326,17 @@ unsigned calcAirTopPBlockNum(int batchSize, IdxT len, int smCnt) return bestNumBlocks; } -template +template [[nodiscard]] std::vector getAirTopPWorkspaceSizes(int32_t batchSize, int32_t vocabSize) { + using HisT + = std::conditional_t, uint64_t, uint32_t>, float>; int constexpr BitsPerPass = 11; int constexpr numBuckets = calcNumBuckets(); IdxT const bufLen = calcBufLen(vocabSize); size_t countersSize = sizeof(Counter) * batchSize; - size_t histogramsSize = sizeof(AccT) * numBuckets * batchSize; + size_t histogramsSize = sizeof(HisT) * numBuckets * batchSize; size_t countHistogramsSize = sizeof(IdxT) * numBuckets * batchSize; size_t buf1Size = sizeof(T) * bufLen * batchSize; size_t idxBuf1Size = sizeof(IdxT) * bufLen * batchSize; @@ -913,18 +1349,27 @@ template return sizes; } -template std::vector getAirTopPWorkspaceSizes(int32_t batchSize, int32_t vocabSize); -template std::vector getAirTopPWorkspaceSizes(int32_t batchSize, int32_t vocabSize); +template std::vector getAirTopPWorkspaceSizes(int32_t batchSize, int32_t vocabSize); +template std::vector getAirTopPWorkspaceSizes(int32_t batchSize, int32_t vocabSize); +template std::vector getAirTopPWorkspaceSizes(int32_t batchSize, int32_t vocabSize); +template std::vector getAirTopPWorkspaceSizes(int32_t batchSize, int32_t vocabSize); -template -void invokeBatchAirTopPSampling(void* workspace, int** outputIds, int* sequenceLength, +template +void invokeAirTopPSamplingWithDeterministicPara(void* workspace, int** outputIds, int* sequenceLength, FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, float* outputLogProbs, T const* logProbs, curandState_t* curandstate, int const batchSize, int maxBatchSize, size_t const vocabSizePadded, int const* endIds, float const maxTopP, float const* topPs, cudaStream_t stream, int blockNum, bool const* skipDecode, int32_t const* batchSlots) { + using HisT + = std::conditional_t, uint64_t, uint32_t>, float>; + static_assert(std::is_same_v | std::is_same_v, "T needs to be either half or float"); static_assert(std::is_same_v, "AccT needs to be float"); + TLLM_CHECK_WITH_INFO(((std::is_same_v) &&(vocabSizePadded < pow(2, 22)) && isDeterministic) + || ((std::is_same_v) &&(vocabSizePadded < pow(2, 41)) && isDeterministic) || (~isDeterministic), + "For Deterministic AIR Top-P, the maximum vocab_size we support is pow(2,22) for half-precision and pow(2,41) " + "for single-precision"); IdxT const vocabSize = vocabSizePadded; int constexpr BitsPerPass = 11; @@ -933,14 +1378,14 @@ void invokeBatchAirTopPSampling(void* workspace, int** outputIds, int* sequenceL int constexpr THREADS_PER_CTA_TOP_P_INIT = 1024; Counter* counters = nullptr; - AccT* histograms = nullptr; + HisT* histograms = nullptr; IdxT* countHistograms = nullptr; T* buf1 = nullptr; IdxT* idxBuf1 = nullptr; T* buf2 = nullptr; IdxT* idxBuf2 = nullptr; - auto const workspaceSizes = getAirTopPWorkspaceSizes(batchSize, vocabSize); + auto const workspaceSizes = getAirTopPWorkspaceSizes(batchSize, vocabSize); std::vector alignedPointers; calcAlignedPointers(alignedPointers, workspace, workspaceSizes); @@ -952,27 +1397,46 @@ void invokeBatchAirTopPSampling(void* workspace, int** outputIds, int* sequenceL buf2 = static_cast(alignedPointers[5]); idxBuf2 = static_cast(alignedPointers[6]); - airTopPInitialize + airTopPInitialize <<>>(counters, batchSize, vocabSize, logProbs, nullptr, maxTopP, topPs, curandstate, histograms, countHistograms, batchSlots); - sync_check_cuda_error(); dim3 grid(blockNum, batchSize); // Sample with Top P given sorted tokens int constexpr numPasses = calcNumPasses(); - auto kernel = airTopPSampling; + auto kernel = airTopPSampling; for (int pass = 0; pass < numPasses; ++pass) { if (pass == numPasses - 1) { - kernel = airTopPSampling; + kernel = airTopPSampling; } kernel<<>>(counters, histograms, countHistograms, outputIds, sequenceLength, finishedInput, finishedOutput, cumLogProbs, outputLogProbs, endIds, maxBatchSize, skipDecode, pass, buf1, idxBuf1, buf2, idxBuf2, batchSlots); - sync_check_cuda_error(); + } +} + +template +void invokeBatchAirTopPSampling(void* workspace, int** outputIds, int* sequenceLength, + FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, float* outputLogProbs, + T const* logProbs, curandState_t* curandstate, int const batchSize, int maxBatchSize, size_t const vocabSizePadded, + int const* endIds, float const maxTopP, float const* topPs, cudaStream_t stream, int blockNum, + bool const* skipDecode, int32_t const* batchSlots, bool isDeterministic) +{ + if (isDeterministic) + { + invokeAirTopPSamplingWithDeterministicPara(workspace, outputIds, sequenceLength, finishedInput, + finishedOutput, cumLogProbs, outputLogProbs, logProbs, curandstate, batchSize, maxBatchSize, + vocabSizePadded, endIds, maxTopP, topPs, stream, blockNum, skipDecode, batchSlots); + } + else + { + invokeAirTopPSamplingWithDeterministicPara(workspace, outputIds, sequenceLength, finishedInput, + finishedOutput, cumLogProbs, outputLogProbs, logProbs, curandstate, batchSize, maxBatchSize, + vocabSizePadded, endIds, maxTopP, topPs, stream, blockNum, skipDecode, batchSlots); } } @@ -980,49 +1444,57 @@ template void invokeBatchAirTopPSampling(void* workspace, int** outputIds, int* FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, float* outputLogProbs, float const* logProbs, curandState_t* curandstate, int const batchSize, int maxBatchSize, size_t const vocabSizePadded, int const* endIds, float const maxTopP, float const* topPs, cudaStream_t stream, - int blockNum, bool const* skipDecode, int32_t const* batchSlots); + int blockNum, bool const* skipDecode, int32_t const* batchSlots, bool isDeterministic); template void invokeBatchAirTopPSampling(void* workspace, int** outputIds, int* sequenceLength, FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, float* outputLogProbs, half const* logProbs, curandState_t* curandstate, int const batchSize, int maxBatchSize, size_t const vocabSizePadded, int const* endIds, float const maxTopP, float const* topPs, cudaStream_t stream, - int blockNum, bool const* skipDecode, int32_t const* batchSlots); + int blockNum, bool const* skipDecode, int32_t const* batchSlots, bool isDeterministic); template void invokeAirTopPSampling(void* workspace, int** outputIds, int* sequenceLength, FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, float* outputLogProbs, T const* logProbs, curandState_t* curandstate, int const batchSize, int maxBatchSize, size_t const vocabSizePadded, int const* endIds, - float const topP, cudaStream_t stream, int blockNum, bool const* skipDecode, int32_t const* batchSlots) + float const topP, cudaStream_t stream, int blockNum, bool const* skipDecode, int32_t const* batchSlots, + bool isDeterministic) { invokeBatchAirTopPSampling(workspace, outputIds, sequenceLength, finishedInput, finishedOutput, cumLogProbs, outputLogProbs, logProbs, curandstate, batchSize, maxBatchSize, vocabSizePadded, endIds, topP, nullptr, stream, - blockNum, skipDecode, batchSlots); + blockNum, skipDecode, batchSlots, isDeterministic); } template void invokeAirTopPSampling(void* workspace, int** outputIds, int* sequenceLength, FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, float* outputLogProbs, float const* logProbs, curandState_t* curandstate, int const batchSize, int maxBatchSize, size_t const vocabSizePadded, int const* endIds, float const topP, cudaStream_t stream, int blockNum, - bool const* skipDecode, int32_t const* batchSlots); + bool const* skipDecode, int32_t const* batchSlots, bool isDeterministic); template void invokeAirTopPSampling(void* workspace, int** outputIds, int* sequenceLength, FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, float* outputLogProbs, half const* logProbs, curandState_t* curandstate, int const batchSize, int maxBatchSize, size_t const vocabSizePadded, int const* endIds, float const topP, cudaStream_t stream, int blockNum, - bool const* skipDecode, int32_t const* batchSlots); - -template unsigned calcAirTopPBlockNum(int batchSize, int len, int smCnt); -template unsigned calcAirTopPBlockNum(int batchSize, int len, int smCnt); + bool const* skipDecode, int32_t const* batchSlots, bool isDeterministic); template -size_t getAirTopPWorkspaceSize(int32_t batchSize, int32_t vocabSizePadded) +size_t getAirTopPWorkspaceSize(int32_t batchSize, int32_t vocabSizePadded, bool isDeterministic) { - auto const workspaceSizes = getAirTopPWorkspaceSizes(batchSize, vocabSizePadded); + std::vector workspaceSizes; + if (isDeterministic == true) + { + workspaceSizes = getAirTopPWorkspaceSizes(batchSize, vocabSizePadded); + } + else + { + workspaceSizes = getAirTopPWorkspaceSizes(batchSize, vocabSizePadded); + } return calcAlignedSize(workspaceSizes, 256); } -template size_t getAirTopPWorkspaceSize(int32_t batchSize, int32_t vocabSizePadded); -template size_t getAirTopPWorkspaceSize(int32_t batchSize, int32_t vocabSizePadded); +template size_t getAirTopPWorkspaceSize(int32_t batchSize, int32_t vocabSizePadded, bool isDeterministic); +template size_t getAirTopPWorkspaceSize(int32_t batchSize, int32_t vocabSizePadded, bool isDeterministic); +template uint32_t calcAirTopPBlockNum(int batchSize, int len, int smCnt, bool isDeterministic); +template uint32_t calcAirTopPBlockNum(int batchSize, int len, int smCnt, bool isDeterministic); } // namespace kernels } // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/samplingAirTopPKernels.h b/cpp/tensorrt_llm/kernels/samplingAirTopPKernels.h index 0d31a9094..f06c2bd8b 100644 --- a/cpp/tensorrt_llm/kernels/samplingAirTopPKernels.h +++ b/cpp/tensorrt_llm/kernels/samplingAirTopPKernels.h @@ -61,37 +61,41 @@ namespace kernels //! \param blockNum The appropriate block configuration calculated based on the number of multiprocessors, occupancy, //! batchSize and vocabSizePadded //! \param skipDecode input buffer [batchSize]. Flags whether to skip decoding per request +//! \param batchSlots input buffer[batchSize], optional. Indices of rows of data in memory pool +//! \param isDeterministic bool, optional. Default value is false. +//! When isDeterministic==true, the result is reproducible. template void invokeBatchAirTopPSampling(void* workspace, int** outputIds, int* sequenceLength, FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, float* outputLogProbs, T const* logProbs, curandState_t* curandstate, int const batchSize, int maxBatchSize, size_t const vocabSizePadded, int const* endIds, float const maxTopP, float const* topPs, cudaStream_t stream, int blockNum, - bool const* skipDecode, int32_t const* batchSlots); + bool const* skipDecode, int32_t const* batchSlots, bool isDeterministic = false); //! \brief Specialization of invokeBatchAirTopPSampling with topPs=nullptr template void invokeAirTopPSampling(void* workspace, int** outputIds, int* sequenceLength, FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, float* outputLogProbs, T const* logProbs, curandState_t* curandstate, int const batchSize, int maxBatchSize, size_t const vocabSizePadded, int const* endIds, - float const topP, cudaStream_t stream, int blockNum, bool const* skipDecode, int32_t const* batchSlots); + float const topP, cudaStream_t stream, int blockNum, bool const* skipDecode, int32_t const* batchSlots, + bool isDeterministic = false); //! \brief Calculate the number of blocks based on the number of multiprocessors, batchSize and vocabSize. //! \tparam T the data type of value -//! \tparam IdxT the data type of index -//! \tparam AccT the data type of variables related to accumulation -//! \tparam BitsPerPass the number of bits for each pass. Can be 8 or 11. Use 11 for default. -//! \tparam BlockSize the block size //! \param batchSize //! \param len the number of candidates for each case //! \param smCnt number of multiprocessors on device -template -unsigned calcAirTopPBlockNum(int batchSize, IdxT len, int smCnt); +//! \param isDeterministic bool, optional. Default value is false. +//! When isDeterministic==true, the result is reproducible. +template +uint32_t calcAirTopPBlockNum(int batchSize, int len, int smCnt, bool isDeterministic = false); //! \brief Returns workspace size in bytes needed for sampling Air TopP computation //! \param batchSize batch size //! \param vocabSizePadded size of padded vocab +//! \param isDeterministic bool, optional. Default value is false. +//! When isDeterministic==true, the result is reproducible. template -[[nodiscard]] size_t getAirTopPWorkspaceSize(int32_t batchSize, int32_t vocabSizePadded); +[[nodiscard]] size_t getAirTopPWorkspaceSize(int32_t batchSize, int32_t vocabSizePadded, bool isDeterministic = false); } // namespace kernels } // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/samplingTopKKernels.cu b/cpp/tensorrt_llm/kernels/samplingTopKKernels.cu index 9fc1b982b..3b4f3e330 100644 --- a/cpp/tensorrt_llm/kernels/samplingTopKKernels.cu +++ b/cpp/tensorrt_llm/kernels/samplingTopKKernels.cu @@ -131,12 +131,12 @@ __global__ void topKStage1(T const* __restrict logProbs, T const* const* __restr } template -__global__ void topKStage2Sampling(int const* __restrict topKTmpIdBuf, T* topKTmpValBuf, int** ids, +__global__ void topKStage2Sampling(int const* __restrict topKTmpIdBuf, T* topKTmpValBuf, int** idsPtrs, int* ids, int* sequenceLengths, FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, - float* outputLogProbs, int const maxTopK, int const* topKs, float const topP, float const* topPs, - curandState_t* curandstate, int const* endIds, int const vocabSize, bool const* skipDecode, int const* batchSlots, - int maxBatchSize, bool const normalizeLogProbs, bool const logitHasProbs, int const* tokensPerStep, - int const maxTokensPerStep, bool returnAllTopK) + float* outputLogProbs, int maxTopK, int const* topKs, float topP, float const* topPs, curandState_t* curandstate, + int const* endIds, int vocabSize, bool const* skipDecode, int const* batchSlots, int maxBatchSize, + bool normalizeLogProbs, bool logitHasProbs, int const* tokensPerStep, int maxTokensPerStep, int maxSeqLen, + bool returnAllTopK) { bool const IS_FP16 = std::is_same::value; T const MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; @@ -164,12 +164,12 @@ __global__ void topKStage2Sampling(int const* __restrict topKTmpIdBuf, T* topKTm typedef cub::BlockReduce, BLOCK_SIZE_> BlockReduce; __shared__ typename BlockReduce::TempStorage tempStorage; extern __shared__ char array[]; - __shared__ float s_sum; - T* s_val = topKTmpValBuf + (batchIdx * maxTokensPerStep + tokenIdx) * stride; - auto* s_id = reinterpret_cast(array); + __shared__ float sSum; + T* sVal = topKTmpValBuf + (batchIdx * maxTokensPerStep + tokenIdx) * stride; + auto* sId = reinterpret_cast(array); if (tid == 0) { - s_sum = 0.0f; + sSum = 0.0f; } TopK_2 partial; @@ -182,7 +182,7 @@ __global__ void topKStage2Sampling(int const* __restrict topKTmpIdBuf, T* topKTm return; } - auto s_val2 = reinterpret_cast(s_id + k); + auto sVal2 = reinterpret_cast(sId + k); float maxLogit; for (int ite = 0; ite < k; ite++) { @@ -190,7 +190,7 @@ __global__ void topKStage2Sampling(int const* __restrict topKTmpIdBuf, T* topKTm #pragma unroll for (int i = tid; i < size; i += BLOCK_SIZE_) { - partial.insert((float) s_val[i], i); + partial.insert((float) sVal[i], i); } TopK_2 total = BlockReduce(tempStorage).Reduce(partial, reduce_topk_op_2); @@ -201,8 +201,8 @@ __global__ void topKStage2Sampling(int const* __restrict topKTmpIdBuf, T* topKTm { maxLogit = total.u; } - s_id[ite] = total.p; - s_val[total.p] = -MAX_T_VAL; + sId[ite] = total.p; + sVal[total.p] = -MAX_T_VAL; // when cumLogProbs are computed, topKTmpValBuf (logits_buf_) are // already pre-processed by softmax_kernel @@ -210,30 +210,31 @@ __global__ void topKStage2Sampling(int const* __restrict topKTmpIdBuf, T* topKTm { total.u = __expf(total.u - maxLogit); } - s_val2[ite] = total.u; - s_sum += total.u; + sVal2[ite] = total.u; + sSum += total.u; } __syncthreads(); } if (tid == 0) { - auto randNum = static_cast(curand_uniform(curandstate + batchSlot) * probThreshold * s_sum); + auto randNum = static_cast(curand_uniform(curandstate + batchSlot) * probThreshold * sSum); + auto* outputIdsRequestPtr = idsPtrs == nullptr ? ids + batchSlot * maxSeqLen : idsPtrs[batchSlot]; for (int ki = 0; ki < k; ki++) { - auto expLogit = s_val2[ki]; + auto expLogit = sVal2[ki]; randNum = randNum - expLogit; if (randNum <= 0.0f || ki == k - 1 || returnAllTopK) { - auto idx = s_id[ki]; - // If s_id is -1 here we force output token to the last from vocabulary to get vivid indicator of smth + auto idx = sId[ki]; + // If sId is -1 here we force output token to the last from vocabulary to get vivid indicator of smth // going wrong for the debug auto outputId = idx != -1 ? topKTmpIdBuf[(batchIdx * maxTokensPerStep + tokenIdx) * stride + idx] % vocabSize : vocabSize - 1; - auto const curSeqLen = sequenceLengths[batchSlot]; - auto outIdx = returnAllTopK ? tokenIdx * maxTopK + ki : curSeqLen + tokenIdx; - ids[batchSlot][outIdx] = outputId; + auto const curSeqLen = sequenceLengths == nullptr ? 0 : sequenceLengths[batchSlot]; + auto const outIdx = returnAllTopK ? tokenIdx * maxTopK + ki : curSeqLen + tokenIdx; + outputIdsRequestPtr[outIdx] = outputId; // cum log prob is not supported with returnAllTopK if (!returnAllTopK) { @@ -248,11 +249,11 @@ __global__ void topKStage2Sampling(int const* __restrict topKTmpIdBuf, T* topKTm { // 'outputLogProbs' is the probability induced by the top-k sampling: // NOT normalized (same way as OpenAI does): - // log_prob = log P(i | i is in top-k) = log(expLogit) + // log_prob = log P(i | i is in vocab) = log(expLogit) // normalized: // log_prob = log P(i | i is in top-k) = log(expLogit / sum) outputLogProbs[curSeqLen * maxBatchSize + batchSlot] - = normalizeLogProbs ? logProb - logf(s_sum) : logProb; + = normalizeLogProbs ? logProb - logf(sSum) : logProb; } } break; @@ -262,7 +263,7 @@ __global__ void topKStage2Sampling(int const* __restrict topKTmpIdBuf, T* topKTm if (maxTokensPerStep == 1 && !returnAllTopK && sequenceLengths != nullptr && finishedOutput != nullptr) { int const seqLen = sequenceLengths[batchSlot]; - if (ids[batchSlot][seqLen] == endIds[batchSlot]) + if (outputIdsRequestPtr[seqLen] == endIds[batchSlot]) { finishedOutput[batchSlot].setFinishedEOS(); // Do not increase seq len when EOS is generated. Seq len should always contain only tokens to be @@ -292,19 +293,19 @@ __global__ void topKStage2Sampling(int const* __restrict topKTmpIdBuf, T* topKTm dim3 block(BLOCK_SIZE_2_); \ topKStage2Sampling \ <<>>(topKTmpIdBuf, topKTmpValBuf, \ - ids, sequenceLengths, finishedInput, finishedOutput, cumLogProbs, outputLogProbs, maxTopK, topKs, \ - topP, topPs, curandstate, endIds, vocabSize, skipDecode, batchSlots, maxBatchSize, \ - normalizeLogProbs, logitsHasProbs, tokensPerStep, maxTokensPerStep, returnAllTopK); \ + idsPtrs, ids, sequenceLengths, finishedInput, finishedOutput, cumLogProbs, outputLogProbs, \ + maxTopK, topKs, topP, topPs, curandstate, endIds, vocabSize, skipDecode, batchSlots, maxBatchSize, \ + normalizeLogProbs, logitsHasProbs, tokensPerStep, maxTokensPerStep, maxSeqLen, returnAllTopK); \ } \ } while (0) template -void invokeBatchTopKSampling(void* workspace, T const* logProbs, T const* const* logProbsPtrs, int** ids, +void invokeBatchTopKSampling(void* workspace, T const* logProbs, T const* const* logProbsPtrs, int** idsPtrs, int* ids, int* sequenceLengths, FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, - float* outputLogProbs, curandState_t* curandstate, int const maxTopK, int const* topKs, float const topP, - float const* topPs, int const vocabSize, int const* endIds, int const* batchSlots, cudaStream_t stream, - int const batchSize, int maxBatchSize, int const* tokensPerStep, int const maxTokensPerStep, bool const* skipDecode, - bool normalizeLogProbs, bool logitsHasProbs, bool returnAllTopK) + float* outputLogProbs, curandState_t* curandstate, int maxTopK, int const* topKs, float topP, float const* topPs, + int vocabSize, int const* endIds, int const* batchSlots, cudaStream_t stream, int batchSize, int maxBatchSize, + int const* tokensPerStep, int maxTokensPerStep, int maxSeqLen, bool const* skipDecode, bool normalizeLogProbs, + bool logitsHasProbs, bool returnAllTopK) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); @@ -360,45 +361,45 @@ void invokeBatchTopKSampling(void* workspace, T const* logProbs, T const* const* #undef CASE_K template void invokeBatchTopKSampling(void* workspace, float const* logProbs, float const* const* logProbsPtrs, - int** ids, int* sequenceLengths, FinishedState const* finishedInput, FinishedState* finishedOutput, - float* cumLogProbs, float* outputLogProbs, curandState_t* curandstate, int const maxTopK, int const* topKs, - float const topP, float const* topPs, int const vocabSizePadded, int const* endIds, int const* batchSlots, - cudaStream_t stream, int const batchSize, int maxBatchSize, int const* tokensPerStep, int const maxTokensPerStep, + int** idsPtrs, int* ids, int* sequenceLengths, FinishedState const* finishedInput, FinishedState* finishedOutput, + float* cumLogProbs, float* outputLogProbs, curandState_t* curandstate, int maxTopK, int const* topKs, float topP, + float const* topPs, int vocabSizePadded, int const* endIds, int const* batchSlots, cudaStream_t stream, + int batchSize, int maxBatchSize, int const* tokensPerStep, int maxTokensPerStep, int maxSeqLen, bool const* skipDecode, bool normalizeLogProbs, bool logitsHasProbs, bool returnAllTopK); -template void invokeBatchTopKSampling(void* workspace, half const* logProbs, half const* const* logProbsPtrs, int** ids, - int* sequenceLengths, FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, - float* outputLogProbs, curandState_t* curandstate, int const maxTopK, int const* topKs, float const topP, - float const* topPs, int const vocabSizePadded, int const* endIds, int const* batchSlots, cudaStream_t stream, - int const batchSize, int maxBatchSize, int const* tokensPerStep, int const maxTokensPerStep, bool const* skipDecode, - bool normalizeLogProbs, bool logitsHasProbs, bool returnAllTopK); +template void invokeBatchTopKSampling(void* workspace, half const* logProbs, half const* const* logProbsPtrs, + int** idsPtrs, int* ids, int* sequenceLengths, FinishedState const* finishedInput, FinishedState* finishedOutput, + float* cumLogProbs, float* outputLogProbs, curandState_t* curandstate, int maxTopK, int const* topKs, float topP, + float const* topPs, int vocabSizePadded, int const* endIds, int const* batchSlots, cudaStream_t stream, + int batchSize, int maxBatchSize, int const* tokensPerStep, int maxTokensPerStep, int maxSeqLen, + bool const* skipDecode, bool normalizeLogProbs, bool logitsHasProbs, bool returnAllTopK); template -void invokeTopKSampling(void* workspace, T const* logProbs, T const* const* logProbsPtrs, int** ids, +void invokeTopKSampling(void* workspace, T const* logProbs, T const* const* logProbsPtrs, int** idsPtrs, int* ids, int* sequenceLengths, FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, - float* outputLogProbs, curandState_t* curandstate, int const topK, float const topP, int const vocabSizePadded, - int const* endIds, int const* batchSlots, cudaStream_t stream, int const batchSize, int maxBatchSize, - int const* tokensPerStep, int const maxTokensPerStep, bool const* skipDecode, bool normalizeLogProbs, - bool logitsHasProbs, bool returnAllTopK) + float* outputLogProbs, curandState_t* curandstate, int topK, float topP, int vocabSizePadded, int const* endIds, + int const* batchSlots, cudaStream_t stream, int batchSize, int maxBatchSize, int const* tokensPerStep, + int maxTokensPerStep, int maxSeqLen, bool const* skipDecode, bool normalizeLogProbs, bool logitsHasProbs, + bool returnAllTopK) { - invokeBatchTopKSampling(workspace, logProbs, logProbsPtrs, ids, sequenceLengths, finishedInput, finishedOutput, - cumLogProbs, outputLogProbs, curandstate, topK, nullptr, topP, nullptr, vocabSizePadded, endIds, batchSlots, - stream, batchSize, maxBatchSize, tokensPerStep, maxTokensPerStep, skipDecode, normalizeLogProbs, logitsHasProbs, - returnAllTopK); + invokeBatchTopKSampling(workspace, logProbs, logProbsPtrs, idsPtrs, ids, sequenceLengths, finishedInput, + finishedOutput, cumLogProbs, outputLogProbs, curandstate, topK, nullptr, topP, nullptr, vocabSizePadded, endIds, + batchSlots, stream, batchSize, maxBatchSize, tokensPerStep, maxTokensPerStep, maxSeqLen, skipDecode, + normalizeLogProbs, logitsHasProbs, returnAllTopK); } -template void invokeTopKSampling(void* workspace, float const* logProbs, float const* const* logProbsPtrs, int** ids, - int* sequenceLengths, FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, - float* outputLogProbs, curandState_t* curandstate, int const topK, float const topP, int const vocabSizePadded, - int const* endIds, int const* batchSlots, cudaStream_t stream, int const batchSize, int maxBatchSize, - int const* tokensPerStep, int const maxTokensPerStep, bool const* skipDecode, bool normalizeLogProbs, +template void invokeTopKSampling(void* workspace, float const* logProbs, float const* const* logProbsPtrs, + int** idsPtrs, int* ids, int* sequenceLengths, FinishedState const* finishedInput, FinishedState* finishedOutput, + float* cumLogProbs, float* outputLogProbs, curandState_t* curandstate, int topK, float topP, int vocabSizePadded, + int const* endIds, int const* batchSlots, cudaStream_t stream, int batchSize, int maxBatchSize, + int const* tokensPerStep, int maxTokensPerStep, int maxSeqLen, bool const* skipDecode, bool normalizeLogProbs, bool logitsHasProbs, bool returnAllTopK); -template void invokeTopKSampling(void* workspace, half const* logProbs, half const* const* logProbsPtrs, int** ids, - int* sequenceLengths, FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, - float* outputLogProbs, curandState_t* curandstate, int const topK, float const topP, int const vocabSizePadded, - int const* endIds, int const* batchSlots, cudaStream_t stream, int const batchSize, int maxBatchSize, - int const* tokensPerStep, int const maxTokensPerStep, bool const* skipDecode, bool normalizeLogProbs, +template void invokeTopKSampling(void* workspace, half const* logProbs, half const* const* logProbsPtrs, int** idsPtrs, + int* ids, int* sequenceLengths, FinishedState const* finishedInput, FinishedState* finishedOutput, + float* cumLogProbs, float* outputLogProbs, curandState_t* curandstate, int const topK, float topP, + int vocabSizePadded, int const* endIds, int const* batchSlots, cudaStream_t stream, int batchSize, int maxBatchSize, + int const* tokensPerStep, int maxTokensPerStep, int maxSeqLen, bool const* skipDecode, bool normalizeLogProbs, bool logitsHasProbs, bool returnAllTopK); } // namespace kernels diff --git a/cpp/tensorrt_llm/kernels/samplingTopKKernels.h b/cpp/tensorrt_llm/kernels/samplingTopKKernels.h index facdb55ed..dfccfb284 100644 --- a/cpp/tensorrt_llm/kernels/samplingTopKKernels.h +++ b/cpp/tensorrt_llm/kernels/samplingTopKKernels.h @@ -19,6 +19,7 @@ #include "tensorrt_llm/common/cudaUtils.h" #include "tensorrt_llm/common/memoryUtils.h" #include "tensorrt_llm/kernels/decodingCommon.h" +#include "tensorrt_llm/runtime/common.h" #include #include @@ -27,6 +28,9 @@ namespace tensorrt_llm { namespace kernels { + +static constexpr uint32_t TOP_K_MAX = 1024; + // clang-format off //! \brief Given logProbs, performs top K **and** top P sampling at the same time. Fills sampled tokens to outputIds. //! Computes sequenceLength, finished state, cumLogProbs inplace. @@ -41,7 +45,10 @@ namespace kernels //! logProbs must contain **just** probabilities instead of log probabilities. //! \param logProbsPtr input buffer [batchSize][vocabSizePadded] array of pointers to logits. If nullptr, logProbs is used. //! Only maxTokensPerStep == 1 is supported. -//! \param outputIds output buffer [maxBatchSize][maxSeqLen]. Contains point32_ters to rows with output tokens per request +//! \param outputIdsPtrs output buffer [maxBatchSize][maxSeqLen], optional. Contains pointers to rows with output tokens per request. +//! If nullptr, outputIds must be provided. +//! \param outputIds output buffer [maxBatchSize, maxSeqLen], optional. Tensor to store output tokens. +//! Not used if outputIdsPtrs != nullptr //! \param sequenceLength input/output buffer [maxBatchSize]. Current sequence length of the request up to, but excluding endId token //! \param finishedInput input buffer [maxBatchSize]. If true, request exits early. //! \param finishedOutput output buffer [maxBatchSize]. Set flag if sequence has finished (if finished || outputId == endId). @@ -69,37 +76,42 @@ namespace kernels //! \param tokensPerStep input buffer [maxBatchSize], optional. Number of tokens per step for each request. //! It is assumed that all requests have maxTokensPerStep tokens per step if nullptr. //! \param maxTokensPerStep maximum number of tokens per computed per step +//! \param maxSeqLen maximum sequence length of outputIds //! \param skipDecode input buffer [maxBatchSize]. Flags whether to skip decoding per request //! \param normalizeLogProbs when set to True outputLogProbs are normalized to TopK //! \param logitsHasProbs flag to highlight that logProbs contains probabilities //! \param returnAllTopK flag to return all selectedTopK results // clang-format on template -void invokeBatchTopKSampling(void* workspace, T const* logProbs, T const* const* logProbsPtr, int32_t** ids, - int32_t* sequenceLengths, FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, - float* outputLogProbs, curandState_t* curandstate, const int32_t maxTopK, int32_t const* topKs, float const topP, - float const* topPs, const int32_t vocabSizePadded, int32_t const* endIds, int32_t const* batchSlots, - cudaStream_t stream, const int32_t batchSize, int maxBatchSize, int32_t const* tokensPerStep, - const int32_t maxTokensPerStep, bool const* skipDecode, bool normalizeLogProbs, bool logitsHasProbs, +void invokeBatchTopKSampling(void* workspace, T const* logProbs, T const* const* logProbsPtr, + runtime::TokenIdType** outputIdsPtrs, runtime::TokenIdType* outputIds, runtime::SizeType* sequenceLengths, + FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, float* outputLogProbs, + curandState_t* curandstate, runtime::SizeType maxTopK, runtime::SizeType const* topKs, float topP, + float const* topPs, runtime::SizeType vocabSizePadded, runtime::TokenIdType const* endIds, + runtime::SizeType const* batchSlots, cudaStream_t stream, runtime::SizeType batchSize, + runtime::SizeType maxBatchSize, runtime::SizeType const* tokensPerStep, runtime::SizeType maxTokensPerStep, + runtime::SizeType maxSeqLen, bool const* skipDecode, bool normalizeLogProbs, bool logitsHasProbs, bool returnAllTopK); //! \brief Specialization of invokeBatchTopKSampling with topPs=nullptr and topKs=nullptr template -void invokeTopKSampling(void* workspace, T const* logProbs, T const* const* logProbsPtr, int32_t** outputIds, - int32_t* sequenceLength, FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, - float* outputLogProbs, curandState_t* curandstate, const int32_t topK, float const topP, - const int32_t vocabSizePadded, int32_t const* endIds, int32_t const* batchSlots, cudaStream_t stream, - const int32_t batchSize, int maxBatchSize, int32_t const* tokensPerStep, const int32_t maxTokensPerStep, - bool const* skipDecode, bool normalizeLogProbs, bool logitsHasProbs, bool returnAllTopK); +void invokeTopKSampling(void* workspace, T const* logProbs, T const* const* logProbsPtr, + runtime::TokenIdType** outputIdsPtrs, runtime::TokenIdType* outputIds, runtime::SizeType* sequenceLength, + FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, float* outputLogProbs, + curandState_t* curandstate, runtime::SizeType topK, float topP, runtime::SizeType vocabSizePadded, + runtime::TokenIdType const* endIds, runtime::SizeType const* batchSlots, cudaStream_t stream, + runtime::SizeType batchSize, int maxBatchSize, runtime::SizeType const* tokensPerStep, + runtime::SizeType maxTokensPerStep, runtime::SizeType maxSeqLen, bool const* skipDecode, bool normalizeLogProbs, + bool logitsHasProbs, bool returnAllTopK); template -[[nodiscard]] std::vector getTopKWorkspaceSizes( - int32_t batchSize, int32_t maxTokensPerStep, int32_t maxTopK, int32_t vocabSizePadded) +[[nodiscard]] std::vector getTopKWorkspaceSizes(runtime::SizeType batchSize, runtime::SizeType maxTokensPerStep, + runtime::SizeType maxTopK, runtime::SizeType vocabSizePadded) { - int32_t constexpr maxBlockPerBeam = 8; + runtime::SizeType constexpr maxBlockPerBeam = 8; auto const tempLogProbsBufSize = sizeof(T) * batchSize * maxTokensPerStep * vocabSizePadded; // type T auto const topKTmpIdsBufSize - = sizeof(int32_t) * batchSize * maxTokensPerStep * maxTopK * maxBlockPerBeam; // type int + = sizeof(runtime::SizeType) * batchSize * maxTokensPerStep * maxTopK * maxBlockPerBeam; // type int auto const topKTmpValBufSize = sizeof(T) * batchSize * maxTokensPerStep * maxTopK * maxBlockPerBeam; // type T return {tempLogProbsBufSize, topKTmpIdsBufSize, topKTmpValBufSize}; @@ -111,8 +123,8 @@ template //! \param maxTopK maximum among all topKs K for topK sampling //! \param vocabSizePadded size of padded vocab template -[[nodiscard]] size_t getTopKWorkspaceSize( - int32_t batchSize, int32_t maxTokensPerStep, int32_t maxTopK, int32_t vocabSizePadded) +[[nodiscard]] size_t getTopKWorkspaceSize(runtime::SizeType batchSize, runtime::SizeType maxTokensPerStep, + runtime::SizeType maxTopK, runtime::SizeType vocabSizePadded) { auto const workspaceSizes = getTopKWorkspaceSizes(batchSize, maxTokensPerStep, maxTopK, vocabSizePadded); return tensorrt_llm::common::calcAlignedSize(workspaceSizes, 256); diff --git a/cpp/tensorrt_llm/kernels/selectiveScan.cu b/cpp/tensorrt_llm/kernels/selectiveScan.cu index 8ecde3a4f..5a71f5169 100644 --- a/cpp/tensorrt_llm/kernels/selectiveScan.cu +++ b/cpp/tensorrt_llm/kernels/selectiveScan.cu @@ -80,7 +80,6 @@ __launch_bounds__(256, 1) __global__ void selective_scan_loop_kernel(SSMParamsBa input_t* z = reinterpret_cast(params.z_ptr); weight_t* dt_bias = reinterpret_cast(params.delta_bias_ptr); bool dt_softplus = params.delta_softplus; - int num_tokens = params.seqlen; int num_channels = params.dim; // static const int STAGES = 12; @@ -102,9 +101,14 @@ __launch_bounds__(256, 1) __global__ void selective_scan_loop_kernel(SSMParamsBa int const channel = blockIdx.x * blockDim.x + threadIdx.x; int const sample = blockIdx.y; // batch id + int num_tokens; + int start_token_idx; + start_token_idx = sample * params.seqlen; + num_tokens = params.last_token_ids_ptr[sample]; + int const seq_loops = (num_tokens + SEQ_UNROLL - 1) / SEQ_UNROLL; - int const input_matrix_row_id = sample * num_tokens; + int const input_matrix_row_id = start_token_idx; if (threadIdx.y == 1) { diff --git a/cpp/tensorrt_llm/kernels/selectiveScan.h b/cpp/tensorrt_llm/kernels/selectiveScan.h index d93f4d407..0f911546a 100644 --- a/cpp/tensorrt_llm/kernels/selectiveScan.h +++ b/cpp/tensorrt_llm/kernels/selectiveScan.h @@ -59,6 +59,7 @@ struct SSMParamsBase void* __restrict__ out_ptr; void* __restrict__ x_ptr; void* __restrict__ z_ptr; + int const* __restrict__ last_token_ids_ptr; }; //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/common.h b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/common.h index 95b8e9a77..db0762351 100644 --- a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/common.h +++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/common.h @@ -18,10 +18,8 @@ #include #include #include -#include -#if defined(ENABLE_BF16) #include -#endif +#include #include #include #include @@ -30,79 +28,74 @@ namespace tensorrt_llm { namespace kernels { -enum class WeightOnlyQuantType +namespace weight_only { - Int4b, - Int8b -}; -enum class WeightOnlyType -{ - PerChannel, - GroupWise -}; - -struct WeightOnlyPerChannel; -template -struct WeightOnlyGroupWise; - -enum class WeightOnlyActivationFunctionType +enum class KernelType { - Gelu, - Relu, - Identity, - InvalidType + FP16Int4Groupwise, + BF16Int4Groupwise, + FP16Int8PerChannel, + BF16Int8PerChannel, + FP16Int4PerChannel, + BF16Int4PerChannel }; -enum class WeightOnlyActivationType -{ - FP16, - BF16, - FP8 -}; +template +struct kernel_type_traits; +#define KERNEL_TYPE_TRAITS_REGISTRY(KT, _isGroupwise, _isInt4) \ + template <> \ + struct kernel_type_traits \ + { \ + static constexpr bool isGroupwise = _isGroupwise; \ + static constexpr bool isInt4 = _isInt4; \ + }; +KERNEL_TYPE_TRAITS_REGISTRY(KernelType::FP16Int4Groupwise, true, true); +KERNEL_TYPE_TRAITS_REGISTRY(KernelType::BF16Int4Groupwise, true, true); +KERNEL_TYPE_TRAITS_REGISTRY(KernelType::FP16Int8PerChannel, false, false); +KERNEL_TYPE_TRAITS_REGISTRY(KernelType::BF16Int8PerChannel, false, false); +KERNEL_TYPE_TRAITS_REGISTRY(KernelType::FP16Int4PerChannel, false, true); +KERNEL_TYPE_TRAITS_REGISTRY(KernelType::BF16Int4PerChannel, false, true); +#undef KERNEL_TYPE_TRAITS_REGISTRY -struct WeightOnlyParams +struct Params { - // ActType is fp16 or bf16 - using ActType = void; - using WeiType = uint8_t; - - uint8_t const* qweight; - ActType const* scales; - ActType const* zeros; - ActType const* in; - ActType const* act_scale; - ActType const* bias; - ActType* out; - int const m; - int const n; - int const k; - int const group_size; - WeightOnlyQuantType quant_type; - WeightOnlyType weight_only_type; - WeightOnlyActivationFunctionType act_func_type; - WeightOnlyActivationType act_type; + using Pointer = void*; + using ConstPointer = void const*; + Pointer act; + Pointer act_scale; + Pointer weight; + Pointer scales; + Pointer zeros; + Pointer bias; + Pointer out; + float alpha; + int m; + int n; + int k; + int groupsize; + KernelType type; + bool apply_alpha_in_advance; - WeightOnlyParams(uint8_t const* _qweight, ActType const* _scales, ActType const* _zeros, ActType const* _in, - ActType const* _act_scale, ActType const* _bias, ActType* _out, int const _m, int const _n, int const _k, - int const _group_size, const WeightOnlyQuantType _quant_type, const WeightOnlyType _weight_only_type, - const WeightOnlyActivationFunctionType _act_func_type, const WeightOnlyActivationType _act_type) - : qweight(_qweight) - , scales(_scales) - , zeros(_zeros) - , in(_in) - , act_scale(_act_scale) - , bias(_bias) + Params(ConstPointer _act, ConstPointer _act_scale, ConstPointer _weight, ConstPointer _scales, ConstPointer _zeros, + ConstPointer _bias, Pointer _out, float _alpha, int _m, int _n, int _k, int _groupsize, KernelType _type, + bool _apply_alpha_in_advance = false) + : act(const_cast(_act)) + , act_scale(const_cast(_act_scale)) + , weight(const_cast(_weight)) + , scales(const_cast(_scales)) + , zeros(const_cast(_zeros)) + , bias(const_cast(_bias)) , out(_out) + , alpha(_alpha) , m(_m) , n(_n) , k(_k) - , group_size(_group_size) - , quant_type(_quant_type) - , weight_only_type(_weight_only_type) - , act_func_type(_act_func_type) - , act_type(_act_type) + , groupsize(_groupsize) + , type(_type) + , apply_alpha_in_advance(_apply_alpha_in_advance) { } }; +} // namespace weight_only } // namespace kernels } // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/converter.h b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/converter.h new file mode 100644 index 000000000..13a226c93 --- /dev/null +++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/converter.h @@ -0,0 +1,83 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/common.h" + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass_extensions/interleaved_numeric_conversion.h" + +namespace tensorrt_llm +{ +namespace kernels +{ +namespace weight_only +{ +template +struct I2FConverter; + +template +struct I2FConverter +{ + static_assert(std::is_same_v || std::is_same_v); + static_assert(WElemBits == 4 || WElemBits == 8); + using CutlassAType = std::conditional_t, cutlass::half_t, cutlass::bfloat16_t>; + using CutlassWType = std::conditional_t; + static constexpr int kConvertCount = 32 / WElemBits; + using Converter = cutlass::FastInterleavedAndBiasedNumericArrayConverter; + using CvtSrcType = typename Converter::source_type; + using CvtResType = typename Converter::result_type; + + template + __device__ __forceinline__ static void convert(void* src, void* dst) + { + static_assert(N % kConvertCount == 0); +#pragma unroll + for (int ii = 0; ii < N / kConvertCount; ++ii) + { + reinterpret_cast(dst)[ii] = Converter::convert(reinterpret_cast(src)[ii]); + } + } +}; + +template +struct I2FConverter +{ + static_assert(std::is_same_v || std::is_same_v); + static_assert(WElemBits == 4 || WElemBits == 8); + using CutlassAType = std::conditional_t, cutlass::half_t, cutlass::bfloat16_t>; + using CutlassWType = std::conditional_t; + static constexpr int kConvertCount = 32 / WElemBits; + using Converter = cutlass::NumericArrayConverter; + using CvtSrcType = typename Converter::source_type; + using CvtResType = typename Converter::result_type; + + template + __device__ __forceinline__ static void convert(void* src, void* dst) + { + static_assert(N % kConvertCount == 0); +#pragma unroll + for (int ii = 0; ii < N / kConvertCount; ++ii) + { + reinterpret_cast(dst)[ii] = Converter::convert(reinterpret_cast(src)[ii]); + } + } +}; + +} // namespace weight_only +} // namespace kernels +} // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/details.h b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/details.h new file mode 100644 index 000000000..fe807d452 --- /dev/null +++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/details.h @@ -0,0 +1,112 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/common.h" + +namespace tensorrt_llm +{ +namespace kernels +{ +namespace weight_only +{ +struct FP16DetailsA +{ + using Type = half; + using Type2 = half2; + static constexpr int kElemBits = 16; +}; + +struct BF16DetailsA +{ + using Type = __nv_bfloat16; + using Type2 = __nv_bfloat162; + static constexpr int kElemBits = 16; +}; + +struct Int8DetailsW +{ + static constexpr int kElemBits = 8; +}; + +struct Int4DetailsW +{ + static constexpr int kElemBits = 4; +}; + +template +struct ColumnMajor +{ + using DetailsA = TypeDetailsA; + using DetailsW = TypeDetailsW; + using AccessTypeA = float4; + using AccessTypeW = int; + static constexpr int kAccessSize = 128; + static constexpr int kStepK = kAccessSize / TypeDetailsA::kElemBits; + static constexpr int kTileSize = 64; + static constexpr int kInterleave = 1; + + struct Mapper + { + __device__ __forceinline__ int operator()(int i) + { + return i; + } + }; +}; + +template +struct ColumnMajorInterleaved +{ + using DetailsA = TypeDetailsA; + using DetailsW = TypeDetailsW; + using AccessTypeA = float4; + using AccessTypeW = int4; + static constexpr int kAccessSize = 128; + static constexpr int kStepK = kAccessSize / TypeDetailsW::kElemBits; + static constexpr int kTileSize = 64; + static constexpr int kInterleave = 128 * 8 / (kTileSize * TypeDetailsW::kElemBits); + + struct Mapper + { + __device__ __forceinline__ int operator()(int i) + { + return (i % 8) / 2 * kInterleave * 2 + i % 2 + i / 8 * 2; + } + }; +}; + +template class LayoutDeatils_, + bool UseInterleavedConverter> +struct KernelDetails +{ + using TypeDetailsA = TypeDetailsA_; + using TypeDetailsW = TypeDetailsW_; + using LayoutDeatils = LayoutDeatils_; + using AccessTypeA = typename LayoutDeatils::AccessTypeA; + using AccessTypeW = typename LayoutDeatils::AccessTypeW; + static constexpr int kWarpSize = 32; + static constexpr int kStepK = LayoutDeatils::kStepK; + static constexpr int kAccessNumA = kStepK * TypeDetailsA::kElemBits / (sizeof(AccessTypeA) * 8); + static constexpr int kAccessNumW = kStepK * TypeDetailsW::kElemBits / (sizeof(AccessTypeW) * 8); + static constexpr int kInterleave = LayoutDeatils::kInterleave; + static constexpr int kThreadsPerInterleavedTile = LayoutDeatils::kTileSize / kStepK; + static constexpr int kElemsPerByteW = 8 / TypeDetailsW::kElemBits; + static constexpr bool kUseInterleavedConverter = UseInterleavedConverter; +}; +} // namespace weight_only +} // namespace kernels +} // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/enabled.h b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/enabled.h deleted file mode 100644 index 1aabeb8ee..000000000 --- a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/enabled.h +++ /dev/null @@ -1,101 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once -#include "tensorrt_llm/common/cudaUtils.h" -#include "tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h" -#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/common.h" - -namespace tensorrt_llm -{ -namespace kernels -{ -template -struct SupportedLayout -{ - static constexpr bool value = false; -}; - -template <> -struct SupportedLayout> -{ - static constexpr bool value = true; -}; - -template <> -struct SupportedLayout> -{ - static constexpr bool value = true; -}; - -template <> -struct SupportedLayout -{ - static constexpr bool value = true; -}; - -template -bool isEnabled() -{ - using Layout = typename cutlass::gemm::kernel::LayoutDetailsB::Layout; - return SupportedLayout::value; -} - -template -bool isEnabledForArch(int arch) -{ - if (arch >= 70 && arch < 75) - { - return false; - } - else if (arch >= 75 && arch < 80) - { - return isEnabled(); - } - else if (arch >= 80 && arch < 90) - { - return isEnabled(); - } - else if (arch >= 90) - { - return isEnabled(); - } - else - { - TLLM_CHECK_WITH_INFO(false, "Unsupported Arch"); - return false; - } -} - -inline bool isWeightOnlyBatchedGemvEnabled(WeightOnlyQuantType qtype) -{ - int const arch = tensorrt_llm::common::getSMVersion(); - if (qtype == WeightOnlyQuantType::Int4b) - { - return isEnabledForArch(arch); - } - else if (qtype == WeightOnlyQuantType::Int8b) - { - return isEnabledForArch(arch); - } - else - { - TLLM_CHECK_WITH_INFO(false, "Unsupported WeightOnlyQuantType"); - return false; - } -} -} // namespace kernels -} // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h index 72eab1a4e..29dfcbdc1 100644 --- a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h +++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h @@ -16,539 +16,135 @@ #pragma once #include "tensorrt_llm/kernels/weightOnlyBatchedGemv/common.h" +#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/converter.h" +#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/details.h" #include "tensorrt_llm/kernels/weightOnlyBatchedGemv/utility.h" namespace tensorrt_llm { namespace kernels { -template -struct ActTypeDetails; - -template <> -struct ActTypeDetails -{ - using CutlassType = cutlass::half_t; - using Vec2 = half2; - - __device__ __forceinline__ static Vec2 to_vec2(half v) - { - return __half2half2(v); - } -}; -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16)) -template <> -struct ActTypeDetails<__nv_bfloat16> -{ - using CutlassType = cutlass::bfloat16_t; - using Vec2 = __nv_bfloat162; - - __device__ __forceinline__ static Vec2 to_vec2(__nv_bfloat16 v) - { - return __bfloat162bfloat162(v); - } -}; -#endif - -template -struct ConverterSelector -{ - static_assert(QType == WeightOnlyQuantType::Int4b || QType == WeightOnlyQuantType::Int8b); - - using WeiType = std::conditional_t; - static constexpr int kConvertCount = QType == WeightOnlyQuantType::Int4b ? 8 : 4; - using Converter - = cutlass::FastInterleavedAndBiasedNumericArrayConverter::CutlassType, WeiType, - kConvertCount>; -}; - -template -struct WeightOnlyDetails; - -template -struct WeightOnlyDetails -{ - // Every four rows of the original weights are interleaved into a row with stride of 64, so if each thread - // processes 32 elements(for int4, we can use ldg.128 to load weights), then every group of two adjacent threads - // will alternately process four different row weights - // for example - // every 256 consecutive int4 elements [256*i, 256*(i+1)-1] of row N under interleave layout, - // the first 64 are from [64*i, 64*(i+1)-1] of row 4N before interleaving, - // and the second 64 are from [64*i, 64*(i+1)-1] of row 4N+1 before interleaving, and so on. - // So if each thread loads 32 int4 elements, then the elements of each 2 adjacent threads of each 8 - // consecutive threads will come from row 4N ~ 4N+3 respectively before interleaving. - static constexpr int kElemBits = 4; - static constexpr int kInterleave = 4; - static constexpr int kStride = 64; - - // The index remapping here is to counteracts the effect of cutlass::permute_B_rows_for_mixed_gemm - // input 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 ... 31 - // weight 0 1 8 9 16 17 24 25 2 3 10 11 18 19 26 27 4 5 12 13 20 21 28 29 6 7 14 15 22 23 30 31 - static constexpr int kShuffleSize = 32; - static constexpr int kShuffleBasicTile = 2; - static constexpr int kShuffleContinuous = 4; - static constexpr int kShuffleStrided = 4; - - // Each warp completes the internal reduce and writes the [Batch * NPerBlock * Interleave] results to the - // corresponding address in shared memory - template - __device__ __forceinline__ static void sync(float* res, float (*sm)[Num * kInterleave]) - { -#pragma unroll - for (int i = 0; i < Num; ++i) - { - res[i] += __shfl_xor_sync(~0, res[i], 16); - res[i] += __shfl_xor_sync(~0, res[i], 8); - res[i] += __shfl_xor_sync(~0, res[i], 1); - } - __syncthreads(); - int warp = threadIdx.x / WarpSize, lane = threadIdx.x % WarpSize; - if (lane == 0 || lane == 2 || lane == 4 || lane == 6) - { -#pragma unroll - for (int i = 0; i < Num; ++i) - { - sm[warp][i * kInterleave + lane / 2] = res[i]; - } - } - __syncthreads(); - } -}; - -template -struct WeightOnlyDetails +namespace weight_only { - // Every two rows of the original weights are interleaved into a row with stride of 64, so if each thread - // processes 16 elements(for int8, we can use ldg.128 to load weights), then every group of four adjacent threads - // will alternately process two different row weights - // for example - // every 128 consecutive int8 elements [128*i, 128*(i+1)-1] of row N under interleave layout, - // the first 64 are from [64*i, 64*(i+1)-1] of row 2N before interleaving, - // and the last 64 are from [64*i, 64*(i+1)-1] of row 2N+1 before interleaving. - // So if each thread loads 16 int8 elements, then the elements of the first four and last four threads of each 8 - // consecutive threads will come from row 2N and row 2N+1 respectively before interleaving. - static constexpr int kElemBits = 8; - static constexpr int kInterleave = 2; - static constexpr int kStride = 64; - - // The index remapping here is to counteracts the effect of cutlass::permute_B_rows_for_mixed_gemm - // input 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 - // weight 0 1 8 9 2 3 10 11 4 5 12 13 6 7 14 15 - static constexpr int kShuffleSize = 16; - static constexpr int kShuffleBasicTile = 2; - static constexpr int kShuffleContinuous = 2; - static constexpr int kShuffleStrided = 4; - - // Each warp completes the internal reduce and writes the [Batch * NPerBlock * Interleave] results to the - // corresponding address in shared memory - template - __device__ __forceinline__ static void sync(float* res, float (*sm)[Num * kInterleave]) - { -#pragma unroll - for (int i = 0; i < Num; ++i) - { - res[i] += __shfl_xor_sync(~0, res[i], 16); - res[i] += __shfl_xor_sync(~0, res[i], 8); - res[i] += __shfl_xor_sync(~0, res[i], 2); - res[i] += __shfl_xor_sync(~0, res[i], 1); - } - __syncthreads(); - int warp = threadIdx.x / WarpSize, lane = threadIdx.x % WarpSize; - if (lane == 0 || lane == 4) - { -#pragma unroll - for (int i = 0; i < Num; ++i) - { - sm[warp][i * kInterleave + lane / 4] = res[i]; - } - } - __syncthreads(); - } -}; - -template -struct WeightOnlyKernelDetails +template +__global__ void kernel(TypeA* act, TypeA* act_scale, uint8_t* weight, TypeA* scales, TypeA* zeros, TypeA* bias, + TypeA* out, float alpha, int m, int n, int k) { - using Layout = WeightOnlyDetails; - - static constexpr int kElemBits = Layout::kElemBits; - static constexpr int kInterleave = Layout::kInterleave; - static constexpr int kStride = Layout::kStride; - - static constexpr int kShuffleSize = Layout::kShuffleSize; - static constexpr int kShuffleBasicTile = Layout::kShuffleBasicTile; - static constexpr int kShuffleContinuous = Layout::kShuffleContinuous; - static constexpr int kShuffleStrided = Layout::kShuffleStrided; - - // The rearrangement here counteracts the effect of cutlass::add_bias_and_interleave_int4/8s_inplace - // Input int8 data layout - // [elt_3 elt_1 elt_2 elt_0] (each elt occupies 8 bits) - // - // Converted fp16/bf16 data layout - // [elt_3 elt_2 elt_1 elt_0] (each elt occupies 16 bits) - - // Input int8 data layout - // [elt_7 elt_5 elt_3 elt_1 elt_6 elt_4 elt_2 elt_0] (each elt occupies 4 bits) + // clang-format off + // ArgType ArgName DataType Shape Layout // - // Converted fp16/bf16 data layout - // [elt_7 elt_6 elt_5 elt_4 elt_3 elt_2 elt_1 elt_0] (each elt occupies 16 bits) - static constexpr int kConvertCount = ConverterSelector::kConvertCount; - using Converter = typename ConverterSelector::Converter; - - // Use ldg128 load data from global memory - static constexpr int kAccessSize = 128; - using AccessType = uint4; - - static constexpr int kElemsPerByte = 8 / kElemBits; - static constexpr int kElemsPerThread = kAccessSize / kElemBits; - static constexpr int kBytePerThread = kElemsPerThread / kElemsPerByte; - static constexpr int kThreadsNumPerTile = kStride / kElemsPerThread; - static constexpr int kThreadsNumPerInterleave = kThreadsNumPerTile * kInterleave; - - static constexpr int kConvertIters = kElemsPerThread / kConvertCount; - - // Each thread loads 16(int8b)/32(int4b) quantized weight elements each time through ldg128 - // So more times of ldg128 are needed to load the same number of fp16/bf16 activation elements. - static constexpr int kActivationElemNumPerAccess = kAccessSize / (sizeof(ActType) * 8); - static constexpr int kActivationAccessNum = kElemsPerThread / kActivationElemNumPerAccess; -}; - -template -struct WeightOnlyProperties; - -template <> -struct WeightOnlyProperties -{ - static constexpr bool kIsFineGrained = false; - static constexpr int kGroupSize = 0; -}; - -template -struct WeightOnlyProperties> -{ - static constexpr bool kIsFineGrained = true; - static constexpr int kGroupSize = GS; -}; - -template -struct WeightOnlyScaleLoader -{ - using ElemType = ActType; - using Details = WeightOnlyKernelDetails; - static constexpr bool kIsFineGrained = WeightOnlyProperties::kIsFineGrained; - static constexpr int kGroupSize = WeightOnlyProperties::kGroupSize; - -private: - ElemType const* _scales; - ElemType const* _zeros; - int _stride; - int _offset; - -public: - __device__ __forceinline__ WeightOnlyScaleLoader( - ElemType const* scales, ElemType const* zeros, int initial_offset, int stride) - : _scales(scales) - , _zeros(zeros) - , _stride(stride) - { - _scales += initial_offset; - if constexpr (Zero) - { - _zeros += initial_offset; - } - // Calculate the k dimension index of the element processed by the current thread of layout before interleave - // Used to load scales and zeros in groupwise weight only quant - _offset = threadIdx.x / Details::kThreadsNumPerInterleave * Details::kStride - + (threadIdx.x % Details::kThreadsNumPerTile) * Details::kElemsPerThread; - } - - __device__ __forceinline__ void load(ElemType& scale, ElemType& zero, int nid) - { - int offset = nid * Details::kInterleave; - if constexpr (kIsFineGrained) - { - offset += _offset / kGroupSize * _stride; - } - scale = _scales[offset]; - if constexpr (Zero) - { - zero = _zeros[offset]; - } - else - { - zero = static_cast(0.f); - } - } - - __device__ __forceinline__ void advance() + // input act fp16/bf16 [m, k] RowMajor + // input act_scale fp16/bf16 [1, k] RowMajor + // input weight int4b/int8b [k, n] ColumnMajor or ColumnMajorInterleaved + // input scales fp16/bf16 [k / GroupSize, n] or [1, n] RowMajor + // input zeros fp16/bf16 [k / GroupSize, n] or [1, n] RowMajor + // input bias fp16/bf16 [1, n] RowMajor + // output out fp16/bf16 [m, n] RowMajor + // clang-format on + using AccessTypeA = typename Details::AccessTypeA; + using AccessTypeW = typename Details::AccessTypeW; + + static constexpr bool Mandatory = true; + static constexpr int StepK = Details::kStepK; + static constexpr int CtaK = StepK * Threads; + static_assert(CtaN % 2 == 0); + if constexpr (GroupSize != 0) { - _offset += BlockSize * Details::kElemsPerThread / Details::kInterleave; + static_assert((CtaK / Details::kInterleave) % GroupSize == 0); } - __device__ __forceinline__ int offset() + int const origin_k = k, interleaved_k = k * Details::kInterleave; + + int const tile_id_m = blockIdx.x, tile_id_n = blockIdx.y, tid = threadIdx.x; + int const offset_m = tile_id_m * CtaM, interleaved_offset_n = tile_id_n * CtaN; + int const real_offset_n = interleaved_offset_n * Details::kInterleave + + ((tid * StepK / Details::LayoutDeatils::kTileSize) % Details::kInterleave); + int const real_offset_k + = (tid * StepK / (Details::kInterleave * Details::LayoutDeatils::kTileSize)) * Details::LayoutDeatils::kTileSize + + ((tid * StepK) % Details::LayoutDeatils::kTileSize); + + GMemIterator act_iterator( + act, offset_m * origin_k + real_offset_k, CtaK / Details::kInterleave, origin_k); + GMemIterator act_scale_iterator( + act_scale, real_offset_k, CtaK / Details::kInterleave, 0); + GMemIterator weight_iterator(weight, + (interleaved_offset_n * interleaved_k + tid * StepK) / Details::kElemsPerByteW, CtaK / Details::kElemsPerByteW, + interleaved_k / Details::kElemsPerByteW); + GMemIterator scales_iterator(scales, + (GroupSize != 0 ? real_offset_k / GroupSize * n : 0) + real_offset_n, + (GroupSize != 0 ? CtaK / Details::kInterleave / GroupSize * n : 0), Details::kInterleave); + GMemIterator zeros_iterator(zeros, + (GroupSize != 0 ? real_offset_k / GroupSize * n : 0) + real_offset_n, + (GroupSize != 0 ? CtaK / Details::kInterleave / GroupSize * n : 0), Details::kInterleave); + + out += offset_m * n + tile_id_n * CtaN * Details::kInterleave; + if constexpr (EnableBias) { - return _offset; + bias += tile_id_n * CtaN * Details::kInterleave; } -}; - -template class ActOp, - bool Zero, bool Bias, bool ActScale, int NPerBlock, int Batch, int BlockSize> -__device__ void weight_only_batched_gemv(uint8_t const* qweight, ActType const* scales, ActType const* zeros, - ActType const* in, ActType const* act_scale, ActType const* bias, ActType* out, int const n, int const k) -{ - static_assert(NPerBlock == 1 || (NPerBlock % 2 == 0)); - using ActType2 = typename ActTypeDetails::Vec2; - using Details = WeightOnlyKernelDetails; - - using Converter = typename Details::Converter; - using AccType = typename Details::AccessType; - using CvtSrcType = typename Converter::source_type; - using CvtResType = typename Converter::result_type; - using ScaleLoader = WeightOnlyScaleLoader; - extern __shared__ uint8_t shmem[]; - constexpr int Interleave = Details::kInterleave; - constexpr int WarpSize = 32; - constexpr int Num = Batch * NPerBlock; - int const tid = threadIdx.x; - int const bid = blockIdx.x; - int const n_start_id = bid * NPerBlock * Interleave; - // Calculate the n-dimensional index of the data processed by the current thread in the interleave tile - int const interleave_n_id = (tid / Details::kThreadsNumPerTile) % Interleave; - qweight += n_start_id * k / Details::kElemsPerByte; - ScaleLoader scale_loader(scales, zeros, n_start_id + interleave_n_id, n); + TypeA tile_acc[CtaM * CtaN]; + fill(tile_acc, static_cast(0.f)); - float(*sm)[Num * Interleave] = reinterpret_cast(shmem); - - // In order to take advantage of hfma2, we use fp16/bf16 for accumulation within threads and fp32 for accumulation - // between threads. - ActType accumulator[Num]; - for (int i = 0; i < Num; ++i) - { - accumulator[i] = static_cast(0.f); - } - - // Iteration in k dimensions - for (int local_k = tid * Details::kElemsPerThread; local_k < k * Interleave; - local_k += BlockSize * Details::kElemsPerThread) + for (int idx_k = tid * StepK, iter = 0; idx_k < interleaved_k; idx_k += CtaK, ++iter) { - ActType weights_f16[Details::kElemsPerThread * NPerBlock]; - ActType scale[NPerBlock], zero[NPerBlock]; -#pragma unroll - for (int idx = 0; idx < NPerBlock; ++idx) - { - // Load quantized weight and scales/zeros - uint8_t weights_quantized[Details::kBytePerThread]; - load(weights_quantized, - qweight + idx * Interleave * k / Details::kElemsPerByte + local_k / Details::kElemsPerByte); - scale_loader.load(scale[idx], zero[idx], idx); - ActType weights_vec[Details::kElemsPerThread]; + TypeA vec_act_scale[StepK]; + TypeA vec_scale[CtaN], vec_zero[CtaN]; + TypeA tile_a[StepK], tile_w[StepK], tile_w_pack2[CtaN * StepK]; + uint8_t tile_w_quantized[StepK / Details::kElemsPerByteW]; #pragma unroll - for (int i = 0; i < Details::kConvertIters; ++i) - { - // Use cutlass::FastInterleavedAndBiasedNumericArrayConverter for I2F type conversion - assign(weights_vec + i * Details::kConvertCount, - Converter::convert(*reinterpret_cast( - weights_quantized + i * Details::kConvertCount / Details::kElemsPerByte))); - } -#pragma unroll - for (int i = 0; i < Details::kShuffleContinuous; ++i) - { -#pragma unroll - for (int j = 0; j < Details::kShuffleStrided; ++j) - { - // Dequantize the weights and arrange the shuffled elements back to the correct order in the - // register array - ActType2 v = *reinterpret_cast(weights_vec + i * Details::kShuffleBasicTile - + j * Details::kShuffleContinuous * Details::kShuffleBasicTile); - v = __hfma2( - v, ActTypeDetails::to_vec2(scale[idx]), ActTypeDetails::to_vec2(zero[idx])); - weights_f16[(i * Details::kShuffleStrided * Details::kShuffleBasicTile - + j * Details::kShuffleBasicTile + 0) - * NPerBlock - + idx] - = v.x; - weights_f16[(i * Details::kShuffleStrided * Details::kShuffleBasicTile - + j * Details::kShuffleBasicTile + 1) - * NPerBlock - + idx] - = v.y; - } - } - } - ActType act_scale_v[Details::kElemsPerThread]; - if constexpr (ActScale) + for (int i = 0; i < CtaN; ++i) { -#pragma unroll - for (int idx = 0; idx < Details::kActivationAccessNum; ++idx) - { - load(act_scale_v + idx * Details::kActivationElemNumPerAccess, - act_scale + scale_loader.offset() + idx * Details::kActivationElemNumPerAccess); - } + scales_iterator.load(vec_scale + i, iter, i); + zeros_iterator.load(vec_zero + i, iter, i); } + act_scale_iterator.load(vec_act_scale, iter); #pragma unroll - for (int b = 0; b < Batch; ++b) + for (int i = 0; i < CtaN; ++i) { - ActType in_v[Details::kElemsPerThread]; -#pragma unroll - for (int idx = 0; idx < Details::kActivationAccessNum; ++idx) - { - // load activation elements - load(in_v + idx * Details::kActivationElemNumPerAccess, - in + b * k + scale_loader.offset() + idx * Details::kActivationElemNumPerAccess); - if constexpr (ActScale) - { -#pragma unroll - for (int i = 0; i < Details::kActivationElemNumPerAccess; i += 2) - { - *reinterpret_cast(in_v + idx * Details::kActivationElemNumPerAccess + i) = __hmul2( - *reinterpret_cast(in_v + idx * Details::kActivationElemNumPerAccess + i), - *reinterpret_cast(act_scale_v + idx * Details::kActivationElemNumPerAccess + i)); - } - } - } - // Perform vector inner product and accumulate - if constexpr (NPerBlock == 1) - { - ActType2 v = ActTypeDetails::to_vec2(static_cast(0.f)); -#pragma unroll - for (int y = 0; y < Details::kElemsPerThread; y += 2) - { - v = __hfma2( - *reinterpret_cast(weights_f16 + y), *reinterpret_cast(in_v + y), v); - } - accumulator[b] += __hadd(v.x, v.y); - } - else - { -#pragma unroll - for (int x = 0; x < NPerBlock / 2; ++x) - { -#pragma unroll - for (int y = 0; y < Details::kElemsPerThread; ++y) - { - *reinterpret_cast(accumulator + b * NPerBlock + x * 2) - = __hfma2(*reinterpret_cast(weights_f16 + y * NPerBlock + x * 2), - ActTypeDetails::to_vec2(in_v[y]), - *reinterpret_cast(accumulator + b * NPerBlock + x * 2)); - } - } - } + weight_iterator.load(tile_w_quantized, iter, i); + dequantize( + tile_w, tile_w_quantized, vec_scale + i, vec_zero + i, alpha); + pack_to_vec2(tile_w_pack2, tile_w, i); } - scale_loader.advance(); - } - float reses[Num]; #pragma unroll - for (int i = 0; i < Num; ++i) - { - reses[i] = static_cast(accumulator[i]); - } - - // Each warp completes the internal reduce and writes the [Batch * NPerBlock * Interleave] results to the - // corresponding address in shared memory - Details::Layout::sync(reses, sm); - - // Each thread is responsible for the accumulation and store to global memory of one element - for (int i = tid; i < Num * Interleave; i += BlockSize) - { - int nid = i % (NPerBlock * Interleave); - float v = 0.f; - for (int j = 0; j < BlockSize / WarpSize; ++j) - { - v += sm[j][i]; - } - float bias_v = 0.f; - if constexpr (Bias) + for (int i = 0; i < CtaM; ++i) { - bias_v = static_cast(bias[n_start_id + nid]); + act_iterator.load(tile_a, iter, i); + apply_scale(tile_a, vec_act_scale); + mma(tile_acc + i * CtaN, tile_w_pack2, tile_a); } - int b = i / NPerBlock / Interleave; - out[b * n + n_start_id + nid] = static_cast(ActOp::apply(v + bias_v)); } + epilogue(out, n, tile_acc, bias, alpha); } -template class ActOp, - bool Zero, bool Bias, bool ActScale, int NPerBlock, int Batch, int BlockSize> -__global__ void weight_only_batched_gemv_wrapper(uint8_t const* qweight, ActType const* scales, ActType const* zeros, - ActType const* in, ActType const* act_scale, ActType const* bias, ActType* out, int const n, int const k) +template +void exec_kernel(Params& params, cudaStream_t s) { - if constexpr (std::is_same_v) + using T = typename Details::TypeDetailsA::Type; + if (params.m % CtaM || params.n % (CtaN * Details::kInterleave)) { - weight_only_batched_gemv(qweight, scales, zeros, in, act_scale, bias, out, n, k); + throw std::runtime_error("launch failed"); } -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16)) - else if (std::is_same_v) - { - weight_only_batched_gemv(qweight, scales, zeros, in, act_scale, bias, out, n, k); - } -#endif + dim3 grid(params.m / CtaM, params.n / (CtaN * Details::kInterleave)); + dim3 block(Threads); + // clang-format off + kernel<<>>( + reinterpret_cast(params.act), + reinterpret_cast(params.act_scale), + reinterpret_cast(params.weight), + reinterpret_cast(params.scales), + reinterpret_cast(params.zeros), + reinterpret_cast(params.bias), + reinterpret_cast(params.out), + params.alpha, + params.m, params.n, params.k + ); + // clang-format on } -template class ActOp, bool Zero, bool Bias, - int NPerBlock, int Batch, int BlockSize> -struct WeightOnlyBatchedGemvKernelLauncher -{ - static void run(WeightOnlyParams const& params, cudaStream_t stream) - { - if (params.act_type == WeightOnlyActivationType::FP16) - { - constexpr int kInterleave = WeightOnlyDetails::kInterleave; - dim3 grid(params.n / NPerBlock / kInterleave); - dim3 block(BlockSize); - int size = sizeof(float) * BlockSize / 32 * Batch * NPerBlock * kInterleave; - if (params.act_scale != nullptr) - { - weight_only_batched_gemv_wrapper<<>>(params.qweight, - reinterpret_cast(params.scales), reinterpret_cast(params.zeros), - reinterpret_cast(params.in), reinterpret_cast(params.act_scale), - reinterpret_cast(params.bias), reinterpret_cast(params.out), params.n, - params.k); - } - else - { - weight_only_batched_gemv_wrapper<<>>(params.qweight, - reinterpret_cast(params.scales), reinterpret_cast(params.zeros), - reinterpret_cast(params.in), reinterpret_cast(params.act_scale), - reinterpret_cast(params.bias), reinterpret_cast(params.out), params.n, - params.k); - } - } -#if defined(ENABLE_BF16) - else if (params.act_type == WeightOnlyActivationType::BF16) - { - constexpr int kInterleave = WeightOnlyDetails::kInterleave; - dim3 grid(params.n / NPerBlock / kInterleave); - dim3 block(BlockSize); - int size = sizeof(float) * BlockSize / 32 * Batch * NPerBlock * kInterleave; - if (params.act_scale != nullptr) - { - weight_only_batched_gemv_wrapper<__nv_bfloat16, QType, WeightOnlyFlag, ActOp, Zero, Bias, true, - NPerBlock, Batch, BlockSize><<>>(params.qweight, - reinterpret_cast<__nv_bfloat16 const*>(params.scales), - reinterpret_cast<__nv_bfloat16 const*>(params.zeros), - reinterpret_cast<__nv_bfloat16 const*>(params.in), - reinterpret_cast<__nv_bfloat16 const*>(params.act_scale), - reinterpret_cast<__nv_bfloat16 const*>(params.bias), reinterpret_cast<__nv_bfloat16*>(params.out), - params.n, params.k); - } - else - { - weight_only_batched_gemv_wrapper<__nv_bfloat16, QType, WeightOnlyFlag, ActOp, Zero, Bias, false, - NPerBlock, Batch, BlockSize><<>>(params.qweight, - reinterpret_cast<__nv_bfloat16 const*>(params.scales), - reinterpret_cast<__nv_bfloat16 const*>(params.zeros), - reinterpret_cast<__nv_bfloat16 const*>(params.in), - reinterpret_cast<__nv_bfloat16 const*>(params.act_scale), - reinterpret_cast<__nv_bfloat16 const*>(params.bias), reinterpret_cast<__nv_bfloat16*>(params.out), - params.n, params.k); - } - } -#endif - } -}; +} // namespace weight_only } // namespace kernels } // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcher.h b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcher.h new file mode 100644 index 000000000..7c3a5fffe --- /dev/null +++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcher.h @@ -0,0 +1,151 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/common.h" +#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h" + +namespace tensorrt_llm +{ +namespace kernels +{ +namespace weight_only +{ +// TODO: +// Using a mechanism similar to the gemm config profiler, dynamically search for the optimal configuration during the +// build engine process. +template +void dispatcher(Params& params, cudaStream_t s) +{ +#define DISPATCHER_FOR_M(target_m, CtaM, CtaN, Threads) \ + do \ + { \ + if (params.m == target_m) \ + { \ + exec_kernel(params, s); \ + return; \ + } \ + } while (0); + if constexpr (EnableZero) + { + // clang-format off + DISPATCHER_FOR_M(1, 1, 4, 128); + DISPATCHER_FOR_M(2, 2, 4, 128); + DISPATCHER_FOR_M(3, 3, 4, 128); + DISPATCHER_FOR_M(4, 4, 4, 128); + // clang-format on + } + else + { + // clang-format off + DISPATCHER_FOR_M(1, 1, 8, 128); + DISPATCHER_FOR_M(2, 2, 8, 128); + DISPATCHER_FOR_M(3, 3, 8, 128); + DISPATCHER_FOR_M(4, 4, 8, 128); + // clang-format on + } + throw std::runtime_error("unsupported m"); +#undef DISPATCHER_FOR_M +} + +template +void check_alpha(Params& params, cudaStream_t s) +{ + if (params.apply_alpha_in_advance && params.alpha != 1.f) + { + dispatcher(params, s); + } + else + { + dispatcher(params, s); + } +} + +template +void check_pointer(Params& params, cudaStream_t s) +{ + if constexpr (GroupSize == 0) + { + check_alpha(params, s); + } + else + { + if (params.act_scale && params.zeros && params.bias) + { + check_alpha(params, s); + } + else if (params.act_scale && params.zeros && !params.bias) + { + check_alpha(params, s); + } + else if (params.act_scale && !params.zeros && params.bias) + { + check_alpha(params, s); + } + else if (!params.act_scale && params.zeros && params.bias) + { + check_alpha(params, s); + } + else if (!params.act_scale && !params.zeros && params.bias) + { + check_alpha(params, s); + } + else if (params.act_scale && !params.zeros && !params.bias) + { + check_alpha(params, s); + } + else if (!params.act_scale && params.zeros && !params.bias) + { + check_alpha(params, s); + } + else + { + check_alpha(params, s); + } + } +} + +template +void select_gs(Params& params, cudaStream_t s) +{ + if constexpr (isGroupwise) + { + if (params.groupsize == 64) + { + check_pointer(params, s); + } + else if (params.groupsize == 128) + { + check_pointer(params, s); + } + } + else + { + if (params.groupsize == 0) + { + check_pointer(params, s); + } + } +} + +#define INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS(KType, A, B, Layout, ConverterInterleave) \ + template void select_gs::isGroupwise, KernelDetails>( \ + Params & params, cudaStream_t s); +} // namespace weight_only +} // namespace kernels +} // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/sm90/kernelLauncher.h b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcherBf16Int4GroupwiseColumnMajorFalse.cu similarity index 70% rename from cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/sm90/kernelLauncher.h rename to cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcherBf16Int4GroupwiseColumnMajorFalse.cu index 44bf7c5fa..8b42b4e60 100644 --- a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/sm90/kernelLauncher.h +++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcherBf16Int4GroupwiseColumnMajorFalse.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,8 +14,7 @@ * limitations under the License. */ -#pragma once -#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/sm90/common.h" +#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcher.h" namespace tensorrt_llm { @@ -23,8 +22,7 @@ namespace kernels { namespace weight_only { - -void kernel_launcher(Params& params, cudaStream_t s); -} +INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS(KernelType::BF16Int4Groupwise, BF16DetailsA, Int4DetailsW, ColumnMajor, false); +} // namespace weight_only } // namespace kernels } // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcherBf16Int4GroupwiseColumnMajorInterleavedTrue.cu b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcherBf16Int4GroupwiseColumnMajorInterleavedTrue.cu new file mode 100644 index 000000000..285f7bdf3 --- /dev/null +++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcherBf16Int4GroupwiseColumnMajorInterleavedTrue.cu @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcher.h" + +namespace tensorrt_llm +{ +namespace kernels +{ +namespace weight_only +{ +INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS( + KernelType::BF16Int4Groupwise, BF16DetailsA, Int4DetailsW, ColumnMajorInterleaved, true); +} // namespace weight_only +} // namespace kernels +} // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcherBf16Int4PerChannelColumnMajorFalse.cu b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcherBf16Int4PerChannelColumnMajorFalse.cu new file mode 100644 index 000000000..50baec563 --- /dev/null +++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcherBf16Int4PerChannelColumnMajorFalse.cu @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcher.h" + +namespace tensorrt_llm +{ +namespace kernels +{ +namespace weight_only +{ +INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS( + KernelType::BF16Int4PerChannel, BF16DetailsA, Int4DetailsW, ColumnMajor, false); +} // namespace weight_only +} // namespace kernels +} // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcherBf16Int4PerChannelColumnMajorInterleavedTrue.cu b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcherBf16Int4PerChannelColumnMajorInterleavedTrue.cu new file mode 100644 index 000000000..415241edf --- /dev/null +++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcherBf16Int4PerChannelColumnMajorInterleavedTrue.cu @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcher.h" + +namespace tensorrt_llm +{ +namespace kernels +{ +namespace weight_only +{ +INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS( + KernelType::BF16Int4PerChannel, BF16DetailsA, Int4DetailsW, ColumnMajorInterleaved, true); +} // namespace weight_only +} // namespace kernels +} // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcherBf16Int8PerChannelColumnMajorFalse.cu b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcherBf16Int8PerChannelColumnMajorFalse.cu new file mode 100644 index 000000000..ee0d4fce1 --- /dev/null +++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcherBf16Int8PerChannelColumnMajorFalse.cu @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcher.h" + +namespace tensorrt_llm +{ +namespace kernels +{ +namespace weight_only +{ +INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS( + KernelType::BF16Int8PerChannel, BF16DetailsA, Int8DetailsW, ColumnMajor, false); +} // namespace weight_only +} // namespace kernels +} // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcherBf16Int8PerChannelColumnMajorInterleavedTrue.cu b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcherBf16Int8PerChannelColumnMajorInterleavedTrue.cu new file mode 100644 index 000000000..0f3e1b6a4 --- /dev/null +++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcherBf16Int8PerChannelColumnMajorInterleavedTrue.cu @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcher.h" + +namespace tensorrt_llm +{ +namespace kernels +{ +namespace weight_only +{ +INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS( + KernelType::BF16Int8PerChannel, BF16DetailsA, Int8DetailsW, ColumnMajorInterleaved, true); +} // namespace weight_only +} // namespace kernels +} // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcherFp16Int4GroupwiseColumnMajorFalse.cu b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcherFp16Int4GroupwiseColumnMajorFalse.cu new file mode 100644 index 000000000..3360fea3c --- /dev/null +++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcherFp16Int4GroupwiseColumnMajorFalse.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcher.h" + +namespace tensorrt_llm +{ +namespace kernels +{ +namespace weight_only +{ +INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS(KernelType::FP16Int4Groupwise, FP16DetailsA, Int4DetailsW, ColumnMajor, false); +} // namespace weight_only +} // namespace kernels +} // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcherFp16Int4GroupwiseColumnMajorInterleavedTrue.cu b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcherFp16Int4GroupwiseColumnMajorInterleavedTrue.cu new file mode 100644 index 000000000..282eb8a9c --- /dev/null +++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcherFp16Int4GroupwiseColumnMajorInterleavedTrue.cu @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcher.h" + +namespace tensorrt_llm +{ +namespace kernels +{ +namespace weight_only +{ +INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS( + KernelType::FP16Int4Groupwise, FP16DetailsA, Int4DetailsW, ColumnMajorInterleaved, true); +} // namespace weight_only +} // namespace kernels +} // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcherFp16Int4PerChannelColumnMajorFalse.cu b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcherFp16Int4PerChannelColumnMajorFalse.cu new file mode 100644 index 000000000..83a4f9dda --- /dev/null +++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcherFp16Int4PerChannelColumnMajorFalse.cu @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcher.h" + +namespace tensorrt_llm +{ +namespace kernels +{ +namespace weight_only +{ +INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS( + KernelType::FP16Int4PerChannel, FP16DetailsA, Int4DetailsW, ColumnMajor, false); +} // namespace weight_only +} // namespace kernels +} // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcherFp16Int4PerChannelColumnMajorInterleavedTrue.cu b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcherFp16Int4PerChannelColumnMajorInterleavedTrue.cu new file mode 100644 index 000000000..0dfce78e9 --- /dev/null +++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcherFp16Int4PerChannelColumnMajorInterleavedTrue.cu @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcher.h" + +namespace tensorrt_llm +{ +namespace kernels +{ +namespace weight_only +{ +INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS( + KernelType::FP16Int4PerChannel, FP16DetailsA, Int4DetailsW, ColumnMajorInterleaved, true); +} // namespace weight_only +} // namespace kernels +} // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcherFp16Int4PerChannelColumnMajorTrue.cu b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcherFp16Int4PerChannelColumnMajorTrue.cu new file mode 100644 index 000000000..48ddf72d1 --- /dev/null +++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcherFp16Int4PerChannelColumnMajorTrue.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcher.h" + +namespace tensorrt_llm +{ +namespace kernels +{ +namespace weight_only +{ +INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS(KernelType::FP16Int4PerChannel, FP16DetailsA, Int4DetailsW, ColumnMajor, true); +} // namespace weight_only +} // namespace kernels +} // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcherFp16Int8PerChannelColumnMajorFalse.cu b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcherFp16Int8PerChannelColumnMajorFalse.cu new file mode 100644 index 000000000..b9c964858 --- /dev/null +++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcherFp16Int8PerChannelColumnMajorFalse.cu @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcher.h" + +namespace tensorrt_llm +{ +namespace kernels +{ +namespace weight_only +{ +INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS( + KernelType::FP16Int8PerChannel, FP16DetailsA, Int8DetailsW, ColumnMajor, false); +} // namespace weight_only +} // namespace kernels +} // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcherFp16Int8PerChannelColumnMajorInterleavedTrue.cu b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcherFp16Int8PerChannelColumnMajorInterleavedTrue.cu new file mode 100644 index 000000000..3231efd99 --- /dev/null +++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcherFp16Int8PerChannelColumnMajorInterleavedTrue.cu @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcher.h" + +namespace tensorrt_llm +{ +namespace kernels +{ +namespace weight_only +{ +INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS( + KernelType::FP16Int8PerChannel, FP16DetailsA, Int8DetailsW, ColumnMajorInterleaved, true); +} // namespace weight_only +} // namespace kernels +} // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcherFp16Int8PerChannelColumnMajorTrue.cu b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcherFp16Int8PerChannelColumnMajorTrue.cu new file mode 100644 index 000000000..030878565 --- /dev/null +++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcherFp16Int8PerChannelColumnMajorTrue.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcher.h" + +namespace tensorrt_llm +{ +namespace kernels +{ +namespace weight_only +{ +INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS(KernelType::FP16Int8PerChannel, FP16DetailsA, Int8DetailsW, ColumnMajor, true); +} // namespace weight_only +} // namespace kernels +} // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelLauncher.cu b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelLauncher.cu deleted file mode 100644 index 75ef3f158..000000000 --- a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelLauncher.cu +++ /dev/null @@ -1,234 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/common.h" -#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/utility.h" - -namespace tensorrt_llm -{ -namespace kernels -{ -template class ActOp, bool Zero, bool Bias, - int N_PER_BLOCK, int BATCH, int BLOCK_SIZE> -struct WeightOnlyBatchedGemvKernelLauncher -{ - static void run(WeightOnlyParams const& params, cudaStream_t stream); -}; - -template class ActOp, int N_PER_BLOCK, - int BATCH, int BLOCK_SIZE> -void select_zero_bias(WeightOnlyParams const& params, cudaStream_t stream) -{ - if (params.zeros && params.bias) - { - WeightOnlyBatchedGemvKernelLauncher::run(params, stream); - } - else if (params.zeros && !params.bias) - { - WeightOnlyBatchedGemvKernelLauncher::run(params, stream); - } - else if (!params.zeros && params.bias) - { - WeightOnlyBatchedGemvKernelLauncher::run(params, stream); - } - else - { - WeightOnlyBatchedGemvKernelLauncher::run(params, stream); - } -} - -template -void select_activation(WeightOnlyParams const& params, cudaStream_t stream) -{ - switch (params.act_func_type) - { - // Currently, activation function is not called in the plugin -#if 0 - case WeightOnlyActivationFunctionType::Gelu: - { - select_zero_bias(params, stream); - break; - } - case WeightOnlyActivationFunctionType::Relu: - { - select_zero_bias(params, stream); - break; - } -#endif - case WeightOnlyActivationFunctionType::Identity: - { - select_zero_bias(params, stream); - break; - } - default: - { - throw std::runtime_error("Use unsupported activation"); - break; - } - } -} - -template -void select_quant_type(WeightOnlyParams const& params, cudaStream_t stream) -{ - if (params.quant_type == WeightOnlyQuantType::Int4b) - { - select_activation(params, stream); - } - else if (params.quant_type == WeightOnlyQuantType::Int8b) - { - select_activation(params, stream); - } - else - { - throw std::runtime_error("Unknown QuantType"); - } -} - -template -void select_groupwise_weight_only(WeightOnlyParams const& params, cudaStream_t stream) -{ - if (params.weight_only_type == WeightOnlyType::GroupWise && params.group_size == 64) - { - select_quant_type, N_PER_BLOCK, BATCH, BLOCK_SIZE>(params, stream); - } - else if (params.weight_only_type == WeightOnlyType::GroupWise && params.group_size == 128) - { - select_quant_type, N_PER_BLOCK, BATCH, BLOCK_SIZE>(params, stream); - } - else - { - throw std::runtime_error("Only support groupwise weight only for gs=64/128"); - } -} - -void weight_only_batched_gemv_launcher(WeightOnlyParams const& params, cudaStream_t stream) -{ - assert(params.act_func_type == WeightOnlyActivationFunctionType::Identity); - assert(params.weight_only_type == WeightOnlyType::GroupWise - || (params.weight_only_type == WeightOnlyType::PerChannel && params.bias == nullptr - && params.zeros == nullptr)); - if (params.weight_only_type == WeightOnlyType::PerChannel) - { - if (params.quant_type == WeightOnlyQuantType::Int4b) - { - switch (params.m) - { - case 1: - { - WeightOnlyBatchedGemvKernelLauncher::run(params, stream); - break; - } - case 2: - { - WeightOnlyBatchedGemvKernelLauncher::run(params, stream); - break; - } - case 3: - { - WeightOnlyBatchedGemvKernelLauncher::run(params, stream); - break; - } - case 4: - { - WeightOnlyBatchedGemvKernelLauncher::run(params, stream); - break; - } - default: - { - throw std::runtime_error("Weight only cuda kernel only supported bs <= 4"); - break; - } - } - } - else if (params.quant_type == WeightOnlyQuantType::Int8b) - { - switch (params.m) - { - case 1: - { - WeightOnlyBatchedGemvKernelLauncher::run(params, stream); - break; - } - case 2: - { - WeightOnlyBatchedGemvKernelLauncher::run(params, stream); - break; - } - case 3: - { - WeightOnlyBatchedGemvKernelLauncher::run(params, stream); - break; - } - case 4: - { - WeightOnlyBatchedGemvKernelLauncher::run(params, stream); - break; - } - default: - { - throw std::runtime_error("Weight only cuda kernel only supported bs <= 4"); - break; - } - } - } - } - else if (params.weight_only_type == WeightOnlyType::GroupWise) - { - switch (params.m) - { - case 1: - { - select_groupwise_weight_only<2, 1, 256>(params, stream); - break; - } - case 2: - { - select_groupwise_weight_only<2, 2, 256>(params, stream); - break; - } - case 3: - { - select_groupwise_weight_only<2, 3, 128>(params, stream); - break; - } - case 4: - { - select_groupwise_weight_only<2, 4, 128>(params, stream); - break; - } - default: - { - throw std::runtime_error("Weight only cuda kernel only supported bs <= 4"); - break; - } - } - } -} -} // namespace kernels -} // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelLauncher.h b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelLauncher.h index 65498c612..75e2e3a89 100644 --- a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelLauncher.h +++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelLauncher.h @@ -13,13 +13,99 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + #pragma once +#include "tensorrt_llm/common/cudaUtils.h" #include "tensorrt_llm/kernels/weightOnlyBatchedGemv/common.h" +#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/details.h" namespace tensorrt_llm { namespace kernels { -void weight_only_batched_gemv_launcher(WeightOnlyParams const& params, cudaStream_t stream); +namespace weight_only +{ +template +void select_gs(Params& params, cudaStream_t s); + +inline void kernel_launcher(int arch, Params& params, cudaStream_t s) +{ +#define EXEC(KType, A, B, Layout, ConverterInterleave) \ + if (params.type == KType) \ + { \ + select_gs::isGroupwise, KernelDetails>( \ + params, s); \ + return; \ + } + if (arch >= 70 && arch < 75) + { + EXEC(KernelType::FP16Int8PerChannel, FP16DetailsA, Int8DetailsW, ColumnMajor, true); + EXEC(KernelType::FP16Int4PerChannel, FP16DetailsA, Int4DetailsW, ColumnMajor, true); + } + else if (arch >= 75 && arch < 80) + { + EXEC(KernelType::FP16Int4Groupwise, FP16DetailsA, Int4DetailsW, ColumnMajorInterleaved, true); + EXEC(KernelType::FP16Int8PerChannel, FP16DetailsA, Int8DetailsW, ColumnMajorInterleaved, true); + EXEC(KernelType::FP16Int4PerChannel, FP16DetailsA, Int4DetailsW, ColumnMajorInterleaved, true); + } + else if (arch >= 80 && arch < 90) + { + EXEC(KernelType::FP16Int4Groupwise, FP16DetailsA, Int4DetailsW, ColumnMajorInterleaved, true); + EXEC(KernelType::BF16Int4Groupwise, BF16DetailsA, Int4DetailsW, ColumnMajorInterleaved, true); + EXEC(KernelType::FP16Int8PerChannel, FP16DetailsA, Int8DetailsW, ColumnMajorInterleaved, true); + EXEC(KernelType::BF16Int8PerChannel, BF16DetailsA, Int8DetailsW, ColumnMajorInterleaved, true); + EXEC(KernelType::FP16Int4PerChannel, FP16DetailsA, Int4DetailsW, ColumnMajorInterleaved, true); + EXEC(KernelType::BF16Int4PerChannel, BF16DetailsA, Int4DetailsW, ColumnMajorInterleaved, true); + } + else if (arch >= 90) + { + EXEC(KernelType::FP16Int4Groupwise, FP16DetailsA, Int4DetailsW, ColumnMajor, false); + EXEC(KernelType::BF16Int4Groupwise, BF16DetailsA, Int4DetailsW, ColumnMajor, false); + EXEC(KernelType::FP16Int8PerChannel, FP16DetailsA, Int8DetailsW, ColumnMajor, false); + EXEC(KernelType::BF16Int8PerChannel, BF16DetailsA, Int8DetailsW, ColumnMajor, false); + EXEC(KernelType::FP16Int4PerChannel, FP16DetailsA, Int4DetailsW, ColumnMajor, false); + EXEC(KernelType::BF16Int4PerChannel, BF16DetailsA, Int4DetailsW, ColumnMajor, false); + } +#undef EXEC +} + +inline bool is_supported(int arch, KernelType kernel_type) +{ +#define SUPPORT(Type) \ + if (kernel_type == Type) \ + return true; + if (arch >= 70 && arch < 75) + { + SUPPORT(KernelType::FP16Int8PerChannel); + SUPPORT(KernelType::FP16Int4PerChannel); + } + else if (arch >= 75 && arch < 80) + { + SUPPORT(KernelType::FP16Int4Groupwise); + SUPPORT(KernelType::FP16Int8PerChannel); + SUPPORT(KernelType::FP16Int4PerChannel); + } + else if (arch >= 80 && arch < 90) + { + SUPPORT(KernelType::FP16Int4Groupwise); + SUPPORT(KernelType::BF16Int4Groupwise); + SUPPORT(KernelType::FP16Int8PerChannel); + SUPPORT(KernelType::BF16Int8PerChannel); + SUPPORT(KernelType::FP16Int4PerChannel); + SUPPORT(KernelType::BF16Int4PerChannel); + } + else if (arch >= 90) + { + SUPPORT(KernelType::FP16Int4Groupwise); + SUPPORT(KernelType::BF16Int4Groupwise); + SUPPORT(KernelType::FP16Int8PerChannel); + SUPPORT(KernelType::BF16Int8PerChannel); + SUPPORT(KernelType::FP16Int4PerChannel); + SUPPORT(KernelType::BF16Int4PerChannel); + } + return false; +#undef SUPPORT } +} // namespace weight_only +} // namespace kernels } // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/sm90/common.h b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/sm90/common.h deleted file mode 100644 index b5fbb6da7..000000000 --- a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/sm90/common.h +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once -#include -#include -#include -#include -#include -#include -#include -#include - -namespace tensorrt_llm -{ -namespace kernels -{ -namespace weight_only -{ -enum class KernelType -{ - W4A16, - W4A8 -}; - -struct Params -{ - void* act; - void* act_scale; - void* weight; - void* scales; - void* zeros; - void* bias; - void* out; - float alpha; - int m; - int n; - int k; - int groupsize; - KernelType type; - - Params(void* _act, void* _act_scale, void* _weight, void* _scales, void* _zeros, void* _bias, void* _out, - float _alpha, int _m, int _n, int _k, int _groupsize, KernelType _type) - : act(_act) - , act_scale(_act_scale) - , weight(_weight) - , scales(_scales) - , zeros(_zeros) - , bias(_bias) - , out(_out) - , alpha(_alpha) - , m(_m) - , n(_n) - , k(_k) - , groupsize(_groupsize) - , type(_type) - { - } -}; -} // namespace weight_only -} // namespace kernels -} // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/sm90/kernel.h b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/sm90/kernel.h deleted file mode 100644 index 560a7cfc1..000000000 --- a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/sm90/kernel.h +++ /dev/null @@ -1,391 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once -#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/sm90/common.h" - -namespace tensorrt_llm -{ -namespace kernels -{ -namespace weight_only -{ - -struct ConverterI4ToF16 -{ - __device__ __forceinline__ static void convert(uint32_t& src, uint4& dst) - { - uint32_t* r = reinterpret_cast(&dst); - uint32_t prmt_indices[4] = {0x4040, 0x4141, 0x4242, 0x4343}; -#pragma unroll - for (int ii = 0; ii < 4; ++ii) - { - asm volatile( - "{\n" - " prmt.b32 %0, %1, %2, %3;\n" - "}\n" - : "=r"(r[ii]) - : "r"(src), "n"(0), "r"(prmt_indices[ii])); - } - static constexpr uint32_t xor_mask = 0x64806408; - static constexpr uint32_t and_mask = 0xFFF0FF0F; - static constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa; -#pragma unroll - for (int ii = 0; ii < 4; ++ii) - { - asm volatile( - "{\n" - " lop3.b32 %0, %0, %1, %2, %3;\n" - "}\n" - : "+r"(r[ii]) - : "n"(and_mask), "n"(xor_mask), "n"(immLut)); - } - static constexpr uint32_t hfma_bias_rep = 0xD480E408; - static constexpr uint32_t hfma_scale_rep = 0x2C003C00; - - half2 const& hfma_bias = reinterpret_cast(hfma_bias_rep); - half2 const& hfma_scale = reinterpret_cast(hfma_scale_rep); -#pragma unroll - for (int ii = 0; ii < 4; ++ii) - { - __half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii]); - fp16x2_val = __hfma2(hfma_scale, fp16x2_val, hfma_bias); - } - } - - template - __device__ __forceinline__ static void convert(void* src, void* dst) - { - static_assert(N == 8 || N == 16); - convert(reinterpret_cast(src)[0], reinterpret_cast(dst)[0]); - if constexpr (N == 16) - { - convert(reinterpret_cast(src)[1], reinterpret_cast(dst)[1]); - } - } -}; - -template -__device__ __forceinline__ void load(void* dst, TSrc* src, int stride) -{ - if constexpr (Enable) - { -#pragma unroll - for (int ii = 0; ii < N; ++ii) - { - reinterpret_cast(dst)[ii] = reinterpret_cast(src + ii * stride)[0]; - } - } -} - -template -__device__ __forceinline__ void apply_scale(void* act, void* act_scale) -{ - static_assert(K % 2 == 0); - static constexpr int VecK = K / 2; - if constexpr (Enable) - { - half2* pa = reinterpret_cast(act); - half2* pb = reinterpret_cast(act_scale); -#pragma unroll - for (int m = 0; m < M; ++m) - { -#pragma unroll - for (int k = 0; k < VecK; ++k) - { - pa[m * VecK + k] = __hmul2(pa[m * VecK + k], pb[k]); - } - } - } -} - -template -__device__ __forceinline__ void dequantize(void* w, void* quantized_w, void* scales, void* zeros, half alpha) -{ - using Converter = ConverterI4ToF16; - static_assert(K % 2 == 0); - static constexpr int VecK = K / 2; -#pragma unroll - for (int n = 0; n < N; ++n) - { - ConverterI4ToF16::convert( - reinterpret_cast(quantized_w) + n * K / 2, reinterpret_cast(w) + n * K); - half2 vec_scale = __half2half2(reinterpret_cast(scales)[n] * alpha); - half2 vec_zero = __half2half2(__float2half_rn(0.f)); - if constexpr (EnableZero) - { - vec_zero = __half2half2(reinterpret_cast(zeros)[n] * alpha); - } -#pragma unroll - for (int k = 0; k < VecK; ++k) - { - reinterpret_cast(w)[n * VecK + k] - = __hfma2(reinterpret_cast(w)[n * VecK + k], vec_scale, vec_zero); - } - } -} - -template -__device__ __forceinline__ void pack_to_vec2(void* dst, void* src) -{ -#pragma unroll - for (int n = 0; n < N; n += 2) - { -#pragma unroll - for (int k = 0; k < K; ++k) - { - reinterpret_cast(dst)[n * K + k * 2] = reinterpret_cast(src)[n * K + k]; - reinterpret_cast(dst)[n * K + k * 2 + 1] = reinterpret_cast(src)[(n + 1) * K + k]; - } - } -} - -template -__device__ __forceinline__ void mma(void* acc, void* w_pack2, void* act) -{ - static_assert(N % 2 == 0); - static constexpr int VecN = N / 2; -#pragma unroll - for (int m = 0; m < M; ++m) - { -#pragma unroll - for (int n = 0; n < VecN; ++n) - { -#pragma unroll - for (int k = 0; k < K; ++k) - { - reinterpret_cast(acc)[m * VecN + n] = __hfma2(reinterpret_cast(w_pack2)[n * K + k], - __half2half2(reinterpret_cast(act)[m * K + k]), reinterpret_cast(acc)[m * VecN + n]); - } - } - } -} - -template -__device__ __forceinline__ T warp_reduce_sum(T& val) -{ - val += __shfl_xor_sync(~0, val, 16); - val += __shfl_xor_sync(~0, val, 8); - val += __shfl_xor_sync(~0, val, 4); - val += __shfl_xor_sync(~0, val, 2); - val += __shfl_xor_sync(~0, val, 1); - return val; -} - -template -__device__ __forceinline__ void epilogue(void* out, int stride, void* tile_acc, void* bias) -{ - static constexpr int WarpSize = 32; - static constexpr int WarpNum = Threads / WarpSize; - static constexpr int AlignShmemSize = (CtaM * CtaN + 31) / 32 * 32; - static_assert(Threads % WarpSize == 0); - __shared__ float shmem[AlignShmemSize * WarpNum]; - int tid = threadIdx.x; - int warp_id = tid / WarpSize, lane_id = tid % WarpSize; -#pragma unroll - for (int m = 0; m < CtaM; ++m) - { -#pragma unroll - for (int n = 0; n < CtaN; ++n) - { - float v = __half2float(reinterpret_cast(tile_acc)[m * CtaN + n]); - v = warp_reduce_sum(v); - if (lane_id == 0) - { - shmem[warp_id * AlignShmemSize + m * CtaN + n] = v; - } - } - } - __syncthreads(); -#pragma unroll - for (int ii = tid; ii < CtaM * CtaN; ii += Threads) - { - int m = ii / CtaN, n = ii % CtaN; - float val = 0.f, v_bias = 0.f; - if constexpr (EnableBias) - { - v_bias = static_cast(reinterpret_cast(bias)[n]); - } -#pragma unroll - for (int jj = 0; jj < WarpNum; ++jj) - { - val += shmem[jj * AlignShmemSize + ii]; - } - reinterpret_cast(out)[m * stride + n] = __float2half_rn(val + v_bias); - } -} - -template -__device__ __forceinline__ void fill(void* tile, half v) -{ -#pragma unroll - for (int ii = 0; ii < N; ++ii) - { - reinterpret_cast(tile)[ii] = v; - } -} - -struct Fp16Details -{ - using ActDataType = half; - static constexpr int StepK = 8; - using AccessTypeAct = float4; - using AccessTypeActScale = float4; - using AccessTypeW = float; - - template - __device__ __forceinline__ static void load_act(void* dst, void* src, int stride) - { - load(dst, reinterpret_cast(src), stride); - } -}; - -struct Fp8Details -{ - using ActDataType = __nv_fp8_e4m3; - static constexpr int StepK = 8; - using AccessTypeAct = float2; - using AccessTypeActScale = float4; - using AccessTypeW = float; - - __device__ __forceinline__ static void conversion(void* dst, void* src) - { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) -#pragma unroll - for (int ii = 0; ii < StepK / 4; ++ii) - { - asm volatile( - "{\n" - ".reg .b16 lo, hi;\n" - "mov.b32 {lo, hi}, %2;\n" - "cvt.rn.f16x2.e4m3x2 %0, lo;\n" - "cvt.rn.f16x2.e4m3x2 %1, hi;\n" - "}\n" - : "=r"(reinterpret_cast(dst)[ii * 2]), "=r"(reinterpret_cast(dst)[ii * 2 + 1]) - : "r"(reinterpret_cast(src)[ii])); - } -#else -#pragma unroll - for (int ii = 0; ii < StepK; ++ii) - { - reinterpret_cast(dst)[ii] = static_cast(reinterpret_cast(src)[ii]); - } -#endif - } - - template - __device__ __forceinline__ static void load_act(void* dst, void* src, int stride) - { - ActDataType vec[CtaM * StepK]; - load(vec, reinterpret_cast(src), stride); -#pragma unroll - for (int ii = 0; ii < CtaM; ++ii) - { - conversion(reinterpret_cast(dst) + ii * StepK, vec + ii * StepK); - } - } -}; - -template -__global__ void kernel(typename Details::ActDataType* act, half* act_scale, uint8_t* weight, half* scales, half* zeros, - half* bias, half* out, float alpha, int m, int n, int k) -{ - // ArgType ArgName DataType Shape Layout - // - // input act fp16 [m, k] RowMajor - // input act_scale fp16 [1, k] RowMajor - // input weight int4b [k, n] ColumnMajor - // input scales fp16 [k / GroupSize, n] RowMajor - // input zeros fp16 [k / GroupSize, n] RowMajor - // input bias fp16 [1, n] RowMajor - // output out fp16 [m, n] RowMajor - - using AccessTypeActScale = typename Details::AccessTypeActScale; - using AccessTypeW = typename Details::AccessTypeW; - static constexpr int StepK = Details::StepK; - - static constexpr bool Mandatory = true; - static constexpr int CtaK = StepK * Threads; - static_assert(CtaN % 2 == 0); - - int const m_tile_id = blockIdx.x, n_tile_id = blockIdx.y, tid = threadIdx.x; - int const m_offset = m_tile_id * CtaM, n_offset = n_tile_id * CtaN; - - act += m_offset * k; - weight += n_offset * k / 2; - scales += n_offset; - zeros += n_offset; - bias += n_offset; - out += m_offset * n + n_offset; - - half tile_a[StepK * CtaM], tile_w[StepK * CtaN], tile_w_pack2[StepK * CtaN]; - half tile_acc[CtaM * CtaN]; - fill(tile_acc, __float2half_rn(0.f)); - - for (int idx_k = tid * StepK; idx_k < k; idx_k += CtaK) - { - half vec_act_scale[StepK]; - half vec_scale[CtaN], vec_zero[CtaN]; - uint8_t tile_w_quantized[StepK * CtaN / 2]; - // Load Data - Details::load_act(tile_a, act + idx_k, k); - load(vec_act_scale, act_scale + idx_k, 0); - load(tile_w_quantized, weight + idx_k / 2, k / 2); - load(vec_scale, scales + idx_k / GroupSize * n, 1); - load(vec_zero, zeros + idx_k / GroupSize * n, 1); - // Dequantize Data - // W4A8 checkpoints have larger activation and weight values. In order to prevent the warp-level FP16 - // accumulator from overflow, the multiplication of alpha is moved from epilogue to dequantize - apply_scale(tile_a, vec_act_scale); - dequantize(tile_w, tile_w_quantized, vec_scale, vec_zero, __float2half_rn(alpha)); - // Rearrange - pack_to_vec2(tile_w_pack2, tile_w); - // MMA - mma(tile_acc, tile_w_pack2, tile_a); - } - // Epilogue - epilogue(out, n, tile_acc, bias); -} - -template -void exec_kernel(Params& params, cudaStream_t s) -{ - if (params.m % CtaM || params.n % CtaN) - { - throw std::runtime_error("launch failed"); - } - dim3 grid(params.m / CtaM, params.n / CtaN); - dim3 block(Threads); - // clang-format off - kernel<<>>( - reinterpret_cast(params.act), - reinterpret_cast(params.act_scale), - reinterpret_cast(params.weight), - reinterpret_cast(params.scales), - reinterpret_cast(params.zeros), - reinterpret_cast(params.bias), - reinterpret_cast(params.out), - params.alpha, - params.m, params.n, params.k - ); - // clang-format on -} - -} // namespace weight_only -} // namespace kernels -} // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/sm90/kernelLauncher.cu b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/sm90/kernelLauncher.cu deleted file mode 100644 index 0a4953945..000000000 --- a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/sm90/kernelLauncher.cu +++ /dev/null @@ -1,123 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/sm90/kernel.h" -#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/sm90/kernelLauncher.h" - -namespace tensorrt_llm -{ -namespace kernels -{ -namespace weight_only -{ -#define DISPATCHER_FOR_M(target_m, CtaM, CtaN, Threads) \ - do \ - { \ - if (params.m == target_m) \ - { \ - exec_kernel(params, s); \ - return; \ - } \ - } while (0); - -template -void dispatcher(Params& params, cudaStream_t s) -{ - // clang-format off - DISPATCHER_FOR_M(1, 1, 8, 128); - DISPATCHER_FOR_M(2, 2, 4, 128); - DISPATCHER_FOR_M(3, 3, 16, 128); - DISPATCHER_FOR_M(4, 4, 16, 128); - DISPATCHER_FOR_M(5, 5, 16, 128); - DISPATCHER_FOR_M(6, 6, 16, 128); - DISPATCHER_FOR_M(7, 7, 16, 128); - DISPATCHER_FOR_M(8, 8, 16, 128); - DISPATCHER_FOR_M(9, 9, 8, 128); - DISPATCHER_FOR_M(10, 10, 8, 128); - DISPATCHER_FOR_M(11, 11, 8, 128); - DISPATCHER_FOR_M(12, 12, 8, 128); - DISPATCHER_FOR_M(13, 13, 8, 128); - DISPATCHER_FOR_M(14, 14, 8, 128); - DISPATCHER_FOR_M(15, 15, 8, 128); - DISPATCHER_FOR_M(16, 16, 8, 128); - // clang-format on - throw std::runtime_error("unsupported m"); -} - -template -void check_pointer(Params& params, cudaStream_t s) -{ - if (params.act_scale && params.zeros && params.bias) - { - dispatcher(params, s); - } - else if (params.act_scale && params.zeros && !params.bias) - { - dispatcher(params, s); - } - else if (params.act_scale && !params.zeros && params.bias) - { - dispatcher(params, s); - } - else if (!params.act_scale && params.zeros && params.bias) - { - dispatcher(params, s); - } - else if (!params.act_scale && !params.zeros && params.bias) - { - dispatcher(params, s); - } - else if (params.act_scale && !params.zeros && !params.bias) - { - dispatcher(params, s); - } - else if (!params.act_scale && params.zeros && !params.bias) - { - dispatcher(params, s); - } - else - { - dispatcher(params, s); - } -} - -template -void select_gs(Params& params, cudaStream_t s) -{ - if (params.groupsize == 64) - { - check_pointer(params, s); - } - else if (params.groupsize == 128) - { - check_pointer(params, s); - } -} - -void kernel_launcher(Params& params, cudaStream_t s) -{ - if (params.type == KernelType::W4A16) - { - select_gs(params, s); - } - else if (params.type == KernelType::W4A8) - { - select_gs(params, s); - } -} -} // namespace weight_only -} // namespace kernels -} // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/utility.h b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/utility.h index 60d92fc5b..ddaf34cd1 100644 --- a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/utility.h +++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/utility.h @@ -15,85 +15,297 @@ */ #pragma once -#include -#include -#include -#include -#include -#include -#include - -#include "cutlass/cutlass.h" -#include "cutlass_extensions/interleaved_numeric_conversion.h" +#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/common.h" +#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/details.h" namespace tensorrt_llm { namespace kernels { +namespace weight_only +{ +template +struct ConverterWrapper +{ + using TypeDetailsA = typename Details::TypeDetailsA; + using TypeDetailsW = typename Details::TypeDetailsW; + static constexpr bool kUseInterleavedConverter = Details::kUseInterleavedConverter; + using Converter = I2FConverter; +}; -__forceinline__ __device__ float copysignf_pos(float a, float b) +template +struct MathWrapper { - float r; - r = __int_as_float(__float_as_int(a) | (__float_as_int(b) & 0x80000000)); - return r; -} +}; -__inline__ __device__ float tanh_opt(float x) +template <> +struct MathWrapper { -#if (__CUDA_ARCH__ >= 750 && CUDART_VERSION >= 11000) - float r; - asm("tanh.approx.f32 %0,%1; \n\t" : "=f"(r) : "f"(x)); - return r; + using Type = typename FP16DetailsA::Type; + using Type2 = typename FP16DetailsA::Type2; + + __device__ __forceinline__ static Type2 to_vec2(Type const& v) + { + return __half2half2(v); + } + + __device__ __forceinline__ static Type2 fma2(Type2 const& a, Type2 const& b, Type2 const& c) + { + return __hfma2(a, b, c); + } + + __device__ __forceinline__ static Type2 mul2(Type2 const& a, Type2 const& b) + { + return __hmul2(a, b); + } +}; + +template <> +struct MathWrapper + +{ + using Type = typename BF16DetailsA::Type; + using Type2 = typename BF16DetailsA::Type2; + + __device__ __forceinline__ static Type2 to_vec2(Type const& v) + { +#if ((defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) && defined(ENABLE_BF16)) + return __bfloat162bfloat162(v); +#else + uint32_t val = 0; + Type2 ret = reinterpret_cast(val); + return ret; +#endif + } + + __device__ __forceinline__ static Type2 fma2(Type2 const& a, Type2 const& b, Type2 const& c) + { +#if ((defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) && defined(ENABLE_BF16)) + return __hfma2(a, b, c); +#else + return to_vec2(static_cast(0.f)); +#endif + } + + __device__ __forceinline__ static Type2 mul2(Type2 const& a, Type2 const& b) + { +#if ((defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) && defined(ENABLE_BF16)) + return __hmul2(a, b); #else - float const exp_val = -1.f * fabs(2 * x); - return copysignf_pos((1.0f - __expf(exp_val)) / (__expf(exp_val) + 1.0f), x); + return to_vec2(static_cast(0.f)); #endif + } +}; + +template +__device__ __forceinline__ void apply_scale(void* act, void* act_scale) +{ + using Type2 = typename MathWrapper::Type2; + static_assert(K % 2 == 0); + static constexpr int VecK = K / 2; + if constexpr (Enable) + { + Type2* pa = reinterpret_cast(act); + Type2* pb = reinterpret_cast(act_scale); +#pragma unroll + for (int m = 0; m < M; ++m) + { +#pragma unroll + for (int k = 0; k < VecK; ++k) + { + pa[m * VecK + k] = MathWrapper::mul2(pa[m * VecK + k], pb[k]); + } + } + } } -template -struct GeluActivation +template +__device__ __forceinline__ void dequantize(void* w, void* quantized_w, void* scales, void* zeros, float alpha) { - static __device__ __forceinline__ T apply(T const& val) + using Type = typename MathWrapper::Type; + using Type2 = typename MathWrapper::Type2; + using Converter = typename ConverterWrapper
::Converter; + static_assert(K % 2 == 0); + static constexpr int VecK = K / 2; +#pragma unroll + for (int n = 0; n < N; ++n) { - float const cdf = 0.5f * (1.0f + tanh_opt((0.7978845608028654f * (val + 0.044715f * val * val * val)))); - return val * cdf; + Converter::convert(reinterpret_cast(quantized_w) + n * K / Details::kElemsPerByteW, + reinterpret_cast(w) + n * K); + Type2 vec_scale, vec_zero; + if constexpr (ApplyAlphaInAdvance) + { + vec_scale = MathWrapper::to_vec2( + reinterpret_cast(scales)[n] * static_cast(alpha)); + vec_zero = MathWrapper::to_vec2(static_cast(0.f)); + if constexpr (EnableZero) + { + vec_zero = MathWrapper::to_vec2( + reinterpret_cast(zeros)[n] * static_cast(alpha)); + } + } + else + { + vec_scale = MathWrapper::to_vec2(reinterpret_cast(scales)[n]); + vec_zero = MathWrapper::to_vec2(static_cast(0.f)); + if constexpr (EnableZero) + { + vec_zero = MathWrapper::to_vec2(reinterpret_cast(zeros)[n]); + } + } +#pragma unroll + for (int k = 0; k < VecK; ++k) + { + reinterpret_cast(w)[n * VecK + k] = MathWrapper::fma2( + reinterpret_cast(w)[n * VecK + k], vec_scale, vec_zero); + } } -}; +} -template -struct ReluActivation +template +__device__ __forceinline__ void pack_to_vec2(void* dst, void* src, int n) { - static __device__ __forceinline__ T apply(T const& val) + using Type = typename MathWrapper::Type; + typename Details::LayoutDeatils::Mapper mapper; + int n0 = n & ~0x1, n1 = n & 0x1; + for (int k = 0; k < K; ++k) { - return val > static_cast(0.0f) ? val : static_cast(0.0f); + int physical_idx = mapper(k); + reinterpret_cast(dst)[n0 * K + k * 2 + n1] = reinterpret_cast(src)[physical_idx]; } -}; +} -template -struct IdentityActivation +template +__device__ __forceinline__ void mma(void* acc, void* w_pack2, void* act) { - static __device__ __forceinline__ T apply(T const& val) + using Type = typename MathWrapper::Type; + using Type2 = typename MathWrapper::Type2; + static_assert(N % 2 == 0); + static constexpr int VecN = N / 2; +#pragma unroll + for (int m = 0; m < M; ++m) { - return val; +#pragma unroll + for (int n = 0; n < VecN; ++n) + { +#pragma unroll + for (int k = 0; k < K; ++k) + { + reinterpret_cast(acc)[m * VecN + n] + = MathWrapper::fma2(reinterpret_cast(w_pack2)[n * K + k], + MathWrapper::to_vec2(reinterpret_cast(act)[m * K + k]), + reinterpret_cast(acc)[m * VecN + n]); + } + } } -}; +} -template -__device__ __forceinline__ void load(T0* dst, T1* src, size_t offset = 0) +template +__device__ __forceinline__ T warp_reduce_sum(T& val) { - *reinterpret_cast(dst) = *(reinterpret_cast(src) + offset); + static_assert(Interleave * ThreadsPerInterleavedTile == 8); + val += __shfl_xor_sync(~0, val, 16); + val += __shfl_xor_sync(~0, val, 8); + if (Interleave != 2 && Interleave != 4) + val += __shfl_xor_sync(~0, val, 4); + if (Interleave != 4) + val += __shfl_xor_sync(~0, val, 2); + val += __shfl_xor_sync(~0, val, 1); + return val; } -template -__device__ __forceinline__ void assign(T* dst, AssignType const& val) +template +__device__ __forceinline__ void epilogue(void* out, int stride, void* tile_acc, void* bias, float alpha) { - *reinterpret_cast(dst) = val; + using Type = typename MathWrapper::Type; + static constexpr int Interleave = Details::kInterleave; + static constexpr int ThreadsPerInterleavedTile = Details::kThreadsPerInterleavedTile; + static constexpr int WarpSize = Details::kWarpSize; + static constexpr int WarpNum = Threads / WarpSize; + static_assert(Threads % WarpSize == 0); + __shared__ float shmem[CtaM * CtaN * Interleave * WarpNum]; + int tid = threadIdx.x; + int warp_id = tid / WarpSize, lane_id = tid % WarpSize; +#pragma unroll + for (int m = 0; m < CtaM; ++m) + { +#pragma unroll + for (int n = 0; n < CtaN; ++n) + { + float v = static_cast(reinterpret_cast(tile_acc)[m * CtaN + n]); + v = warp_reduce_sum(v); + if (lane_id < Interleave * ThreadsPerInterleavedTile && lane_id % ThreadsPerInterleavedTile == 0) + { + shmem[warp_id * CtaM * CtaN * Interleave + m * CtaN * Interleave + n * Interleave + + lane_id / ThreadsPerInterleavedTile] + = v; + } + } + } + __syncthreads(); +#pragma unroll + for (int ii = tid; ii < CtaM * CtaN * Interleave; ii += Threads) + { + int m = ii / (CtaN * Interleave), n = ii % (CtaN * Interleave); + float val = 0.f, v_bias = 0.f; + if constexpr (EnableBias) + { + v_bias = static_cast(reinterpret_cast(bias)[n]); + } +#pragma unroll + for (int jj = 0; jj < WarpNum; ++jj) + { + val += shmem[jj * CtaM * CtaN * Interleave + ii]; + } + if constexpr (ApplyAlphaInAdvance) + { + reinterpret_cast(out)[m * stride + n] = static_cast(val + v_bias); + } + else + { + reinterpret_cast(out)[m * stride + n] = static_cast(alpha * val + v_bias); + } + } } -template -__device__ __forceinline__ void store(T0* src, T1* dst, size_t offset = 0) +template +__device__ __forceinline__ void fill(void* tile, T v) { - *(reinterpret_cast(dst) + offset) = *reinterpret_cast(src); +#pragma unroll + for (int ii = 0; ii < N; ++ii) + { + reinterpret_cast(tile)[ii] = v; + } } + +template +class GMemIterator +{ +public: + __device__ __forceinline__ GMemIterator(T* addr, int offset, int step, int stride) + : addr_(Enable ? (addr + offset) : nullptr) + , step_(step) + , stride_(stride) + { + } + + __device__ __forceinline__ void load(void* dst, int iter, int ii = 0) + { + if constexpr (Enable) + { +#pragma unroll + for (int jj = 0; jj < Continuous; ++jj) + { + reinterpret_cast(dst)[jj] = reinterpret_cast(addr_ + iter * step_ + ii * stride_)[jj]; + } + } + } + +private: + T* addr_; + int step_; + int stride_; +}; +} // namespace weight_only } // namespace kernels } // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs1Int4b.cu b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs1Int4b.cu deleted file mode 100644 index 1f8a5d175..000000000 --- a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs1Int4b.cu +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h" - -namespace tensorrt_llm -{ -namespace kernels -{ - -template struct WeightOnlyBatchedGemvKernelLauncher; - -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, true, true, 2, 1, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, true, false, 2, 1, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, false, true, 2, 1, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, false, false, 2, 1, 256>; - -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, true, true, 2, 1, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, true, false, 2, 1, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, false, true, 2, 1, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, false, false, 2, 1, 256>; - -} // namespace kernels -} // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs1Int8b.cu b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs1Int8b.cu deleted file mode 100644 index 94d0339d9..000000000 --- a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs1Int8b.cu +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h" - -namespace tensorrt_llm -{ -namespace kernels -{ - -template struct WeightOnlyBatchedGemvKernelLauncher; - -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, true, true, 2, 1, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, true, false, 2, 1, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, false, true, 2, 1, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, false, false, 2, 1, 256>; - -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, true, true, 2, 1, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, true, false, 2, 1, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, false, true, 2, 1, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, false, false, 2, 1, 256>; - -} // namespace kernels -} // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs2Int4b.cu b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs2Int4b.cu deleted file mode 100644 index 07b15e77c..000000000 --- a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs2Int4b.cu +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h" - -namespace tensorrt_llm -{ -namespace kernels -{ - -template struct WeightOnlyBatchedGemvKernelLauncher; - -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, true, true, 2, 2, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, true, false, 2, 2, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, false, true, 2, 2, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, false, false, 2, 2, 256>; - -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, true, true, 2, 2, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, true, false, 2, 2, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, false, true, 2, 2, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, false, false, 2, 2, 256>; - -} // namespace kernels -} // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs2Int8b.cu b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs2Int8b.cu deleted file mode 100644 index 913ffe82b..000000000 --- a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs2Int8b.cu +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h" - -namespace tensorrt_llm -{ -namespace kernels -{ - -template struct WeightOnlyBatchedGemvKernelLauncher; - -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, true, true, 2, 2, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, true, false, 2, 2, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, false, true, 2, 2, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, false, false, 2, 2, 256>; - -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, true, true, 2, 2, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, true, false, 2, 2, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, false, true, 2, 2, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, false, false, 2, 2, 256>; - -} // namespace kernels -} // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs3Int4b.cu b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs3Int4b.cu deleted file mode 100644 index 6b10d355b..000000000 --- a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs3Int4b.cu +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h" - -namespace tensorrt_llm -{ -namespace kernels -{ - -template struct WeightOnlyBatchedGemvKernelLauncher; - -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, true, true, 2, 3, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, true, false, 2, 3, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, false, true, 2, 3, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, false, false, 2, 3, 128>; - -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, true, true, 2, 3, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, true, false, 2, 3, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, false, true, 2, 3, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, false, false, 2, 3, 128>; - -} // namespace kernels -} // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs3Int8b.cu b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs3Int8b.cu deleted file mode 100644 index 269d963e5..000000000 --- a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs3Int8b.cu +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h" - -namespace tensorrt_llm -{ -namespace kernels -{ - -template struct WeightOnlyBatchedGemvKernelLauncher; - -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, true, true, 2, 3, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, true, false, 2, 3, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, false, true, 2, 3, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, false, false, 2, 3, 128>; - -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, true, true, 2, 3, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, true, false, 2, 3, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, false, true, 2, 3, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, false, false, 2, 3, 128>; - -} // namespace kernels -} // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs4Int4b.cu b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs4Int4b.cu deleted file mode 100644 index 01236bbc7..000000000 --- a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs4Int4b.cu +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h" - -namespace tensorrt_llm -{ -namespace kernels -{ - -template struct WeightOnlyBatchedGemvKernelLauncher; - -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, true, true, 2, 4, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, true, false, 2, 4, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, false, true, 2, 4, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, false, false, 2, 4, 128>; - -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, true, true, 2, 4, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, true, false, 2, 4, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, false, true, 2, 4, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, false, false, 2, 4, 128>; - -} // namespace kernels -} // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs4Int8b.cu b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs4Int8b.cu deleted file mode 100644 index 8c4e870e4..000000000 --- a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs4Int8b.cu +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h" - -namespace tensorrt_llm -{ -namespace kernels -{ - -template struct WeightOnlyBatchedGemvKernelLauncher; - -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, true, true, 2, 4, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, true, false, 2, 4, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, false, true, 2, 4, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, false, false, 2, 4, 128>; - -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, true, true, 2, 4, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, true, false, 2, 4, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, false, true, 2, 4, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, false, false, 2, 4, 128>; - -} // namespace kernels -} // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/layers/baseBeamSearchLayer.cu b/cpp/tensorrt_llm/layers/baseBeamSearchLayer.cu index 2aeec3864..97f2d4a03 100644 --- a/cpp/tensorrt_llm/layers/baseBeamSearchLayer.cu +++ b/cpp/tensorrt_llm/layers/baseBeamSearchLayer.cu @@ -79,8 +79,8 @@ void update_indir_cache_kernelLauncher(int* tgt_indir_cache, int const* src_indi } template -BaseBeamSearchLayer::BaseBeamSearchLayer( - size_t vocab_size, size_t vocab_size_padded, cudaStream_t stream, std::shared_ptr allocator) +BaseBeamSearchLayer::BaseBeamSearchLayer(runtime::SizeType vocab_size, runtime::SizeType vocab_size_padded, + cudaStream_t stream, std::shared_ptr allocator) : BaseLayer(stream, std::move(allocator), nullptr) , vocab_size_(vocab_size) , vocab_size_padded_(vocab_size_padded) @@ -115,7 +115,7 @@ void BaseBeamSearchLayer::freeBuffer() } template -void BaseBeamSearchLayer::allocateBuffer(size_t batch_size) +void BaseBeamSearchLayer::allocateBuffer(runtime::SizeType batch_size) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); mIsAllocateBuffer = true; @@ -123,7 +123,7 @@ void BaseBeamSearchLayer::allocateBuffer(size_t batch_size) } template -void BaseBeamSearchLayer::setupBase(size_t batch_size, SetupParams const& setupParams) +void BaseBeamSearchLayer::setupBase(runtime::SizeType batch_size, SetupParams const& setupParams) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); allocateBuffer(batch_size); diff --git a/cpp/tensorrt_llm/layers/baseBeamSearchLayer.h b/cpp/tensorrt_llm/layers/baseBeamSearchLayer.h index 4b92a534d..b7c07ea1f 100644 --- a/cpp/tensorrt_llm/layers/baseBeamSearchLayer.h +++ b/cpp/tensorrt_llm/layers/baseBeamSearchLayer.h @@ -21,6 +21,7 @@ #include "tensorrt_llm/kernels/decodingCommon.h" #include "tensorrt_llm/layers/baseLayer.h" #include "tensorrt_llm/layers/decodingParams.h" +#include "tensorrt_llm/runtime/common.h" #include @@ -42,8 +43,8 @@ class BaseBeamSearchLayer : public BaseLayer public: using SetupParams = DecodingSetupParams; - BaseBeamSearchLayer( - size_t vocab_size, size_t vocab_size_padded, cudaStream_t stream, std::shared_ptr allocator); + BaseBeamSearchLayer(runtime::SizeType vocab_size, runtime::SizeType vocab_size_padded, cudaStream_t stream, + std::shared_ptr allocator); BaseBeamSearchLayer(BaseBeamSearchLayer const& beam_search_layer); @@ -54,8 +55,9 @@ class BaseBeamSearchLayer : public BaseLayer class ForwardParams : public SoftmaxParams { public: - ForwardParams(int step, int ite, tc::Tensor logits, tc::Tensor endIds, tc::Tensor src_cache_indirection, - int max_attention_window, int sink_token_length, int max_seq_len) + ForwardParams(runtime::SizeType step, runtime::SizeType ite, tc::Tensor logits, tc::Tensor endIds, + tc::Tensor src_cache_indirection, runtime::SizeType max_attention_window, + runtime::SizeType sink_token_length, runtime::SizeType max_seq_len) : SoftmaxParams(step, ite, std::move(logits), std::move(endIds)) , src_cache_indirection{std::move(src_cache_indirection)} , max_attention_window{max_attention_window} @@ -65,9 +67,9 @@ class BaseBeamSearchLayer : public BaseLayer } // mandatory parameters - int max_attention_window; - int sink_token_length; - int max_seq_len; + runtime::SizeType max_attention_window; + runtime::SizeType sink_token_length; + runtime::SizeType max_seq_len; tc::Tensor src_cache_indirection; // [local_batch_size, beam_width, max_seq_len] // optional parameters @@ -98,18 +100,18 @@ class BaseBeamSearchLayer : public BaseLayer protected: // meta data - size_t vocab_size_; - size_t vocab_size_padded_; + runtime::SizeType vocab_size_; + runtime::SizeType vocab_size_padded_; size_t topk_softmax_workspace_size_; void* topk_softmax_workspace_ = nullptr; virtual void invokeSoftMax(BeamSearchOutputParams& outputs, SoftmaxParams const& params) = 0; - void setupBase(size_t batch_size, SetupParams const& setupParams); + void setupBase(runtime::SizeType batch_size, SetupParams const& setupParams); private: - void allocateBuffer(size_t batch_size); + void allocateBuffer(runtime::SizeType batch_size); void freeBuffer(); }; diff --git a/cpp/tensorrt_llm/layers/baseSamplingLayer.cpp b/cpp/tensorrt_llm/layers/baseSamplingLayer.cpp index abfb2b6b1..ebab3f82f 100644 --- a/cpp/tensorrt_llm/layers/baseSamplingLayer.cpp +++ b/cpp/tensorrt_llm/layers/baseSamplingLayer.cpp @@ -27,13 +27,14 @@ using namespace tensorrt_llm::common; using namespace tensorrt_llm::kernels; +using namespace tensorrt_llm::runtime; namespace tensorrt_llm { namespace layers { template -BaseSamplingLayer::BaseSamplingLayer(size_t maxBatchSize, size_t vocabSize, size_t vocabSizePadded, +BaseSamplingLayer::BaseSamplingLayer(SizeType maxBatchSize, SizeType vocabSize, SizeType vocabSizePadded, cudaStream_t stream, std::shared_ptr allocator, cudaDeviceProp* prop) : BaseLayer(stream, std::move(allocator), prop) , mMaxBatchSize(maxBatchSize) diff --git a/cpp/tensorrt_llm/layers/baseSamplingLayer.h b/cpp/tensorrt_llm/layers/baseSamplingLayer.h index 5e0c45811..c200a95ad 100644 --- a/cpp/tensorrt_llm/layers/baseSamplingLayer.h +++ b/cpp/tensorrt_llm/layers/baseSamplingLayer.h @@ -23,6 +23,7 @@ #include "tensorrt_llm/kernels/penaltyTypes.h" #include "tensorrt_llm/layers/baseLayer.h" #include "tensorrt_llm/layers/decodingParams.h" +#include "tensorrt_llm/runtime/common.h" namespace tc = tensorrt_llm::common; @@ -40,12 +41,12 @@ class BaseSamplingLayer : public BaseLayer class SetupParams : public DecodingSetupParams { public: - std::optional> runtime_top_k; // [1] or [batchSize] on cpu - std::optional> runtime_top_p; // [1] or [batchSize] on cpu - std::optional> randomSeed; // [1] or [batchSize] on cpu - std::optional> top_p_decay; // [batchSize], must between [0, 1] - std::optional> top_p_min; // [batchSize], must between [0, 1] - std::optional> top_p_reset_ids; // [batchSize] + std::optional> runtime_top_k; // [1] or [batchSize] on cpu + std::optional> runtime_top_p; // [1] or [batchSize] on cpu + std::optional> randomSeed; // [1] or [batchSize] on cpu + std::optional> top_p_decay; // [batchSize], must between [0, 1] + std::optional> top_p_min; // [batchSize], must between [0, 1] + std::optional> top_p_reset_ids; // [batchSize] std::optional normalize_log_probs; }; @@ -80,8 +81,8 @@ class BaseSamplingLayer : public BaseLayer //! \param allocator shared pointer to IAllocator object that will be use to alloc and free tensors //! \param prop [optional] cudaDeviceProp // clang-format on - BaseSamplingLayer(size_t maxBatchSize, size_t vocabSize, size_t vocabSizePadded, cudaStream_t stream, - std::shared_ptr allocator, cudaDeviceProp* prop); + BaseSamplingLayer(runtime::SizeType maxBatchSize, runtime::SizeType vocabSize, runtime::SizeType vocabSizePadded, + cudaStream_t stream, std::shared_ptr allocator, cudaDeviceProp* prop); ~BaseSamplingLayer() override = default; @@ -107,7 +108,7 @@ class BaseSamplingLayer : public BaseLayer //! \param batchSlots input tensor [batchSize], address map of the new requests, in pinned memory //! \param setupParams setup sampling parameters per request // clang-format on - virtual void setup(size_t batchSize, int32_t const* batchSlots, SetupParams const& setupParams) = 0; + virtual void setup(runtime::SizeType batchSize, int32_t const* batchSlots, SetupParams const& setupParams) = 0; size_t getWorkspaceSize() const { @@ -120,9 +121,9 @@ class BaseSamplingLayer : public BaseLayer } protected: - size_t mMaxBatchSize; - size_t mVocabSize; - size_t mVocabSizePadded; + runtime::SizeType mMaxBatchSize; + runtime::SizeType mVocabSize; + runtime::SizeType mVocabSizePadded; size_t mSamplingWorkspaceSize = 0; size_t mAllocatedSize = 0; diff --git a/cpp/tensorrt_llm/layers/decodingParams.h b/cpp/tensorrt_llm/layers/decodingParams.h index 6f7132442..7182914ac 100644 --- a/cpp/tensorrt_llm/layers/decodingParams.h +++ b/cpp/tensorrt_llm/layers/decodingParams.h @@ -76,6 +76,10 @@ class DecodingOutputParams std::optional parent_ids; // [max_seq_len, batch_size * beam_width], necessary in beam search tc::Tensor output_ids_ptr; // [batch_size] int* (2-d array), each int* has [beam_width, max_seq_len] + + // Medusa params + std::optional nextDraftTokens; // [batch_size, max_draft_tokens_per_step] + std::optional acceptedLengths; // [batch_size] }; } // namespace tensorrt_llm::layers diff --git a/cpp/tensorrt_llm/layers/dynamicDecodeLayer.cpp b/cpp/tensorrt_llm/layers/dynamicDecodeLayer.cpp index 9476a11c4..c5f7b0180 100644 --- a/cpp/tensorrt_llm/layers/dynamicDecodeLayer.cpp +++ b/cpp/tensorrt_llm/layers/dynamicDecodeLayer.cpp @@ -71,9 +71,10 @@ bool hasDiffRuntimeArgs(DecodingSetupParams const& params) } // namespace template -DynamicDecodeLayer::DynamicDecodeLayer(DecodingMode const& mode, size_t maxBatchSize, size_t maxBeamWidth, - size_t vocabSize, size_t vocabSizePadded, cudaStream_t stream, std::shared_ptr allocator, - cudaDeviceProp* cudaDeviceProp) +DynamicDecodeLayer::DynamicDecodeLayer(DecodingMode const& mode, SizeType maxBatchSize, SizeType maxBeamWidth, + SizeType vocabSize, SizeType vocabSizePadded, cudaStream_t stream, std::shared_ptr allocator, + cudaDeviceProp* cudaDeviceProp, std::optional maxTokensPerStep, + std::optional maxNumMedusaHeads) : BaseLayer(stream, std::move(allocator)) , mDecodingMode(mode) , mMaxBatchSize(maxBatchSize) @@ -81,9 +82,14 @@ DynamicDecodeLayer::DynamicDecodeLayer(DecodingMode const& mode, size_t maxBa , mVocabSize(vocabSize) , mVocabSizePadded(vocabSizePadded) , mCudaDeviceProp(cudaDeviceProp) + , mMaxTokensPerStep(maxTokensPerStep.value_or(1)) + , mMaxNumMedusaHeads(maxNumMedusaHeads.value_or(0)) { - TLLM_LOG_TRACE(__PRETTY_FUNCTION__); + TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + initialize(); + + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template @@ -95,22 +101,30 @@ DynamicDecodeLayer::DynamicDecodeLayer(DynamicDecodeLayer const& dynamicDecod , mVocabSize(dynamicDecodeLayer.mVocabSize) , mVocabSizePadded(dynamicDecodeLayer.mVocabSizePadded) , mCudaDeviceProp(dynamicDecodeLayer.mCudaDeviceProp) + , mMaxTokensPerStep(dynamicDecodeLayer.mMaxTokensPerStep) + , mMaxNumMedusaHeads(dynamicDecodeLayer.mMaxNumMedusaHeads) { - TLLM_LOG_TRACE(__PRETTY_FUNCTION__); + TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + initialize(); + + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template DynamicDecodeLayer::~DynamicDecodeLayer() { - TLLM_LOG_TRACE(__PRETTY_FUNCTION__); + TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + freeBuffer(); + + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void DynamicDecodeLayer::initialize() { - TLLM_LOG_TRACE(__PRETTY_FUNCTION__); + TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); mIdsPtrHost = runtime::BufferManager::pinned(ITensor::makeShape({}), runtime::TRTDataType::value); mLogitsPtrsHost = runtime::BufferManager::pinned(ITensor::makeShape({}), runtime::TRTDataType::value); @@ -132,12 +146,15 @@ void DynamicDecodeLayer::initialize() mConfiguredBeamWidth = mMaxBeamWidth; initializeLayers(); } + + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void DynamicDecodeLayer::allocateBuffer() { - TLLM_LOG_TRACE(__PRETTY_FUNCTION__); + TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + mZeroParentIdsDevice = mAllocator->reMalloc(mZeroParentIdsDevice, sizeof(int*) * 2 * mMaxBatchSize, false); mTemperatureDevice = mAllocator->reMalloc(mTemperatureDevice, sizeof(float) * mMaxBatchSize, false); mRepetitionPenaltyDevice = mAllocator->reMalloc(mRepetitionPenaltyDevice, sizeof(float) * mMaxBatchSize, false); @@ -145,13 +162,16 @@ void DynamicDecodeLayer::allocateBuffer() mFrequencyPenaltyDevice = mAllocator->reMalloc(mFrequencyPenaltyDevice, sizeof(float) * mMaxBatchSize, false); mMinLengthDevice = mAllocator->reMalloc(mMinLengthDevice, sizeof(int32_t) * mMaxBatchSize, false); mRuntimeLogitsDevice = mAllocator->reMalloc( - mRuntimeLogitsDevice, sizeof(T) * mMaxBatchSize * mMaxBeamWidth * mVocabSizePadded, false); + mRuntimeLogitsDevice, sizeof(T) * mMaxBatchSize * mMaxTokensPerStep * mMaxBeamWidth * mVocabSizePadded, false); + + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void DynamicDecodeLayer::freeBuffer() { - TLLM_LOG_TRACE(__PRETTY_FUNCTION__); + TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + mAllocator->free((void**) &mZeroParentIdsDevice); if (mPenaltyWorkspaceDevice != nullptr) { @@ -167,12 +187,16 @@ void DynamicDecodeLayer::freeBuffer() mAllocator->free((void**) (&mFrequencyPenaltyDevice)); mAllocator->free((void**) (&mMinLengthDevice)); mAllocator->free((void**) (&mRuntimeLogitsDevice)); + + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void DynamicDecodeLayer::initializeLayers() { - const size_t workspaceSize = sizeof(int) * mMaxBatchSize * mConfiguredBeamWidth * mVocabSize; + TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + + auto const workspaceSize = sizeof(SizeType) * mMaxBatchSize * mMaxTokensPerStep * mConfiguredBeamWidth * mVocabSize; mPenaltyWorkspaceDevice = mAllocator->reMalloc(mPenaltyWorkspaceDevice, workspaceSize, false); if (mDecodingMode.isTopKorTopP()) @@ -186,15 +210,23 @@ void DynamicDecodeLayer::initializeLayers() = std::make_unique>(mVocabSize, mVocabSizePadded, mStream, mAllocator); mPenaltyWorkspacePrevDevice = mAllocator->reMalloc(mPenaltyWorkspacePrevDevice, workspaceSize, false); } + else if (mDecodingMode.isMedusa()) + { + mMedusaDecodingLayer = std::make_unique>( + mMaxBatchSize, mVocabSize, mVocabSizePadded, mMaxTokensPerStep, mMaxNumMedusaHeads, mStream, mAllocator); + } else { - TLLM_CHECK_WITH_INFO(false, "Decoding mode is none of the supported {TopK, TopP, TopKTopP, BeamSearch}"); + TLLM_CHECK_WITH_INFO( + false, "Decoding mode is none of the supported {TopK, TopP, TopKTopP, BeamSearch, Medusa}"); } + + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void DynamicDecodeLayer::setup( - size_t batchSize, size_t beamWidth, int32_t const* batchSlots, SetupParams const& setupParams) + SizeType batchSize, SizeType beamWidth, SizeType const* batchSlots, SetupParams const& setupParams) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); @@ -210,9 +242,9 @@ void DynamicDecodeLayer::setup( TLLM_CHECK_WITH_INFO((mConfiguredBeamWidth == 1 && beamWidth == 1) || (mConfiguredBeamWidth > 1 && beamWidth > 1 && beamWidth <= mConfiguredBeamWidth), - "Decoder is configured with beam width %d, but %lu was given", mConfiguredBeamWidth, beamWidth); + "Decoder is configured with beam width %d, but %d was given", mConfiguredBeamWidth, beamWidth); TLLM_CHECK_WITH_INFO(mConfiguredBeamWidth <= mMaxBeamWidth, - "Decoder is created with max beam width %lu, but %d was given", mMaxBeamWidth, mConfiguredBeamWidth); + "Decoder is created with max beam width %d, but %d was given", mMaxBeamWidth, mConfiguredBeamWidth); setupLayers(batchSize, beamWidth, batchSlots, setupParams); @@ -223,13 +255,13 @@ void DynamicDecodeLayer::setup( template void DynamicDecodeLayer::setupLayers( - size_t batchSize, size_t beamWidth, int32_t const* batchSlots, SetupParams const& setupParams) + SizeType batchSize, SizeType beamWidth, int32_t const* batchSlots, SetupParams const& setupParams) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); - if (beamWidth == 1) + if (mDecodingMode.isTopKorTopP()) { // sampling layers TLLM_CHECK_WITH_INFO( - mDecodingMode.isTopKorTopP(), "beamWidth == 1 is given, but decoder is not configured as TopK or TopP"); + beamWidth == 1, "Decoding mode is TopK and/or TopP, but beamWidth != 1 (%d != 1)", beamWidth); typename TopPSamplingLayer::SetupParams samplingParams; samplingParams.runtime_top_k = setupParams.runtime_top_k; @@ -243,10 +275,9 @@ void DynamicDecodeLayer::setupLayers( mSamplingLayer->setup(batchSize, batchSlots, samplingParams); } - else + else if (mDecodingMode.isBeamSearch()) { // beam search layer - TLLM_CHECK_WITH_INFO( - mDecodingMode.isBeamSearch(), "beamWidth > 1 is given, but decoder is not configured as BeamSearch"); + TLLM_CHECK_WITH_INFO(beamWidth > 1, "Decoding mode is beam search, but beamWidth <= 1 (%d <= 1)", beamWidth); typename OnlineBeamSearchLayer::SetupParams beamSearchParams; beamSearchParams.beam_search_diversity_rate = setupParams.beam_search_diversity_rate; @@ -256,11 +287,26 @@ void DynamicDecodeLayer::setupLayers( mHasDiffRuntimeArgs = hasDiffRuntimeArgs(beamSearchParams); mOnlineBeamSearchDecode->setup(batchSize, beamSearchParams); } + else if (mDecodingMode.isMedusa()) + { + typename MedusaDecodingLayer::MedusaSetupParams medusaSetupParams; + medusaSetupParams.runtimeTopK = setupParams.runtime_top_k; + medusaSetupParams.runtimeHeadsTopK = setupParams.topKMedusaHeads; + medusaSetupParams.tokensPerStep = setupParams.tokensPerStep; + medusaSetupParams.randomSeed = setupParams.randomSeed; + mMedusaDecodingLayer->setup(batchSize, batchSlots, medusaSetupParams); + } + else + { + TLLM_CHECK_WITH_INFO( + false, "Decoding mode is none of the supported {TopK, TopP, TopKTopP, BeamSearch, Medusa}"); + } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template -void DynamicDecodeLayer::setupPenalties(size_t batchSize, int32_t const* batchSlots, SetupParams const& setupParams) +void DynamicDecodeLayer::setupPenalties( + SizeType batchSize, SizeType const* batchSlots, SetupParams const& setupParams) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); std::vector batchSlotsVec(batchSize); @@ -312,31 +358,33 @@ void DynamicDecodeLayer::forward(OutputParams& outputs, ForwardParams const& TLLM_CHECK_WITH_INFO( outputs.sequence_length.has_value(), "sequence_length tensor is mandatory in DynamicDecoderLayer."); - size_t batchSize = 0; - size_t beamWidth = 0; - size_t vocabSize = 0; + SizeType batchSize = 0; + SizeType beamWidth = 0; + SizeType vocabSize = 0; auto const maxSeqLen = outputs.output_ids.shape[outputs.output_ids.shape.size() - 1]; if (params.logits) { auto const& logitsShape = params.logits->shape; - TLLM_CHECK(logitsShape.size() == 3); + TLLM_CHECK(logitsShape.size() == 3 || logitsShape.size() == 4); batchSize = logitsShape[0]; - beamWidth = logitsShape[1]; - vocabSize = logitsShape[2]; + auto const idxOffset = logitsShape.size() - 3; + beamWidth = logitsShape[idxOffset + 1]; + vocabSize = logitsShape[idxOffset + 2]; } else { TLLM_CHECK(params.logits_vec->size()); auto const& logitsShape = params.logits_vec.value()[0].shape; - TLLM_CHECK(logitsShape.size() == 3); + TLLM_CHECK(logitsShape.size() == 3 || logitsShape.size() == 4); + auto const idxOffset = logitsShape.size() - 3; batchSize = params.logits_vec->size(); - beamWidth = logitsShape[1]; - vocabSize = logitsShape[2]; + beamWidth = logitsShape[idxOffset + 1]; + vocabSize = logitsShape[idxOffset + 2]; } TLLM_CHECK_WITH_INFO((mConfiguredBeamWidth == 1 && beamWidth == 1) || (mConfiguredBeamWidth > 1 && beamWidth > 1 && beamWidth <= mConfiguredBeamWidth), - "Decoder is configured with beam width %d, but %lu was given", mConfiguredBeamWidth, beamWidth); + "Decoder is configured with beam width %d, but %d was given", mConfiguredBeamWidth, beamWidth); if (!mLogitsPtrsHost->data()) { @@ -358,7 +406,9 @@ void DynamicDecodeLayer::forward(OutputParams& outputs, ForwardParams const& prepareIdsPtrs(outputs, batchSlotsHost, batchSize, beamWidth, maxSeqLen); auto logits = Tensor(MEMORY_GPU, std::is_same_v ? DataType::TYPE_FP32 : DataType::TYPE_FP16, - {batchSize, beamWidth, mVocabSizePadded}, mRuntimeLogitsDevice); + {static_cast(batchSize), static_cast(mMaxTokensPerStep), static_cast(beamWidth), + static_cast(mVocabSizePadded)}, + mRuntimeLogitsDevice); // Apply penalties applyPenalties(outputs, params, batchSlotsHost, batchSlots, batchSize, beamWidth, maxSeqLen); @@ -373,8 +423,8 @@ void DynamicDecodeLayer::forward(OutputParams& outputs, ForwardParams const& checkStopCriteria(outputs, params, batchSlots, batchSize, beamWidth, maxSeqLen, mStream); // Copy nextIds and transpose logits when needed - prepareOutputData( - outputs, params, mIdsPtrHost, batchSlots, batchSize, mMaxBatchSize, beamWidth, maxSeqLen, mCyclicStep, mStream); + prepareOutputData(outputs, params, mIdsPtrHost, batchSlots, batchSize, mMaxBatchSize, beamWidth, maxSeqLen, + mMaxTokensPerStep, mCyclicStep, mStream); mCyclicStep += 1; @@ -384,7 +434,7 @@ void DynamicDecodeLayer::forward(OutputParams& outputs, ForwardParams const& template void DynamicDecodeLayer::layersForward(Tensor& logits, OutputParams& outputs, ForwardParams const& params, - int32_t const* batchSlots, size_t batchSize, size_t beamWidth, size_t maxSeqLen) + int32_t const* batchSlots, SizeType batchSize, SizeType beamWidth, SizeType maxSeqLen) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto const ite = params.ite; @@ -395,10 +445,9 @@ void DynamicDecodeLayer::layersForward(Tensor& logits, OutputParams& outputs, auto const localBatchSize = static_cast(params.local_batch_size); // dynamic decode GPT - if (beamWidth > 1) + if (mDecodingMode.isBeamSearch()) { - TLLM_CHECK_WITH_INFO( - mDecodingMode.isBeamSearch(), "beamWidth > 1 is given, but decoder is not configured as BeamSearch"); + TLLM_CHECK_WITH_INFO(beamWidth > 1, "Decoding mode is beam search, but beamWidth <= 1 (%d <= 1)", beamWidth); TLLM_CHECK_WITH_INFO( params.src_cache_indirection.has_value(), "src_cache_indirection is mandatory in beam search."); TLLM_CHECK_WITH_INFO( @@ -409,13 +458,13 @@ void DynamicDecodeLayer::layersForward(Tensor& logits, OutputParams& outputs, // Because we still not support batch beam search now, so we need to compute // one by one if there are different runtime arguments. - const size_t dynamic_decode_batch_size = mHasDiffRuntimeArgs ? 1 : localBatchSize; - int const dynamic_decode_total_iteration = localBatchSize / dynamic_decode_batch_size; + size_t const dynamic_decode_batch_size = mHasDiffRuntimeArgs ? 1 : localBatchSize; + auto const dynamic_decode_total_iteration = localBatchSize / dynamic_decode_batch_size; for (uint32_t dynamic_ite = 0; dynamic_ite < dynamic_decode_total_iteration; ++dynamic_ite) { - int const dynamic_id_offset = dynamic_ite * dynamic_decode_batch_size * beamWidth; - int const dynamic_decode_vocab_size_units_offset = dynamic_id_offset * mVocabSizePadded; + auto const dynamic_id_offset = dynamic_ite * dynamic_decode_batch_size * beamWidth; + auto const dynamic_decode_vocab_size_units_offset = dynamic_id_offset * mVocabSizePadded; auto const logits_offset = logits.slice( {dynamic_decode_batch_size, logits.shape[1], logits.shape[2]}, dynamic_decode_vocab_size_units_offset); @@ -454,14 +503,14 @@ void DynamicDecodeLayer::layersForward(Tensor& logits, OutputParams& outputs, } // end of dynamic_ite std::swap(mPenaltyWorkspaceDevice, mPenaltyWorkspacePrevDevice); } - else + else if (mDecodingMode.isTopKorTopP()) { // beamWidth == 1 TLLM_CHECK_WITH_INFO( - mDecodingMode.isTopKorTopP(), "beamWidth == 1 is given, but decoder is not configured as TopK or TopP"); + beamWidth == 1, "Decoding mode is TopK and/or TopP, but beamWidth != 1 (%d != 1)", beamWidth); // In sampling, we have supported batch sampling. So, we always compute all // sentences once. - Tensor const logits_slice{logits.slice({localBatchSize, beamWidth, logits.shape[2]}, 0)}; + Tensor const logits_slice{logits.slice({localBatchSize, static_cast(beamWidth), logits.shape[2]}, 0)}; Tensor const end_id_slice{endIds.slice({localBatchSize}, 0)}; typename BaseSamplingLayer::ForwardParams decode_input_tensors{ step, ite, logits_slice, end_id_slice, static_cast(maxSeqLen)}; @@ -471,7 +520,8 @@ void DynamicDecodeLayer::layersForward(Tensor& logits, OutputParams& outputs, if (params.input_lengths) { auto& input_lengths = params.input_lengths.value(); - decode_input_tensors.input_lengths = input_lengths.slice({localBatchSize, beamWidth}, 0); + decode_input_tensors.input_lengths + = input_lengths.slice({localBatchSize, static_cast(beamWidth)}, 0); } decode_input_tensors.batch_slots = params.batch_slots; @@ -498,12 +548,31 @@ void DynamicDecodeLayer::layersForward(Tensor& logits, OutputParams& outputs, // Run TopK + TopP decode layers. mSamplingLayer->forward(decode_outputs, decode_input_tensors); } + else if (mDecodingMode.isMedusa()) + { + TLLM_CHECK_WITH_INFO(beamWidth == 1, "Decoding mode is Medusa, but beamWidth != 1 (%d != 1)", beamWidth); + + typename MedusaDecodingLayer::MedusaForwardParams medusaInputParams(logits, endIds); + medusaInputParams.finished = outputs.finished.value(); + medusaInputParams.batch_slots = params.batch_slots; + medusaInputParams.paths = params.paths.value(); + medusaInputParams.medusaLogits = params.medusaLogits.value(); + + DecodingOutputParams medusaOutputParams(outputs.output_ids); + medusaOutputParams.sequence_length = outputs.sequence_length.value(); + medusaOutputParams.finished = outputs.finished.value(); + medusaOutputParams.nextDraftTokens = outputs.nextDraftTokens.value(); + medusaOutputParams.acceptedLengths = outputs.acceptedLengths.value(); + + mMedusaDecodingLayer->forward(medusaOutputParams, medusaInputParams); + } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void DynamicDecodeLayer::applyPenalties(OutputParams& outputs, ForwardParams const& params, - int32_t const* batchSlotsHost, int32_t const* batchSlots, size_t batchSize, size_t beamWidth, size_t maxSeqLen) + int32_t const* batchSlotsHost, int32_t const* batchSlots, SizeType batchSize, SizeType beamWidth, + SizeType maxSeqLen) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); @@ -514,7 +583,7 @@ void DynamicDecodeLayer::applyPenalties(OutputParams& outputs, ForwardParams if (params.logits_vec) { TLLM_CHECK_WITH_INFO(params.logits_vec->size() == batchSize, - "Logits vector size (%lu) is not equal to the batchSize (%lu)", params.logits_vec->size(), batchSize); + "Logits vector size (%lu) is not equal to the batchSize (%d)", params.logits_vec->size(), batchSize); logitsPtrsHostData[bi] = params.logits_vec.value()[bi].template getPtr(); } else @@ -545,8 +614,7 @@ void DynamicDecodeLayer::applyPenalties(OutputParams& outputs, ForwardParams #undef GET_PENALTIES - constexpr int32_t maxTokensPerStep = 1; - int32_t* tokensPerStep = nullptr; + auto const tokensPerStep = params.tokensPerStep ? params.tokensPerStep->template getPtr() : nullptr; InvokeBatchApplyPenaltyParams penaltyParams{reinterpret_cast(logitsPtrsHostData), mRuntimeLogitsDevice, embeddingBias, mPenaltyWorkspaceDevice, mPenaltyWorkspacePrevDevice, temperatures, repetitionPenalties, presencePenalties, frequencyPenalties, @@ -554,7 +622,7 @@ void DynamicDecodeLayer::applyPenalties(OutputParams& outputs, ForwardParams static_cast(beamWidth), static_cast(maxSeqLen), mVocabSize, mVocabSizePadded, outputs.output_ids_ptr.template getPtr(), outputs.parent_ids_ptr.template getPtr(), inputLengths, outputs.sequence_length->template getPtr(), minLengths, - params.end_ids.template getPtr(), batchSlots, maxTokensPerStep, tokensPerStep, mStream}; + params.end_ids.template getPtr(), batchSlots, mMaxTokensPerStep, tokensPerStep, mStream}; invokeBatchApplyPenalty(penaltyParams); sync_check_cuda_error(); @@ -563,11 +631,17 @@ void DynamicDecodeLayer::applyPenalties(OutputParams& outputs, ForwardParams template void DynamicDecodeLayer::banWords(Tensor& logits, OutputParams& outputs, ForwardParams const& params, - int32_t const* batchSlots, size_t batchSize, size_t beamWidth, size_t maxSeqLen, size_t vocabSizePadded, + int32_t const* batchSlots, SizeType batchSize, SizeType beamWidth, SizeType maxSeqLen, SizeType vocabSizePadded, cudaStream_t stream) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + if (mDecodingMode.isMedusa()) + { + // Do not support Ban Words for Medusa + return; + } + banRepeatNGrams(logits, outputs, params, batchSlots, batchSize, beamWidth, maxSeqLen, vocabSizePadded, stream); banBadWords(logits, outputs, params, batchSlots, batchSize, beamWidth, maxSeqLen, vocabSizePadded, stream); @@ -576,7 +650,7 @@ void DynamicDecodeLayer::banWords(Tensor& logits, OutputParams& outputs, Forw template void DynamicDecodeLayer::banRepeatNGrams(Tensor& logits, OutputParams& outputs, ForwardParams const& params, - int32_t const* batchSlots, size_t batchSize, size_t beamWidth, size_t maxSeqLen, size_t vocabSizePadded, + int32_t const* batchSlots, SizeType batchSize, SizeType beamWidth, SizeType maxSeqLen, SizeType vocabSizePadded, cudaStream_t stream) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); @@ -597,18 +671,19 @@ void DynamicDecodeLayer::banRepeatNGrams(Tensor& logits, OutputParams& output template void DynamicDecodeLayer::banBadWords(Tensor& logits, OutputParams& outputs, ForwardParams const& params, - int32_t const* batchSlots, size_t batchSize, size_t beamWidth, size_t maxSeqLen, size_t vocabSizePadded, + int32_t const* batchSlots, SizeType batchSize, SizeType beamWidth, SizeType maxSeqLen, SizeType vocabSizePadded, cudaStream_t stream) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto const maxBadWordsLength = params.max_bad_words_len; if (maxBadWordsLength) { - int32_t const** badWordsPtr = params.bad_words_ptr->template getPtr(); - int32_t const* badWordsLens = params.bad_words_lengths->template getPtr(); + auto const** badWordsPtr = params.bad_words_ptr->template getPtr(); + auto const* badWordsLens = params.bad_words_lengths->template getPtr(); - invokeBanBadWords((T*) logits.template getPtr(), outputs.output_ids_ptr.template getPtr(), - beamWidth > 1 ? outputs.parent_ids_ptr.template getPtr() : nullptr, batchSlots, batchSize, + invokeBanBadWords((T*) logits.template getPtr(), + outputs.output_ids_ptr.template getPtr(), + beamWidth > 1 ? outputs.parent_ids_ptr.template getPtr() : nullptr, batchSlots, batchSize, beamWidth, badWordsPtr, badWordsLens, maxBadWordsLength, vocabSizePadded, outputs.sequence_length->template getPtr(), maxSeqLen, stream); } @@ -617,11 +692,16 @@ void DynamicDecodeLayer::banBadWords(Tensor& logits, OutputParams& outputs, F template void DynamicDecodeLayer::checkStopCriteria(OutputParams& outputs, ForwardParams const& params, - int32_t const* batchSlots, size_t batchSize, size_t beamWidth, size_t maxSeqLen, cudaStream_t stream) + int32_t const* batchSlots, SizeType batchSize, SizeType beamWidth, SizeType maxSeqLen, cudaStream_t stream) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); - checkStopWordsStopCriteria(outputs, params, batchSlots, batchSize, beamWidth, maxSeqLen, stream); + if (!mDecodingMode.isMedusa()) + { + // Do not support Stop Words for Medusa + checkStopWordsStopCriteria(outputs, params, batchSlots, batchSize, beamWidth, maxSeqLen, stream); + } + checkMaxLengthStopCriteria(outputs, params, batchSlots, batchSize, beamWidth, maxSeqLen, stream); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); @@ -629,7 +709,7 @@ void DynamicDecodeLayer::checkStopCriteria(OutputParams& outputs, ForwardPara template void DynamicDecodeLayer::checkStopWordsStopCriteria(OutputParams& outputs, ForwardParams const& params, - int32_t const* batchSlots, size_t batchSize, size_t beamWidth, size_t maxSeqLen, cudaStream_t stream) + int32_t const* batchSlots, SizeType batchSize, SizeType beamWidth, SizeType maxSeqLen, cudaStream_t stream) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto const maxStopWordsLength = params.max_stop_words_len; @@ -648,7 +728,7 @@ void DynamicDecodeLayer::checkStopWordsStopCriteria(OutputParams& outputs, Fo template void DynamicDecodeLayer::checkMaxLengthStopCriteria(OutputParams& outputs, ForwardParams const& params, - int32_t const* batchSlots, size_t batchSize, size_t beamWidth, size_t maxSeqLen, cudaStream_t stream) + int32_t const* batchSlots, SizeType batchSize, SizeType beamWidth, SizeType maxSeqLen, cudaStream_t stream) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); if (params.sequence_limit_length) @@ -665,16 +745,16 @@ void DynamicDecodeLayer::checkMaxLengthStopCriteria(OutputParams& outputs, Fo template void DynamicDecodeLayer::prepareIdsPtrs( - OutputParams& outputs, int32_t const* batchSlots, size_t batchSize, size_t beamWidth, size_t maxSeqLen) + OutputParams& outputs, SizeType const* batchSlots, SizeType batchSize, SizeType beamWidth, SizeType maxSeqLen) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto idsPtrHostSlice = ITensor::slice(mIdsPtrHost, mCyclicStep, 1); - auto idsPtrHost = reinterpret_cast(runtime::bufferCast(*idsPtrHostSlice)); + auto idsPtrHost = reinterpret_cast(runtime::bufferCast(*idsPtrHostSlice)); for (int bi = 0; bi < batchSize; bi++) { auto const batchSlot = batchSlots[bi]; idsPtrHost[batchSlot] - = outputs.output_ids.template getPtrWithOffset(batchSlot * beamWidth * maxSeqLen); + = outputs.output_ids.template getPtrWithOffset(batchSlot * beamWidth * maxSeqLen); } for (int bi = 0; bi < batchSize; bi++) { @@ -682,7 +762,7 @@ void DynamicDecodeLayer::prepareIdsPtrs( if (beamWidth > 1) { idsPtrHost[mMaxBatchSize + batchSlot] - = outputs.parent_ids.value().template getPtrWithOffset(bi * beamWidth * maxSeqLen); + = outputs.parent_ids.value().template getPtrWithOffset(bi * beamWidth * maxSeqLen); } else { @@ -690,23 +770,29 @@ void DynamicDecodeLayer::prepareIdsPtrs( } } - outputs.output_ids_ptr - = Tensor(MEMORY_GPU, DataType::TYPE_INT32_PTR, {mMaxBatchSize, beamWidth, maxSeqLen}, idsPtrHost); - outputs.parent_ids_ptr = Tensor( - MEMORY_GPU, DataType::TYPE_INT32_PTR, {mMaxBatchSize, beamWidth, maxSeqLen}, idsPtrHost + mMaxBatchSize); + outputs.output_ids_ptr = Tensor(MEMORY_GPU, DataType::TYPE_INT32_PTR, + {static_cast(mMaxBatchSize), static_cast(beamWidth), static_cast(maxSeqLen)}, + idsPtrHost); + outputs.parent_ids_ptr = Tensor(MEMORY_GPU, DataType::TYPE_INT32_PTR, + {static_cast(mMaxBatchSize), static_cast(beamWidth), static_cast(maxSeqLen)}, + idsPtrHost + mMaxBatchSize); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void DynamicDecodeLayer::prepareOutputData(OutputParams& outputs, ForwardParams const& params, - runtime::ITensor::SharedPtr const& idsPtrsHost, int32_t const* batchSlots, size_t batchSize, size_t maxBatchSize, - size_t beamWidth, size_t maxSeqLen, int32_t cyclicStep, cudaStream_t stream) + runtime::ITensor::SharedPtr const& idsPtrsHost, SizeType const* batchSlots, SizeType batchSize, + SizeType maxBatchSize, SizeType beamWidth, SizeType maxSeqLen, SizeType maxTokensPerStep, SizeType cyclicStep, + cudaStream_t stream) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto idsPtrHostSlice = ITensor::slice(idsPtrsHost, cyclicStep, 1); - auto idsPtrHost = reinterpret_cast(runtime::bufferCast(*idsPtrHostSlice)); - invokeCopyNextStepIds(outputs.newTokens.template getPtr(), idsPtrHost, - outputs.sequence_length->template getPtr(), batchSlots, batchSize, beamWidth, maxSeqLen, stream); + auto idsPtrHost = reinterpret_cast(runtime::bufferCast(*idsPtrHostSlice)); + auto const numNewTokens = outputs.acceptedLengths ? outputs.acceptedLengths->template getPtr() + : static_cast(nullptr); + invokeCopyNextStepIds(outputs.newTokens.template getPtr(), idsPtrHost, + outputs.sequence_length->template getPtr(), numNewTokens, batchSlots, batchSize, maxBatchSize, + beamWidth, maxSeqLen, maxTokensPerStep, stream); // Transpose the output log probs from [maxSeqLen, bs, beamWidth] to [batchSize, beamWidth, maxSeqLen] if (outputs.output_log_probs_tiled) @@ -715,7 +801,7 @@ void DynamicDecodeLayer::prepareOutputData(OutputParams& outputs, ForwardPara invokeTransposeLogProbs(outputs.output_log_probs.value().template getPtr(), outputs.output_log_probs_tiled.value().template getPtr(), - outputs.sequence_length->template getPtr(), batchSlots, batchSize, maxBatchSize, beamWidth, + outputs.sequence_length->template getPtr(), batchSlots, batchSize, maxBatchSize, beamWidth, logProbsMaxSeqLen, stream); } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); diff --git a/cpp/tensorrt_llm/layers/dynamicDecodeLayer.h b/cpp/tensorrt_llm/layers/dynamicDecodeLayer.h index af84465bf..e8752eadd 100644 --- a/cpp/tensorrt_llm/layers/dynamicDecodeLayer.h +++ b/cpp/tensorrt_llm/layers/dynamicDecodeLayer.h @@ -19,6 +19,7 @@ #include "tensorrt_llm/common/tensor.h" #include "tensorrt_llm/kernels/beamSearchTopkKernels.h" #include "tensorrt_llm/layers/baseLayer.h" +#include "tensorrt_llm/layers/medusaDecodingLayer.h" #include "tensorrt_llm/layers/onlineBeamSearchLayer.h" #include "tensorrt_llm/layers/samplingLayer.h" #include "tensorrt_llm/runtime/cudaStream.h" @@ -46,9 +47,11 @@ template class DynamicDecodeLayer : public BaseLayer { public: - DynamicDecodeLayer(runtime::DecodingMode const& mode, size_t max_batch_size, size_t max_beam_width, - size_t vocab_size, size_t vocab_size_padded, cudaStream_t stream, std::shared_ptr allocator, - cudaDeviceProp* cuda_device_prop); + DynamicDecodeLayer(runtime::DecodingMode const& mode, runtime::SizeType max_batch_size, + runtime::SizeType max_beam_width, runtime::SizeType vocab_size, runtime::SizeType vocab_size_padded, + cudaStream_t stream, std::shared_ptr allocator, cudaDeviceProp* cuda_device_prop, + std::optional maxTokensPerStep = std::nullopt, + std::optional maxNumMedusaHeads = std::nullopt); ~DynamicDecodeLayer() override; DynamicDecodeLayer(DynamicDecodeLayer const& dynamic_decode_layer); @@ -63,24 +66,29 @@ class DynamicDecodeLayer : public BaseLayer std::optional> min_length; // [1] or [batch_size] on cpu // baseSamplingLayer - std::optional> runtime_top_k; // [1] or [batch_size] on cpu - std::optional> runtime_top_p; // [1] or [batch_size] on cpu - std::optional> randomSeed; // [1] or [batch_size] on cpu + std::optional> runtime_top_k; // [1] or [batch_size] on cpu + std::optional> runtime_top_p; // [1] or [batch_size] on cpu + std::optional> randomSeed; // [1] or [batch_size] on cpu // topPSamplingLayer std::optional> top_p_decay; // [batch_size], must between [0, 1] std::optional> top_p_min; // [batch_size], must between [0, 1] std::optional> top_p_reset_ids; // [batch_size] - // omlineBeamSearchLayer + // onlineBeamSearchLayer std::optional> beam_search_diversity_rate; std::optional> length_penalty; std::optional> early_stopping; std::optional normalize_log_probs; + + // Medusa params + std::optional>> topKMedusaHeads; // [batchSize, maxMedusaHeads] + std::optional> tokensPerStep; // [batchSize] }; - void setup(size_t batch_size, size_t beam_width, int const* batch_slots, SetupParams const& setupParams); + void setup(runtime::SizeType batch_size, runtime::SizeType beam_width, int const* batch_slots, + SetupParams const& setupParams); class ForwardParams { @@ -127,8 +135,14 @@ class DynamicDecodeLayer : public BaseLayer std::optional bad_words_lengths; // [batch_size], on gpu std::optional stop_words_ptr; // [batch_size][2, stop_words_length], on gpu std::optional stop_words_lengths; // [batch_size], on gpu - std::optional no_repeat_ngram_size; // [batch_size] - std::optional batch_slots; // [batch_size] in pinned memory + std::optional no_repeat_ngram_size; // [batch_size], on gpu + std::optional batch_slots; // [batch_size], in pinned memory + + // Medusa inputs + std::optional tokensPerStep; // [batch_size], optional, on gpu + std::optional paths; // [batch_size, max_tokens_per_step, max_num_heads + 1], optional, on gpu + std::optional + medusaLogits; // [max_num_heads, batch_size, max_tokens_per_step, vocab_size], optional, on gpu }; class OutputParams @@ -143,21 +157,25 @@ class DynamicDecodeLayer : public BaseLayer tc::Tensor output_ids; // [batch_size, beam_width, max_seq_len] tc::Tensor newTokens; // [batch_size, beam_width] // optional parameters - std::optional finished; // [batch_size * beam_width] - std::optional finished_sum; // [1] in pinned host memory - std::optional cum_log_probs; // [batch_size * beam_width], necessary in beam search - std::optional parent_ids; // [max_seq_len, batch_size * beam_width], necessary in beam search - std::optional sequence_length; // [batch_size * beam_width] - std::optional output_log_probs_tiled; // [request_output_length, batch_size, beam_width] - // must be float* - std::optional output_log_probs; // [batch_size, beam_width, request_output_length] - // must be float* - std::optional tgt_cache_indirection; // [local_batch_size, beam_width, max_seq_len] - // the k/v cache index for beam search + std::optional finished; // [batch_size * beam_width] + std::optional finished_sum; // [1] in pinned host memory + std::optional cum_log_probs; // [batch_size * beam_width], necessary in beam search + std::optional parent_ids; // [max_seq_len, batch_size * beam_width], necessary in beam search + std::optional sequence_length; // [batch_size * beam_width] + std::optional + output_log_probs_tiled; // [request_output_length, batch_size, beam_width], must be float* + std::optional output_log_probs; // [batch_size, beam_width, request_output_length], must be float* + std::optional + tgt_cache_indirection; // [local_batch_size, beam_width, max_seq_len], the k/v cache index for beam search std::shared_ptr beamHypotheses; // structure maintains some pointers of beam search tc::Tensor output_ids_ptr; // [batch_size] int* (2-d array), each int* has [beam_width, max_seq_len] tc::Tensor parent_ids_ptr; // [batch_size] int* (2-d array), each int* has [beam_width, max_seq_len] + + // Medusa outputs + std::optional + nextDraftTokens; // [batch_size, max_tokens_per_step], draft tokens predicted by Medusa heads + std::optional acceptedLengths; // [batch_size], lengths of the accepted draft tokens + 1 }; void forward(OutputParams& outputs, ForwardParams const& params); @@ -173,47 +191,56 @@ class DynamicDecodeLayer : public BaseLayer void initialize(); void initializeLayers(); - void setupLayers(size_t batchSize, size_t beamWidth, int32_t const* batchSlots, SetupParams const& setupParams); - void setupPenalties(size_t batchSize, int32_t const* batchSlots, SetupParams const& setupParams); + void setupLayers(runtime::SizeType batchSize, runtime::SizeType beamWidth, runtime::SizeType const* batchSlots, + SetupParams const& setupParams); + void setupPenalties( + runtime::SizeType batchSize, runtime::SizeType const* batchSlots, SetupParams const& setupParams); void layersForward(tc::Tensor& logits, OutputParams& outputs, ForwardParams const& params, - int32_t const* batchSlots, size_t batchSize, size_t beamWidth, size_t maxSeqLen); + runtime::SizeType const* batchSlots, runtime::SizeType batchSize, runtime::SizeType beamWidth, + runtime::SizeType maxSeqLen); - void applyPenalties(OutputParams& outputs, ForwardParams const& params, int32_t const* batchSlotsHost, - int32_t const* batchSlots, size_t batchSize, size_t beamWidth, size_t maxSeqLen); + void applyPenalties(OutputParams& outputs, ForwardParams const& params, runtime::SizeType const* batchSlotsHost, + runtime::SizeType const* batchSlots, runtime::SizeType batchSize, runtime::SizeType beamWidth, + runtime::SizeType maxSeqLen); - static void banWords(tc::Tensor& logits, OutputParams& outputs, ForwardParams const& params, - int32_t const* batchSlots, size_t batchSize, size_t beamWidth, size_t maxSeqLen, size_t vocabSizePadded, - cudaStream_t stream); + void banWords(tc::Tensor& logits, OutputParams& outputs, ForwardParams const& params, + runtime::SizeType const* batchSlots, runtime::SizeType batchSize, runtime::SizeType beamWidth, + runtime::SizeType maxSeqLen, runtime::SizeType vocabSizePadded, cudaStream_t stream); static void banRepeatNGrams(tc::Tensor& logits, OutputParams& outputs, ForwardParams const& params, - int32_t const* batchSlots, size_t batchSize, size_t beamWidth, size_t maxSeqLen, size_t vocabSizePadded, - cudaStream_t stream); + runtime::SizeType const* batchSlots, runtime::SizeType batchSize, runtime::SizeType beamWidth, + runtime::SizeType maxSeqLen, runtime::SizeType vocabSizePadded, cudaStream_t stream); static void banBadWords(tc::Tensor& logits, OutputParams& outputs, ForwardParams const& params, - int32_t const* batchSlots, size_t batchSize, size_t beamWidth, size_t maxSeqLen, size_t vocabSizePadded, - cudaStream_t stream); + runtime::SizeType const* batchSlots, runtime::SizeType batchSize, runtime::SizeType beamWidth, + runtime::SizeType maxSeqLen, runtime::SizeType vocabSizePadded, cudaStream_t stream); - static void checkStopCriteria(OutputParams& outputs, ForwardParams const& params, int32_t const* batchSlots, - size_t batchSize, size_t beamWidth, size_t maxSeqLen, cudaStream_t stream); + void checkStopCriteria(OutputParams& outputs, ForwardParams const& params, int32_t const* batchSlots, + runtime::SizeType batchSize, runtime::SizeType beamWidth, runtime::SizeType maxSeqLen, cudaStream_t stream); static void checkMaxLengthStopCriteria(OutputParams& outputs, ForwardParams const& params, - int32_t const* batchSlots, size_t batchSize, size_t beamWidth, size_t maxSeqLen, cudaStream_t stream); + runtime::SizeType const* batchSlots, runtime::SizeType batchSize, runtime::SizeType beamWidth, + runtime::SizeType maxSeqLen, cudaStream_t stream); static void checkStopWordsStopCriteria(OutputParams& outputs, ForwardParams const& params, - int32_t const* batchSlots, size_t batchSize, size_t beamWidth, size_t maxSeqLen, cudaStream_t stream); + runtime::SizeType const* batchSlots, runtime::SizeType batchSize, runtime::SizeType beamWidth, + runtime::SizeType maxSeqLen, cudaStream_t stream); - void prepareIdsPtrs( - OutputParams& outputs, int32_t const* batchSlots, size_t batchSize, size_t beamWidth, size_t maxSeqLen); + void prepareIdsPtrs(OutputParams& outputs, runtime::SizeType const* batchSlots, runtime::SizeType batchSize, + runtime::SizeType beamWidth, runtime::SizeType maxSeqLen); static void prepareOutputData(OutputParams& outputs, ForwardParams const& params, - runtime::ITensor::SharedPtr const& idsPtrsHost, int32_t const* batchSlots, size_t batchSize, - size_t maxBatchSize, size_t beamWidth, size_t maxSeqLen, int32_t cyclicStep, cudaStream_t stream); + runtime::ITensor::SharedPtr const& idsPtrsHost, runtime::SizeType const* batchSlots, + runtime::SizeType batchSize, runtime::SizeType maxBatchSize, runtime::SizeType beamWidth, + runtime::SizeType maxSeqLen, runtime::SizeType maxTokensPerStep, runtime::SizeType cyclicStep, + cudaStream_t stream); private: std::unique_ptr> mOnlineBeamSearchDecode; std::unique_ptr> mSamplingLayer; + std::unique_ptr> mMedusaDecodingLayer; runtime::DecodingMode mDecodingMode; - size_t mMaxBatchSize; - size_t mMaxBeamWidth; - size_t mVocabSize; - size_t mVocabSizePadded; + runtime::SizeType mMaxBatchSize; + runtime::SizeType mMaxBeamWidth; + runtime::SizeType mVocabSize; + runtime::SizeType mVocabSizePadded; cudaDeviceProp* mCudaDeviceProp; @@ -248,6 +275,9 @@ class DynamicDecodeLayer : public BaseLayer int32_t mCyclicStep = 0; int32_t mRuntimeMaxSeqLen = 0; int32_t mConfiguredBeamWidth = -1; + + runtime::SizeType mMaxTokensPerStep; + runtime::SizeType mMaxNumMedusaHeads; }; } // namespace layers diff --git a/cpp/tensorrt_llm/layers/fillBuffers.h b/cpp/tensorrt_llm/layers/fillBuffers.h index 0e36d7149..1290841be 100644 --- a/cpp/tensorrt_llm/layers/fillBuffers.h +++ b/cpp/tensorrt_llm/layers/fillBuffers.h @@ -26,6 +26,7 @@ #include "tensorrt_llm/common/assert.h" #include "tensorrt_llm/common/memoryUtils.h" #include "tensorrt_llm/kernels/decodingCommon.h" +#include "tensorrt_llm/runtime/common.h" namespace tensorrt_llm { @@ -39,7 +40,7 @@ struct FillBuffers template void operator()(std::optional> const& optParam, T const defaultValue, std::vector& hostBuffer, - T* deviceBuffer, int32_t const* batchSlots) const + T* deviceBuffer, runtime::SizeType const* batchSlots) const { using tensorrt_llm::common::cudaAutoCpy; @@ -71,8 +72,8 @@ struct FillBuffers } } - size_t batchSize; - size_t maxBatchSize; + runtime::SizeType batchSize; + runtime::SizeType maxBatchSize; cudaStream_t stream; }; diff --git a/cpp/tensorrt_llm/layers/medusaDecodingLayer.cpp b/cpp/tensorrt_llm/layers/medusaDecodingLayer.cpp new file mode 100644 index 000000000..52df49851 --- /dev/null +++ b/cpp/tensorrt_llm/layers/medusaDecodingLayer.cpp @@ -0,0 +1,415 @@ +/* + * Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2021, NAVER Corp. Authored by CLOVA. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/layers/medusaDecodingLayer.h" +#include "tensorrt_llm/common/cudaUtils.h" +#include "tensorrt_llm/common/memoryUtils.h" +#include "tensorrt_llm/kernels/decodingCommon.h" +#include "tensorrt_llm/kernels/decodingKernels.h" +#include "tensorrt_llm/kernels/samplingTopKKernels.h" +#include "tensorrt_llm/runtime/bufferManager.h" + +#include + +using namespace tensorrt_llm::common; +using namespace tensorrt_llm::kernels; +using namespace tensorrt_llm::runtime; + +namespace tensorrt_llm +{ +namespace layers +{ + +template +MedusaDecodingLayer::MedusaDecodingLayer(SizeType maxBatchSize, SizeType vocabSize, SizeType vocabSizePadded, + SizeType maxTokensPerStep, SizeType maxNumHeads, cudaStream_t stream, std::shared_ptr allocator) + : BaseLayer(stream, std::move(allocator), nullptr) + , mMaxBatchSize(maxBatchSize) + , mVocabSize(vocabSize) + , mVocabSizePadded(vocabSizePadded) + , mMaxTokensPerStep(maxTokensPerStep) + , mMaxNumHeads(maxNumHeads) +{ + TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + + allocateBuffer(); + + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); +} + +template +MedusaDecodingLayer::~MedusaDecodingLayer() +{ + TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + + freeBuffer(); + + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); +} + +template +void MedusaDecodingLayer::allocateBuffer() +{ + TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + + // Get sampling workspace size + { + auto samplingSizePrimarySampling = getTopKWorkspaceSize(mMaxBatchSize, 1, TOP_K_MAX, mVocabSizePadded); + + auto const maxBatchSizeHeadNums = mMaxBatchSize * mMaxNumHeads; + auto samplingSizeMedusaHeadsSampling + = getTopKWorkspaceSize(maxBatchSizeHeadNums, 1, TOP_K_MAX, mVocabSizePadded); + + mSamplingWorkspaceSize = std::max(samplingSizePrimarySampling, samplingSizeMedusaHeadsSampling); + } + + mDraftIdsPtrHost + = runtime::BufferManager::pinned(ITensor::makeShape({static_cast(mMaxBatchSize), mMaxNumHeads}), + runtime::TRTDataType::value); + mCummulativeTopK.resize(mMaxBatchSize * mMaxNumHeads); + + std::array deviceBufferSizes; + deviceBufferSizes[0] = mMaxBatchSize * sizeof(curandState_t); + deviceBufferSizes[1] = mMaxBatchSize * sizeof(SizeType); + deviceBufferSizes[2] = mMaxBatchSize * mMaxNumHeads * sizeof(SizeType); + deviceBufferSizes[3] = mSamplingWorkspaceSize; + deviceBufferSizes[4] = mMaxBatchSize * sizeof(SizeType); + deviceBufferSizes[5] = mMaxBatchSize * mMaxTokensPerStep * sizeof(TokenIdType); + deviceBufferSizes[6] = mMaxBatchSize * mMaxNumHeads * sizeof(uint64_t); + deviceBufferSizes[7] = mMaxBatchSize * mMaxNumHeads * sizeof(T*); + deviceBufferSizes[8] = mMaxBatchSize * mMaxNumHeads * sizeof(curandState_t); + deviceBufferSizes[9] = mMaxBatchSize * mMaxNumHeads * sizeof(SizeType); + + mCurandStatesDevice = mAllocator->reMalloc(mCurandStatesDevice, deviceBufferSizes[0], false); + mTokensPerStepDevice = mAllocator->reMalloc(mTokensPerStepDevice, deviceBufferSizes[1], false); + mSetupWorkspaceDevice = mAllocator->reMalloc(mSetupWorkspaceDevice, deviceBufferSizes[2], false); + mSamplingWorkspaceDevice = mAllocator->reMalloc(mSamplingWorkspaceDevice, deviceBufferSizes[3], false); + mRuntimeTopKDevice = mAllocator->reMalloc(mRuntimeTopKDevice, deviceBufferSizes[4], false); + mTargetTokensDevice = mAllocator->reMalloc(mTargetTokensDevice, deviceBufferSizes[5], false); + mRandomSeedsDevice = mAllocator->reMalloc(mRandomSeedsDevice, deviceBufferSizes[6], false); + mMedusaLogitsPtrsDevice = mAllocator->reMalloc(mMedusaLogitsPtrsDevice, deviceBufferSizes[7], false); + mCurandStatesMedusaLogitsDevice + = mAllocator->reMalloc(mCurandStatesMedusaLogitsDevice, deviceBufferSizes[8], false); + mRuntimeTopKPerRequestPerMedusaHeadDevice + = mAllocator->reMalloc(mRuntimeTopKPerRequestPerMedusaHeadDevice, deviceBufferSizes[9], false); + + mTiledBatchSlotsSetup = BufferManager::pinnedPool( + ITensor::makeShape({static_cast(mMaxBatchSize * mMaxNumHeads)}), nvinfer1::DataType::kINT32); + mTiledBatchSlotsForward = BufferManager::pinnedPool( + ITensor::makeShape({static_cast(mMaxBatchSize * mMaxNumHeads)}), nvinfer1::DataType::kINT32); + + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); +} + +template +void MedusaDecodingLayer::freeBuffer() +{ + TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + + mAllocator->free((void**) (&mCurandStatesDevice)); + mAllocator->free((void**) (&mTokensPerStepDevice)); + mAllocator->free((void**) (&mSetupWorkspaceDevice)); + mAllocator->free((void**) (&mSamplingWorkspaceDevice)); + mAllocator->free((void**) (&mRuntimeTopKDevice)); + mAllocator->free((void**) (&mTargetTokensDevice)); + mAllocator->free((void**) (&mRandomSeedsDevice)); + mAllocator->free((void**) (&mMedusaLogitsPtrsDevice)); + mAllocator->free((void**) (&mCurandStatesMedusaLogitsDevice)); + mAllocator->free((void**) (&mRuntimeTopKPerRequestPerMedusaHeadDevice)); + + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); +} + +template +void MedusaDecodingLayer::setup(SizeType batchSize, SizeType const* batchSlots, MedusaSetupParams const& setupParams) +{ + TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + + // Prepare random seed + auto initCurandStates = [this](std::optional> const& randomSeed, SizeType batchSize, + SizeType const* batchSlots, curandState_t* statesDevice) + { + if (randomSeed) + { + if (randomSeed->size() == 1) + { + invokeCurandInitialize(statesDevice, batchSlots, batchSize, randomSeed->front(), this->mStream); + sync_check_cuda_error(); + } + else + { + TLLM_CHECK_WITH_INFO(randomSeed->size() == batchSize, "Random seed vector size mismatch."); + cudaAutoCpy(this->mRandomSeedsDevice, randomSeed->data(), batchSize, this->mStream); + invokeCurandBatchInitialize( + statesDevice, batchSlots, batchSize, this->mRandomSeedsDevice, this->mStream); + sync_check_cuda_error(); + } + } + else + { + // Initialize curand states using the default seed 0. + invokeCurandInitialize(statesDevice, batchSlots, batchSize, 0, this->mStream); + } + }; + + initCurandStates(setupParams.randomSeed, batchSize, batchSlots, mCurandStatesDevice); + + auto batchSizeMaxNumHeads = batchSize * mMaxNumHeads; + auto randomSeed = setupParams.randomSeed.value_or(std::vector(batchSize, uint64_t{0})); + std::vector tiledRandomSeed(batchSizeMaxNumHeads); + if (randomSeed.size() > 1) + { + for (SizeType bi = 0; bi < batchSize; ++bi) + { + for (SizeType hi = 0; hi < mMaxNumHeads; ++hi) + { + tiledRandomSeed[bi * mMaxNumHeads + hi] = randomSeed[bi]; + } + } + } + auto tiledBatchSlots = bufferCast(*mTiledBatchSlotsSetup); + for (SizeType bi = 0; bi < batchSize; ++bi) + { + for (SizeType hi = 0; hi < mMaxNumHeads; ++hi) + { + tiledBatchSlots[bi * mMaxNumHeads + hi] = batchSlots[bi] + hi; + } + } + initCurandStates({tiledRandomSeed}, batchSizeMaxNumHeads, tiledBatchSlots, mCurandStatesMedusaLogitsDevice); + + // Prepare tokens per step + { + auto tokensPerStep = setupParams.tokensPerStep.value_or(std::vector{batchSize, mMaxTokensPerStep}); + TLLM_CHECK_WITH_INFO(tokensPerStep.size() == batchSize, + fmtstr("tokensPerStep.size() (%lu) == batchSize (%d) is not satisfied!", tokensPerStep.size(), batchSize)); + + cudaAutoCpy(reinterpret_cast(mSetupWorkspaceDevice), tokensPerStep.data(), batchSize, mStream); + invokeScatterDecodingParams( + reinterpret_cast(mSetupWorkspaceDevice), mTokensPerStepDevice, batchSlots, batchSize, mStream); + } + + // Prepare runtime top K + auto prepareRuntimeTopK = [this](std::vector const& runtimeTopK, SizeType batchSize, + SizeType const* batchSlots, SizeType* runtimeTopKDevice) + { + TLLM_CHECK_WITH_INFO(runtimeTopK.size() == batchSize, + fmtstr("runtimeTopK.size() (%lu) == batchSize (%d) is not satisfied!", runtimeTopK.size(), batchSize)); + + cudaAutoCpy( + reinterpret_cast(this->mSetupWorkspaceDevice), runtimeTopK.data(), batchSize, this->mStream); + invokeScatterDecodingParams(reinterpret_cast(this->mSetupWorkspaceDevice), runtimeTopKDevice, + batchSlots, batchSize, this->mStream); + + // FIXME(nkorobov): monotonically growing + auto const curMaxTopK = *std::max_element(std::begin(runtimeTopK), std::end(runtimeTopK)); + return curMaxTopK; + }; + + auto constexpr defaultTopK = 1u; + { + auto runtimeTopK = setupParams.runtimeTopK.value_or(std::vector{batchSize, defaultTopK}); + auto const curMaxTopK = prepareRuntimeTopK(runtimeTopK, batchSize, batchSlots, mRuntimeTopKDevice); + mRuntimeMaxTopK = std::max(mRuntimeMaxTopK, curMaxTopK); + } + { + auto runtimeHeadsTopK = setupParams.runtimeHeadsTopK; + std::vector runtimeHeadsTopKFlatten; + if (runtimeHeadsTopK.has_value()) + { + for (auto const& sub : runtimeHeadsTopK.value()) + { + runtimeHeadsTopKFlatten.insert(runtimeHeadsTopKFlatten.end(), sub.begin(), sub.end()); + } + } + else + { + runtimeHeadsTopKFlatten = std::vector(batchSizeMaxNumHeads, defaultTopK); + } + + for (SizeType bi = 0; bi < batchSize; ++bi) + { + auto const slot = batchSlots[bi]; + SizeType cummulativeTopK = 0; + for (SizeType hi = 0; hi < mMaxNumHeads; ++hi) + { + mCummulativeTopK[slot * mMaxNumHeads + hi] = cummulativeTopK; + cummulativeTopK += runtimeHeadsTopKFlatten[bi * mMaxNumHeads + hi]; + } + } + + auto tiledBatchSlots = bufferCast(*mTiledBatchSlotsSetup); + for (SizeType bi = 0; bi < batchSize; ++bi) + { + for (SizeType hi = 0; hi < mMaxNumHeads; ++hi) + { + tiledBatchSlots[bi * mMaxNumHeads + hi] = mMaxNumHeads * batchSlots[bi] + hi; + } + } + + auto const curMaxTopK = prepareRuntimeTopK(runtimeHeadsTopKFlatten, static_cast(batchSizeMaxNumHeads), + tiledBatchSlots, mRuntimeTopKPerRequestPerMedusaHeadDevice); + mRuntimeMaxTopKPerRequestPerMedusaHead = std::max(mRuntimeMaxTopKPerRequestPerMedusaHead, curMaxTopK); + } + + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); +} + +template +void MedusaDecodingLayer::forward(DecodingOutputParams& outputs, MedusaForwardParams& inputs) +{ + TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + + samplePrimeHeadTokens(outputs, inputs); + + acceptDraftTokens(outputs, inputs); + + sampleNewDraftTokens(outputs, inputs); + + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); +} + +template +void MedusaDecodingLayer::samplePrimeHeadTokens(DecodingOutputParams& outputs, MedusaForwardParams& inputs) +{ + TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + + auto const batchSize = inputs.logits.shape[0]; + + auto logits = inputs.logits.template getPtr(); + auto batchSlots + = inputs.batch_slots ? inputs.batch_slots->template getPtr() : static_cast(nullptr); + auto sequenceLengths = (outputs.sequence_length) ? outputs.sequence_length->template getPtr() + : static_cast(nullptr); + + TLLM_CHECK_WITH_INFO(batchSlots != nullptr, "Batch slots must be provided for MedusaDecoding"); + TLLM_CHECK_WITH_INFO(sequenceLengths != nullptr, "Sequence lengths must be provided for MedusaDecoding"); + + // Sample multiple tokens per request and store them to separate to be accepted/rejected later + // Sequence length is not modified, endIds is not checked, outputLogProbs are not supported. + // Finished state is not set. + invokeBatchTopKSampling(mSamplingWorkspaceDevice, logits, /* logProbsPtrs */ static_cast(nullptr), + /* outputIdsPtrs */ nullptr, mTargetTokensDevice, sequenceLengths, + /* finishedInput */ nullptr, /* finishedOutput */ nullptr, + /* cumLogProbs */ nullptr, /* outputLogProbs */ nullptr, mCurandStatesDevice, mRuntimeMaxTopK, + mRuntimeTopKDevice, 1.0f, /* runtimeTopPDevice */ nullptr, mVocabSizePadded, /* endIds */ nullptr, batchSlots, + mStream, batchSize, mMaxBatchSize, mTokensPerStepDevice, mMaxTokensPerStep, mMaxTokensPerStep, + /* skipDecode */ nullptr, /* normalizeLogProbs */ false, + /* probsComputed */ false, /* return all Top-K*/ false); + + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); +} + +template +void MedusaDecodingLayer::acceptDraftTokens(DecodingOutputParams& outputs, MedusaForwardParams& inputs) +{ + TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + + auto const batchSize = inputs.logits.shape[0]; + auto const maxSeqLen = outputs.output_ids.shape[outputs.output_ids.shape.size() - 1]; + + auto outputIds = outputs.output_ids.template getPtr(); + auto endIds = inputs.end_ids.template getPtr(); + auto paths = inputs.paths.template getPtr(); + auto medusaLogits = inputs.medusaLogits.template getPtr(); + + auto batchSlots + = inputs.batch_slots ? inputs.batch_slots->template getPtr() : static_cast(nullptr); + auto sequenceLengths = outputs.sequence_length ? outputs.sequence_length->template getPtr() + : static_cast(nullptr); + auto acceptedLengths = outputs.acceptedLengths ? outputs.acceptedLengths->template getPtr() + : static_cast(nullptr); + + TLLM_CHECK_WITH_INFO(batchSlots != nullptr, "Batch slots must be provided for MedusaDecoding"); + TLLM_CHECK_WITH_INFO(sequenceLengths != nullptr, "Sequence lengths must be provided for MedusaDecoding"); + TLLM_CHECK_WITH_INFO(acceptedLengths != nullptr, "Accepted lengths must be provided for MedusaDecoding"); + + auto finishedStates + = reinterpret_cast(outputs.finished->template getPtr()); + + // Compare draft tokens from outputIds with sampled target tokens at mTargetTokensDevice using paths. + // Select the longest accepted path, modify outputIds in-place, increment sequenceLengths accordingly. + // Fill mMedusaLogitsPtrsDevice with respective Medusa logits + acceptDraftTokensByIdsWithPaths(outputIds, mTargetTokensDevice, sequenceLengths, acceptedLengths, finishedStates, + batchSlots, paths, endIds, medusaLogits, const_cast(mMedusaLogitsPtrsDevice), batchSize, mVocabSize, + mMaxBatchSize, maxSeqLen, mMaxTokensPerStep, mMaxNumHeads, mMaxTokensPerStep, mStream); + + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); +} + +template +void MedusaDecodingLayer::sampleNewDraftTokens(DecodingOutputParams& outputs, MedusaForwardParams& inputs) +{ + TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + + auto const batchSize = inputs.logits.shape[0]; + auto batchSlots + = inputs.batch_slots ? inputs.batch_slots->template getPtr() : static_cast(nullptr); + auto sequenceLengths = (outputs.sequence_length) ? outputs.sequence_length->template getPtr() + : static_cast(nullptr); + + TLLM_CHECK_WITH_INFO(batchSlots != nullptr, "Batch slots must be provided for MedusaDecoding"); + TLLM_CHECK_WITH_INFO(sequenceLengths != nullptr, "Sequence lengths must be provided for MedusaDecoding"); + + // For each request we sample Head Num times for topK[hi] tokens + auto const batchSizeHeadNums = batchSize * mMaxNumHeads; + auto const maxBatchSizeHeadNums = mMaxBatchSize * mMaxNumHeads; + + auto tiledBatchSlots = bufferCast(*mTiledBatchSlotsForward); + for (SizeType bi = 0; bi < batchSize; ++bi) + { + for (SizeType hi = 0; hi < mMaxNumHeads; ++hi) + { + tiledBatchSlots[bi * mMaxNumHeads + hi] = mMaxNumHeads * batchSlots[bi] + hi; + } + } + + auto draftIdsPtrs = reinterpret_cast(bufferCast(*mDraftIdsPtrHost)); + auto draftIds = (outputs.nextDraftTokens) ? outputs.nextDraftTokens->template getPtr() + : static_cast(nullptr); + TLLM_CHECK_WITH_INFO(draftIds != nullptr, "Draft ids must be provided for MedusaDecoding"); + + for (SizeType bi = 0; bi < batchSize; ++bi) + { + auto slot = batchSlots[bi]; + for (SizeType hi = 0; hi < mMaxNumHeads; ++hi) + { + draftIdsPtrs[slot * mMaxNumHeads + hi] + = draftIds + slot * mMaxTokensPerStep + mCummulativeTopK[slot * mMaxNumHeads + hi]; + } + } + + invokeBatchTopKSampling(mSamplingWorkspaceDevice, + /* logits */ static_cast(nullptr), const_cast(mMedusaLogitsPtrsDevice), draftIdsPtrs, + /* outputIds */ nullptr, /* sequenceLength */ nullptr, + /* finishedInput */ nullptr, /* finishedOutput */ nullptr, + /* cumLogProbs */ nullptr, /* outputLogProbs */ nullptr, mCurandStatesMedusaLogitsDevice, + mRuntimeMaxTopKPerRequestPerMedusaHead, mRuntimeTopKPerRequestPerMedusaHeadDevice, 1.0f, + /* runtimeTopPDevice */ nullptr, mVocabSizePadded, /* endIds */ nullptr, tiledBatchSlots, mStream, + batchSizeHeadNums, maxBatchSizeHeadNums, + /* tokensPerStep */ nullptr, /* maxTokensPerStep */ 1, + /* maxSeqLen (not required as outputIds is nullptr) */ 0, + /* skipDecode */ nullptr, /* normalizeLogProbs */ false, + /* probsComputed */ false, /* return all Top-K*/ true); + + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); +} + +template class MedusaDecodingLayer; +template class MedusaDecodingLayer; + +} // namespace layers +} // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/layers/medusaDecodingLayer.h b/cpp/tensorrt_llm/layers/medusaDecodingLayer.h new file mode 100644 index 000000000..cde062993 --- /dev/null +++ b/cpp/tensorrt_llm/layers/medusaDecodingLayer.h @@ -0,0 +1,118 @@ +/* + * Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2021, NAVER Corp. Authored by CLOVA. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#include "tensorrt_llm/common/tensor.h" +#include "tensorrt_llm/layers/baseLayer.h" +#include "tensorrt_llm/layers/decodingParams.h" +#include "tensorrt_llm/runtime/common.h" +#include "tensorrt_llm/runtime/decodingMode.h" +#include "tensorrt_llm/runtime/iTensor.h" + +namespace tc = tensorrt_llm::common; + +namespace tensorrt_llm +{ +namespace layers +{ + +//! \brief +template +class MedusaDecodingLayer : public BaseLayer +{ +public: + using Base = BaseLayer; + using PathsVec = std::vector>>; + + class MedusaSetupParams : public DecodingSetupParams + { + public: + std::optional> runtimeTopK; // [1] or [batchSize] on cpu + std::optional>> + runtimeHeadsTopK; // [batchSize, maxMedusaHeads] on cpu + std::optional> randomSeed; // [1] or [batchSize] on cpu + std::optional> tokensPerStep; // [1] or [batchSize] on cpu + }; + + class MedusaForwardParams : public DecodingParams + { + public: + MedusaForwardParams(tc::Tensor logits, tc::Tensor endIds) + : DecodingParams{0, 0, std::move(logits), std::move(endIds)} + { + } + + tc::Tensor paths; // [maxBatchSize, maxTokensPerStep, maxNumHeads + 1] on gpu + tc::Tensor medusaLogits; // [maxNumHeads, maxBatchSize, maxTokensPerStep, vocabSize] on gpu + }; + + MedusaDecodingLayer(runtime::SizeType maxBatchSize, runtime::SizeType vocabSize, runtime::SizeType vocabSizePadded, + runtime::SizeType maxTokensPerStep, runtime::SizeType maxNumHeads, cudaStream_t stream, + std::shared_ptr allocator); + + ~MedusaDecodingLayer() override; + + void setup(runtime::SizeType batchSize, runtime::SizeType const* batchSlots, MedusaSetupParams const& setupParams); + + void forward(DecodingOutputParams& outputs, MedusaForwardParams& inputs); + +private: + void allocateBuffer(); + void freeBuffer(); + + void samplePrimeHeadTokens(DecodingOutputParams& outputs, MedusaForwardParams& inputs); + void acceptDraftTokens(DecodingOutputParams& outputs, MedusaForwardParams& inputs); + void sampleNewDraftTokens(DecodingOutputParams& outputs, MedusaForwardParams& inputs); + +private: + using Base::mStream; + using Base::mAllocator; + + runtime::SizeType mMaxBatchSize; + runtime::SizeType mVocabSize; + runtime::SizeType mVocabSizePadded; + + runtime::SizeType mMaxTokensPerStep; + runtime::SizeType mMaxNumHeads; + + size_t mSamplingWorkspaceSize; + runtime::SizeType mRuntimeMaxTopK{0}; + runtime::SizeType mRuntimeMaxTopKPerRequestPerMedusaHead{0}; + + curandState_t* mCurandStatesDevice{nullptr}; + runtime::SizeType* mTokensPerStepDevice{nullptr}; + void* mSetupWorkspaceDevice{nullptr}; + void* mSamplingWorkspaceDevice{nullptr}; + runtime::SizeType* mRuntimeTopKDevice{nullptr}; + runtime::TokenIdType* mTargetTokensDevice{nullptr}; + uint64_t* mRandomSeedsDevice{nullptr}; + T** mMedusaLogitsPtrsDevice{nullptr}; + curandState_t* mCurandStatesMedusaLogitsDevice{nullptr}; + runtime::SizeType* mRuntimeTopKPerRequestPerMedusaHeadDevice{nullptr}; + + runtime::ITensor::UniquePtr mTiledBatchSlotsSetup; + runtime::ITensor::UniquePtr mTiledBatchSlotsForward; + runtime::ITensor::UniquePtr mDraftIdsPtrHost; + + std::vector mCummulativeTopK; +}; + +} // namespace layers +} // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/layers/onlineBeamSearchLayer.cu b/cpp/tensorrt_llm/layers/onlineBeamSearchLayer.cu index 0cabf6e15..6d718852d 100644 --- a/cpp/tensorrt_llm/layers/onlineBeamSearchLayer.cu +++ b/cpp/tensorrt_llm/layers/onlineBeamSearchLayer.cu @@ -31,73 +31,7 @@ static int const SMALL_TOP_K_SOFTMAX_MAX_VOC_PARTS = 128; static int const MAX_K = 4; template -__global__ void update_kernel(BeamHypotheses beam_hyps) -{ - int const beam_width{beam_hyps.beam_width}; - int const ite{beam_hyps.ite}; - int const local_batch_size{beam_hyps.local_batch_size}; - int const max_seq_len{beam_hyps.max_seq_len}; - int const vocab_size{beam_hyps.vocab_size}; - int const end_id{beam_hyps.end_ids[blockIdx.x]}; - int* num_beams{beam_hyps.num_beams}; - int* sequence_lengths{beam_hyps.sequence_lengths_src}; - int** output_ids_ptr{beam_hyps.output_ids_tgt_ptr}; - int** parent_ids_ptr{beam_hyps.parent_ids_tgt_ptr}; - FinishedState* finished{beam_hyps.finished}; - - extern __shared__ char s_buf[]; // intermediate result - int* s_sequence_lengths = reinterpret_cast(s_buf); - - for (int beam_idx = threadIdx.x; beam_idx < beam_width; beam_idx += blockDim.x) - { - auto const batch_beam_idx = blockIdx.x * beam_width + beam_idx; - s_sequence_lengths[beam_idx] = sequence_lengths[batch_beam_idx]; - } - __syncthreads(); - - for (int beam_idx = threadIdx.x; beam_idx < beam_width; beam_idx += blockDim.x) - { - auto const batch_beam_idx = blockIdx.x * beam_width + beam_idx; - int const current_step{s_sequence_lengths[beam_idx]}; - - // Increase the seq_len even if the request has finished. - // On the following iteration we check if the sequence has finished before - auto const finish_state = finished[batch_beam_idx]; - if (!finish_state.isFinished()) - { - s_sequence_lengths[beam_idx]++; - } - - int new_word_id{output_ids_ptr[blockIdx.x][beam_idx * max_seq_len + current_step]}; - int new_beam_id{(new_word_id / vocab_size) % beam_width}; - new_word_id = new_word_id % vocab_size; - - sequence_lengths[batch_beam_idx] = s_sequence_lengths[new_beam_id]; - if (new_word_id == end_id) - { - finished[batch_beam_idx].setFinishedEOS(); - } - parent_ids_ptr[blockIdx.x][beam_idx * max_seq_len + current_step] = new_beam_id; - output_ids_ptr[blockIdx.x][beam_idx * max_seq_len + current_step] = new_word_id; - } - if (num_beams != nullptr && num_beams[ite * local_batch_size + blockIdx.x] == beam_width) - { - for (int beam_idx = threadIdx.x; beam_idx < beam_width; beam_idx += blockDim.x) - { - finished[blockIdx.x * beam_width + beam_idx].setFinished(); - } - } -} - -void invokeUpdate(BeamHypotheses& beam_hyps, cudaStream_t stream) -{ - dim3 grid(beam_hyps.local_batch_size); - dim3 block(min(beam_hyps.beam_width, 1024)); - update_kernel<<>>(beam_hyps); -} - -template -void OnlineBeamSearchLayer::setup(size_t batch_size, SetupParams const& setupParams) +void OnlineBeamSearchLayer::setup(runtime::SizeType batch_size, SetupParams const& setupParams) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); BaseBeamSearchLayer::setupBase(batch_size, setupParams); @@ -118,45 +52,36 @@ template void OnlineBeamSearchLayer::invokeSoftMax(BeamSearchOutputParams& outputs, SoftmaxParams const& params) { TLLM_LOG_TRACE("%s", __PRETTY_FUNCTION__); - - BeamHypotheses beam_hyps; - if (outputs.beamHypotheses) - { - beam_hyps = *outputs.beamHypotheses; - beam_hyps.end_ids = params.end_ids.template getPtr(); - beam_hyps.finished - = reinterpret_cast(outputs.finished->template getPtr()); - beam_hyps.cum_log_probs_src = outputs.cum_log_probs->template getPtr(); - beam_hyps.log_probs_src - = (outputs.output_log_probs) ? outputs.output_log_probs->template getPtr() : nullptr; - beam_hyps.sequence_lengths_src = outputs.sequence_length->template getPtr(); - beam_hyps.output_ids_tgt_ptr = outputs.output_ids_ptr.template getPtr(); - beam_hyps.parent_ids_tgt_ptr = outputs.parent_ids_ptr.template getPtr(); - - beam_hyps.diversity_rates = diversity_rates_buf_; - beam_hyps.length_penalties = length_penalties_buf_; - beam_hyps.early_stoppings = early_stoppings_buf_; - - beam_hyps.batch_size = static_cast(outputs.output_ids_ptr.shape[0]); - beam_hyps.beam_width = static_cast(outputs.output_ids_ptr.shape[1]); - beam_hyps.ite = params.ite; - beam_hyps.local_batch_size = params.logits.shape[0]; - beam_hyps.max_seq_len = static_cast(outputs.output_ids_ptr.shape[2]); - beam_hyps.vocab_size = vocab_size_padded_; - } + TLLM_CHECK_WITH_INFO(outputs.beamHypotheses, std::string("Output BeamHypotheses is not set")); + + BeamHypotheses bh{*outputs.beamHypotheses}; + bh.end_ids = params.end_ids.template getPtr(); + bh.finished = reinterpret_cast(outputs.finished->template getPtr()); + bh.cum_log_probs_src = outputs.cum_log_probs->template getPtr(); + bh.log_probs_src = (outputs.output_log_probs) ? outputs.output_log_probs->template getPtr() : nullptr; + bh.sequence_lengths_src = outputs.sequence_length->template getPtr(); + bh.output_ids_tgt_ptr = outputs.output_ids_ptr.template getPtr(); + bh.parent_ids_tgt_ptr = outputs.parent_ids_ptr.template getPtr(); + bh.diversity_rates = diversity_rates_buf_; + bh.length_penalties = length_penalties_buf_; + bh.early_stoppings = early_stoppings_buf_; + + bh.batch_size = static_cast(outputs.output_ids_ptr.shape[0]); + bh.beam_width = static_cast(outputs.output_ids_ptr.shape[1]); + bh.ite = params.ite; + bh.local_batch_size = params.logits.shape[0]; + bh.max_seq_len = static_cast(outputs.output_ids_ptr.shape[2]); + bh.vocab_size = vocab_size_padded_; T const* logits = params.logits.template getPtr(); T const* bias = static_cast(nullptr); - invokeTopkSoftMax(logits, bias, topk_softmax_workspace_, topk_softmax_workspace_size_, beam_hyps, mStream); - sync_check_cuda_error(); - - invokeUpdate(beam_hyps, mStream); + invokeTopkSoftMax(logits, bias, topk_softmax_workspace_, topk_softmax_workspace_size_, bh, mStream); sync_check_cuda_error(); } template -void OnlineBeamSearchLayer::allocateBuffer(size_t batch_size) +void OnlineBeamSearchLayer::allocateBuffer(runtime::SizeType batch_size) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); // we need to check 2 * beam_width candidates each time @@ -190,8 +115,8 @@ void OnlineBeamSearchLayer::freeBuffer() } template -OnlineBeamSearchLayer::OnlineBeamSearchLayer( - size_t vocab_size, size_t vocab_size_padded, cudaStream_t stream, std::shared_ptr allocator) +OnlineBeamSearchLayer::OnlineBeamSearchLayer(runtime::SizeType vocab_size, runtime::SizeType vocab_size_padded, + cudaStream_t stream, std::shared_ptr allocator) : BaseBeamSearchLayer(vocab_size, vocab_size_padded, stream, std::move(allocator)) { } diff --git a/cpp/tensorrt_llm/layers/onlineBeamSearchLayer.h b/cpp/tensorrt_llm/layers/onlineBeamSearchLayer.h index 3e040d3cf..a0745071a 100644 --- a/cpp/tensorrt_llm/layers/onlineBeamSearchLayer.h +++ b/cpp/tensorrt_llm/layers/onlineBeamSearchLayer.h @@ -44,14 +44,14 @@ class OnlineBeamSearchLayer : public BaseBeamSearchLayer std::optional> early_stopping; // [1] or [batch_size] on cpu }; - OnlineBeamSearchLayer( - size_t vocab_size, size_t vocab_size_padded, cudaStream_t stream, std::shared_ptr allocator); + OnlineBeamSearchLayer(runtime::SizeType vocab_size, runtime::SizeType vocab_size_padded, cudaStream_t stream, + std::shared_ptr allocator); OnlineBeamSearchLayer(OnlineBeamSearchLayer const& beam_search_layer); ~OnlineBeamSearchLayer() override; - void setup(size_t batch_size, SetupParams const& setupParams); + void setup(runtime::SizeType batch_size, SetupParams const& setupParams); protected: // meta data @@ -78,7 +78,7 @@ class OnlineBeamSearchLayer : public BaseBeamSearchLayer int* early_stoppings_buf_; private: - void allocateBuffer(size_t batch_size); + void allocateBuffer(runtime::SizeType batch_size); void freeBuffer(); }; diff --git a/cpp/tensorrt_llm/layers/samplingLayer.cpp b/cpp/tensorrt_llm/layers/samplingLayer.cpp index db68ba20c..311df89b9 100644 --- a/cpp/tensorrt_llm/layers/samplingLayer.cpp +++ b/cpp/tensorrt_llm/layers/samplingLayer.cpp @@ -32,9 +32,36 @@ namespace tensorrt_llm namespace layers { template -void SamplingLayer::allocateBuffer(size_t batchSize) +SamplingLayer::SamplingLayer(DecodingMode const& mode, SizeType maxBatchSize, SizeType vocabSize, + SizeType vocabSizePadded, cudaStream_t stream, std::shared_ptr allocator, cudaDeviceProp* prop) + : BaseSamplingLayer(maxBatchSize, vocabSize, vocabSizePadded, stream, std::move(allocator), nullptr) + , mDecodingMode(mode) { - TLLM_LOG_TRACE(__PRETTY_FUNCTION__); + TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + + TLLM_CHECK_WITH_INFO(!mDecodingMode.isBeamSearch(), "SamplingLayer does not support Beam search mode"); + TLLM_CHECK_WITH_INFO(mDecodingMode.isTopKorTopP(), "SamplingLayer requires TopK nor TopP mode"); + if (mDecodingMode.isTopK()) + { + mTopKDecode + = std::make_unique>(maxBatchSize, vocabSize, vocabSizePadded, mStream, mAllocator); + } + + if (mDecodingMode.isTopP()) + { + mTopPDecode = std::make_unique>( + maxBatchSize, vocabSize, vocabSizePadded, mStream, mAllocator, prop, /* deterministic */ true); + } + + allocateBuffer(maxBatchSize); + + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); +} + +template +void SamplingLayer::allocateBuffer(SizeType batchSize) +{ + TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); mSamplingWorkspaceSize = 0; if (mDecodingMode.isTopK()) @@ -73,46 +100,28 @@ void SamplingLayer::allocateBuffer(size_t batchSize) // host buffers. mSkipDecodeHost = (bool*) std::realloc(mSkipDecodeHost, sizeof(bool) * batchSize); TLLM_CHECK(mSkipDecodeHost != nullptr); + + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void SamplingLayer::freeBuffer() { - TLLM_LOG_TRACE(__PRETTY_FUNCTION__); + TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + mAllocator->free((void**) (&mCurandStatesDevice)); mAllocator->free((void**) (&mRandomSeedsDevice)); mAllocator->free((void**) (&mSkipDecodeDevice)); mAllocator->free((void**) (&mSamplingWorkspaceDevice)); std::free(mSkipDecodeHost); -} -template -SamplingLayer::SamplingLayer(DecodingMode const& mode, size_t maxBatchSize, size_t vocabSize, size_t vocabSizePadded, - cudaStream_t stream, std::shared_ptr allocator, cudaDeviceProp* prop) - : BaseSamplingLayer(maxBatchSize, vocabSize, vocabSizePadded, stream, std::move(allocator), nullptr) - , mDecodingMode(mode) -{ - TLLM_CHECK_WITH_INFO(!mDecodingMode.isBeamSearch(), "Beam search mode has been requested from Sampling Layer"); - TLLM_CHECK_WITH_INFO(mDecodingMode.isTopKorTopP(), "Requested mode is neither TopK nor TopP"); - if (mDecodingMode.isTopK()) - { - mTopKDecode - = std::make_unique>(maxBatchSize, vocabSize, vocabSizePadded, mStream, mAllocator); - } - - if (mDecodingMode.isTopP()) - { - mTopPDecode = std::make_unique>( - maxBatchSize, vocabSize, vocabSizePadded, mStream, mAllocator, prop, /* deterministic */ true); - } - - allocateBuffer(maxBatchSize); + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template -void SamplingLayer::setup(const size_t batchSize, int32_t const* batchSlots, SetupParams const& setupParams) +void SamplingLayer::setup(SizeType batchSize, SizeType const* batchSlots, SetupParams const& setupParams) { - TLLM_LOG_TRACE(__PRETTY_FUNCTION__); + TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); // If runtime argument has single random seed, using this random seed to // initialize the random table of all sentences. If the argument has @@ -149,6 +158,8 @@ void SamplingLayer::setup(const size_t batchSize, int32_t const* batchSlots, { mTopPDecode->setup(batchSize, batchSlots, setupParams); } + + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template diff --git a/cpp/tensorrt_llm/layers/samplingLayer.h b/cpp/tensorrt_llm/layers/samplingLayer.h index e9829d306..3c934aae2 100644 --- a/cpp/tensorrt_llm/layers/samplingLayer.h +++ b/cpp/tensorrt_llm/layers/samplingLayer.h @@ -34,9 +34,11 @@ namespace layers { template -inline bool allOfBatchSlots(int32_t const* batchSlotsHost, T const* data, size_t batchSize, T value) +inline bool allOfBatchSlots( + runtime::SizeType const* batchSlotsHost, T const* data, runtime::SizeType batchSize, T value) { - return std::all_of(batchSlotsHost, batchSlotsHost + batchSize, [&](int32_t b) { return data[b] == value; }); + return std::all_of( + batchSlotsHost, batchSlotsHost + batchSize, [&](runtime::SizeType b) { return data[b] == value; }); }; //! \brief Top class for sampling layers. @@ -49,14 +51,16 @@ class SamplingLayer : public BaseSamplingLayer using SetupParams = typename Base::SetupParams; using ForwardParams = typename Base::ForwardParams; - SamplingLayer(runtime::DecodingMode const& mode, size_t maxBatchSize, size_t vocabSize, size_t vocabSizePadded, - cudaStream_t stream, std::shared_ptr allocator, cudaDeviceProp* prop); + SamplingLayer(runtime::DecodingMode const& mode, runtime::SizeType maxBatchSize, runtime::SizeType vocabSize, + runtime::SizeType vocabSizePadded, cudaStream_t stream, + std::shared_ptr allocator, cudaDeviceProp* prop); ~SamplingLayer() override = default; void forward(DecodingOutputParams& outputs, ForwardParams& inputs) override; - void setup(size_t batchSize, int32_t const* batchSlots, SetupParams const& setupParams) override; + void setup( + runtime::SizeType batchSize, runtime::SizeType const* batchSlots, SetupParams const& setupParams) override; private: using Base::mMaxBatchSize; @@ -83,7 +87,7 @@ class SamplingLayer : public BaseSamplingLayer std::unique_ptr> mTopPDecode; private: - void allocateBuffer(size_t batchSize); + void allocateBuffer(runtime::SizeType batchSize); void freeBuffer(); }; diff --git a/cpp/tensorrt_llm/layers/topKSamplingLayer.cu b/cpp/tensorrt_llm/layers/topKSamplingLayer.cu index bcbab8366..fe01c10b6 100644 --- a/cpp/tensorrt_llm/layers/topKSamplingLayer.cu +++ b/cpp/tensorrt_llm/layers/topKSamplingLayer.cu @@ -35,16 +35,16 @@ namespace tensorrt_llm namespace layers { -template -__global__ void setupTopKRuntimeArgs(int batchSize, uint32_t topK, uint32_t* topKs, int topKsSize, float topP, - float* topPs, int topPsSize, bool* skipDecode, int const* batchSlots) +template +__global__ void setupTopKRuntimeArgs(SizeType batchSize, SizeType topK, SizeType* topKs, SizeType topKsSize, float topP, + float* topPs, SizeType topPsSize, bool* skipDecode, SizeType const* batchSlots) { - int index = blockIdx.x * blockDim.x + threadIdx.x; - for (int bi = index; bi < batchSize; bi += gridDim.x * blockDim.x) + auto const index = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + for (auto bi = index; bi < batchSize; bi += static_cast(gridDim.x * blockDim.x)) { auto const batchSlot = batchSlots != nullptr ? batchSlots[bi] : bi; - uint32_t k = topKsSize > 1 ? topKs[batchSlot] : topK; - float p = topPsSize > 1 ? topPs[batchSlot] : topP; + auto k = topKsSize > 1 ? topKs[batchSlot] : topK; + auto p = topPsSize > 1 ? topPs[batchSlot] : topP; if (k == 0 && p == 0.0f) { // TensorRT-LLM's topp implementation does not support topp = 0.0f, but it @@ -70,13 +70,36 @@ __global__ void setupTopKRuntimeArgs(int batchSize, uint32_t topK, uint32_t* top } template -void TopKSamplingLayer::allocateBuffer(size_t const batchSize) +TopKSamplingLayer::TopKSamplingLayer(SizeType maxBatchSize, SizeType vocabSize, SizeType vocabSizePadded, + cudaStream_t stream, std::shared_ptr allocator) + : BaseSamplingLayer(maxBatchSize, vocabSize, vocabSizePadded, stream, std::move(allocator), nullptr) +{ + TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + + allocateBuffer(mMaxBatchSize); + + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); +} + +template +TopKSamplingLayer::~TopKSamplingLayer() +{ + TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + + freeBuffer(); + + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); +} + +template +void TopKSamplingLayer::allocateBuffer(SizeType const batchSize) { - TLLM_LOG_TRACE(__PRETTY_FUNCTION__); + TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + mSamplingWorkspaceSize = getTopKWorkspaceSize(batchSize, 1, TOP_K_MAX, mVocabSizePadded); std::array deviceBufferSizes; - deviceBufferSizes[0] = sizeof(uint32_t) * batchSize; + deviceBufferSizes[0] = sizeof(SizeType) * batchSize; deviceBufferSizes[1] = sizeof(float) * batchSize; deviceBufferSizes[2] = sizeof(bool) * batchSize; deviceBufferSizes[3] = std::max(deviceBufferSizes[0], deviceBufferSizes[1]); @@ -86,34 +109,39 @@ void TopKSamplingLayer::allocateBuffer(size_t const batchSize) mSkipDecodeDevice = mAllocator->reMalloc(mSkipDecodeDevice, deviceBufferSizes[2], false); mSetupWorkspaceDevice = mAllocator->reMalloc(mSetupWorkspaceDevice, deviceBufferSizes[3], false); - mSkipDecodeHost = (bool*) std::realloc(mSkipDecodeHost, sizeof(bool) * batchSize); + mSkipDecodeHost = static_cast(std::realloc(mSkipDecodeHost, sizeof(bool) * batchSize)); mAllocatedSize = std::accumulate(deviceBufferSizes.begin(), deviceBufferSizes.end(), 0); TLLM_LOG_DEBUG("topKSamplingLayer allocated %lu bytes on GPU", mAllocatedSize); + + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void TopKSamplingLayer::freeBuffer() { - TLLM_LOG_TRACE(__PRETTY_FUNCTION__); + TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + mAllocator->free((void**) (&mRuntimeTopKDevice)); mAllocator->free((void**) (&mRuntimeTopPDevice)); mAllocator->free((void**) (&mSkipDecodeDevice)); mAllocator->free((void**) (&mSetupWorkspaceDevice)); std::free(mSkipDecodeHost); + + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template -void TopKSamplingLayer::setup(size_t const batchSize, int32_t const* batchSlots, SetupParams const& setupParams) +void TopKSamplingLayer::setup(SizeType const batchSize, SizeType const* batchSlots, SetupParams const& setupParams) { - TLLM_LOG_TRACE(__PRETTY_FUNCTION__); + TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); - uint32_t constexpr defaultTopK = 0; - auto runtimeTopK = setupParams.runtime_top_k.value_or(std::vector{defaultTopK}); + SizeType constexpr defaultTopK = 0; + auto runtimeTopK = setupParams.runtime_top_k.value_or(std::vector{defaultTopK}); auto runtimeTopP = setupParams.runtime_top_p.value_or(std::vector{}); - size_t const runtimeTopKSize = runtimeTopK.size(); - size_t const runtimeTopPSize = runtimeTopP.size(); + auto const runtimeTopKSize = runtimeTopK.size(); + auto const runtimeTopPSize = runtimeTopP.size(); mNormalizeLogProbs = setupParams.normalize_log_probs.has_value() && setupParams.normalize_log_probs.value(); for (auto& topP : runtimeTopP) @@ -134,42 +162,43 @@ void TopKSamplingLayer::setup(size_t const batchSize, int32_t const* batchSlo } } - uint32_t const topK = *std::max_element(std::begin(runtimeTopK), std::end(runtimeTopK)); - float const topP = (runtimeTopPSize == 0) ? 0.0f : runtimeTopP.front(); + auto const topK = *std::max_element(std::begin(runtimeTopK), std::end(runtimeTopK)); + auto const topP = (runtimeTopPSize == 0) ? 0.0f : runtimeTopP.front(); if (runtimeTopKSize > 1) { TLLM_CHECK_WITH_INFO(runtimeTopK.size() == batchSize, - fmtstr("runtimeTopK.size() (%lu) == batchSize (%lu) is not satisfied!", runtimeTopK.size(), batchSize)); - cudaAutoCpy(reinterpret_cast(mSetupWorkspaceDevice), runtimeTopK.data(), batchSize, mStream); - invokeScatterDecodingParams( - reinterpret_cast(mSetupWorkspaceDevice), mRuntimeTopKDevice, batchSlots, batchSize, mStream); + fmtstr("runtimeTopK.size() (%lu) == batchSize (%d) is not satisfied!", runtimeTopK.size(), batchSize)); + cudaAutoCpy( + reinterpret_cast(mSetupWorkspaceDevice), runtimeTopK.data(), batchSize, mStream); + invokeScatterDecodingParams(reinterpret_cast(mSetupWorkspaceDevice), mRuntimeTopKDevice, + batchSlots, batchSize, mStream); } if (runtimeTopPSize > 1) { TLLM_CHECK_WITH_INFO(runtimeTopP.size() == batchSize, - fmtstr("runtimeTopP.size() (%lu) == batchSize (%lu) is not satisfied!", runtimeTopP.size(), batchSize)); + fmtstr("runtimeTopP.size() (%lu) == batchSize (%d) is not satisfied!", runtimeTopP.size(), batchSize)); cudaAutoCpy(reinterpret_cast(mSetupWorkspaceDevice), runtimeTopP.data(), batchSize, mStream); invokeScatterDecodingParams( reinterpret_cast(mSetupWorkspaceDevice), mRuntimeTopPDevice, batchSlots, batchSize, mStream); } { - dim3 block(std::min((int) batchSize, 256)); - dim3 grid(divUp((int) batchSize, (int) block.x)); + dim3 block(std::min(static_cast(batchSize), 256u)); + dim3 grid(divUp(static_cast(batchSize), block.x)); // support topK up to TOP_K_MAX. setupTopKRuntimeArgs<<>>(batchSize, topK, mRuntimeTopKDevice, runtimeTopKSize, topP, mRuntimeTopPDevice, runtimeTopPSize, mSkipDecodeDevice, batchSlots); } cudaAutoCpy(mSkipDecodeHost, mSkipDecodeDevice, mMaxBatchSize, mStream); - std::vector runtimeTopKs(mMaxBatchSize); + std::vector runtimeTopKs(mMaxBatchSize); cudaAutoCpy(runtimeTopKs.data(), mRuntimeTopKDevice, mMaxBatchSize, mStream); { - uint32_t maxTopK = 0; - for (size_t bi = 0; bi < batchSize; ++bi) + runtime::SizeType maxTopK = 0; + for (SizeType bi = 0; bi < static_cast(batchSize); ++bi) { - uint32_t bid = bi; + auto bid = bi; if (batchSlots) { bid = batchSlots[bi]; @@ -178,6 +207,8 @@ void TopKSamplingLayer::setup(size_t const batchSize, int32_t const* batchSlo } mRuntimeMaxTopK = std::max(mRuntimeMaxTopK, maxTopK); } + + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template @@ -188,8 +219,8 @@ void TopKSamplingLayer::forward(DecodingOutputParams& outputs, ForwardParams& auto const batchSize = inputs.logits.shape[0]; auto logits = inputs.logits.template getPtr(); - auto endIds = inputs.end_ids.template getPtr(); - auto batchSlots = inputs.batch_slots ? inputs.batch_slots->template getPtr() : nullptr; + auto endIds = inputs.end_ids.template getPtr(); + auto batchSlots = inputs.batch_slots ? inputs.batch_slots->template getPtr() : nullptr; auto curandStatesDevice = inputs.curand_states; auto samplingWorkspaceDevice = inputs.sampling_workspace; auto const probsComputed = inputs.probs_computed; @@ -204,32 +235,23 @@ void TopKSamplingLayer::forward(DecodingOutputParams& outputs, ForwardParams& ? reinterpret_cast(outputs.finished->template getPtr()) : nullptr; - float* cumLogProbs = (outputs.cum_log_probs) ? outputs.cum_log_probs->template getPtr() : nullptr; - float* outputLogProbs = (outputs.output_log_probs) ? outputs.output_log_probs->template getPtr() : nullptr; - int* sequenceLength = (outputs.sequence_length) ? outputs.sequence_length->template getPtr() : nullptr; + auto cumLogProbs + = (outputs.cum_log_probs) ? outputs.cum_log_probs->template getPtr() : static_cast(nullptr); + auto outputLogProbs = (outputs.output_log_probs) ? outputs.output_log_probs->template getPtr() + : static_cast(nullptr); + auto sequenceLength = (outputs.sequence_length) ? outputs.sequence_length->template getPtr() + : static_cast(nullptr); invokeBatchTopKSampling(samplingWorkspaceDevice, logits, static_cast(nullptr), - outputs.output_ids_ptr.template getPtr(), sequenceLength, finishedInput, finishedOutput, cumLogProbs, - outputLogProbs, curandStatesDevice, static_cast(mRuntimeMaxTopK), - reinterpret_cast(mRuntimeTopKDevice), 1.0f, mRuntimeTopPDevice, mVocabSizePadded, endIds, batchSlots, - mStream, batchSize, mMaxBatchSize, nullptr, 1, mSkipDecodeDevice, mNormalizeLogProbs, probsComputed, + outputs.output_ids_ptr.template getPtr(), /* outputIds */ nullptr, sequenceLength, finishedInput, + finishedOutput, cumLogProbs, outputLogProbs, curandStatesDevice, static_cast(mRuntimeMaxTopK), + static_cast(mRuntimeTopKDevice), 1.0f, mRuntimeTopPDevice, mVocabSizePadded, endIds, batchSlots, + mStream, batchSize, mMaxBatchSize, /* tokens per step */ nullptr, /* max tokens per step */ 1, + /* maxSeqLen ignored as outputIds is nullptr */ 0, mSkipDecodeDevice, mNormalizeLogProbs, probsComputed, /* return all Top-K*/ false); sync_check_cuda_error(); -} - -template -TopKSamplingLayer::TopKSamplingLayer(size_t maxBatchSize, size_t vocabSize, size_t vocabSizePadded, - cudaStream_t stream, std::shared_ptr allocator) - : BaseSamplingLayer(maxBatchSize, vocabSize, vocabSizePadded, stream, std::move(allocator), nullptr) -{ - allocateBuffer(mMaxBatchSize); -} -template -TopKSamplingLayer::~TopKSamplingLayer() -{ - TLLM_LOG_TRACE(__PRETTY_FUNCTION__); - freeBuffer(); + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template class TopKSamplingLayer; diff --git a/cpp/tensorrt_llm/layers/topKSamplingLayer.h b/cpp/tensorrt_llm/layers/topKSamplingLayer.h index fdcdf02cc..3be0c7f2a 100644 --- a/cpp/tensorrt_llm/layers/topKSamplingLayer.h +++ b/cpp/tensorrt_llm/layers/topKSamplingLayer.h @@ -38,11 +38,12 @@ class TopKSamplingLayer : public BaseSamplingLayer using SetupParams = typename Base::SetupParams; using ForwardParams = typename Base::ForwardParams; - TopKSamplingLayer(size_t maxBatchSize, size_t vocabSize, size_t vocabSizePadded, cudaStream_t stream, - std::shared_ptr allocator); + TopKSamplingLayer(runtime::SizeType maxBatchSize, runtime::SizeType vocabSize, runtime::SizeType vocabSizePadded, + cudaStream_t stream, std::shared_ptr allocator); ~TopKSamplingLayer(); - void setup(size_t batchSize, int32_t const* batchSlots, SetupParams const& setupParams) override; + void setup( + runtime::SizeType batchSize, runtime::SizeType const* batchSlots, SetupParams const& setupParams) override; void forward(DecodingOutputParams& outputs, ForwardParams& inputs) override; bool const* getSkipDecodeHost() const @@ -52,8 +53,8 @@ class TopKSamplingLayer : public BaseSamplingLayer protected: bool mNormalizeLogProbs = true; - uint32_t mRuntimeMaxTopK = 0; - uint32_t* mRuntimeTopKDevice = nullptr; + runtime::SizeType mRuntimeMaxTopK = 0; + runtime::SizeType* mRuntimeTopKDevice = nullptr; float* mRuntimeTopPDevice = nullptr; void* mSetupWorkspaceDevice = nullptr; bool* mSkipDecodeDevice = nullptr; @@ -69,10 +70,8 @@ class TopKSamplingLayer : public BaseSamplingLayer using Base::mStream; using Base::mAllocator; - static constexpr uint32_t TOP_K_MAX = 1024; - private: - void allocateBuffer(size_t batchSize); + void allocateBuffer(runtime::SizeType batchSize); void freeBuffer(); }; diff --git a/cpp/tensorrt_llm/layers/topPSamplingLayer.cu b/cpp/tensorrt_llm/layers/topPSamplingLayer.cu index 958b4bd08..cafea18aa 100644 --- a/cpp/tensorrt_llm/layers/topPSamplingLayer.cu +++ b/cpp/tensorrt_llm/layers/topPSamplingLayer.cu @@ -29,26 +29,27 @@ using namespace tensorrt_llm::common; using namespace tensorrt_llm::kernels; +using namespace tensorrt_llm::runtime; namespace tensorrt_llm { namespace layers { -static __global__ void setTopPRuntimeArgs(int batchSize, uint32_t topK, uint32_t* topKs, int topKsSize, float topP, - float* topPs, int topPsSize, bool* skipDecode, int const* batchSlots, float* initialTopPBuf) +static __global__ void setTopPRuntimeArgs(SizeType batchSize, SizeType topK, SizeType* topKs, SizeType topKsSize, + float topP, float* topPs, SizeType topPsSize, bool* skipDecode, SizeType const* batchSlots, float* initialTopPBuf) { /** * @brief Setup the runtime arguments for topp, broadcasting top_p to top_ps and top_k to top_ks. */ - int index = blockIdx.x * blockDim.x + threadIdx.x; - for (int bi = index; bi < batchSize; bi += gridDim.x * blockDim.x) + auto index = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + for (SizeType bi = index; bi < batchSize; bi += static_cast(gridDim.x * blockDim.x)) { auto const batchSlot = batchSlots != nullptr ? batchSlots[bi] : bi; - std::uint32_t k = topKsSize > 1 ? topKs[batchSlot] : topK; - float p = topPsSize > 1 ? topPs[batchSlot] : topP; + auto k = topKsSize > 1 ? topKs[batchSlot] : topK; + auto const p = topPsSize > 1 ? topPs[batchSlot] : topP; if (k == 0 && p == 0.0f) { // TensorRT-LLM's topp implementation does not support topp = 0.0f, but it @@ -65,28 +66,54 @@ static __global__ void setTopPRuntimeArgs(int batchSize, uint32_t topK, uint32_t } template -void TopPSamplingLayer::allocateBuffer(size_t batchSize) +TopPSamplingLayer::TopPSamplingLayer(SizeType maxBatchSize, SizeType vocabSize, SizeType vocabSizePadded, + cudaStream_t stream, std::shared_ptr allocator, cudaDeviceProp* prop, bool isDeterministic, + bool isAirTopP) + : BaseSamplingLayer(maxBatchSize, vocabSize, vocabSizePadded, stream, std::move(allocator), prop) + , mIsDeterministic(isDeterministic) + , mIsAirTopP(isAirTopP) +{ + TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + + allocateBuffer(mMaxBatchSize); + + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); +} + +template +TopPSamplingLayer::~TopPSamplingLayer() { - TLLM_LOG_TRACE(__PRETTY_FUNCTION__); - if (mIsDeterministic) + TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + + freeBuffer(); + + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); +} + +template +void TopPSamplingLayer::allocateBuffer(SizeType batchSize) +{ + TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + + if (mIsAirTopP == false) { mSamplingWorkspaceSize = getTopPWorkspaceSize(batchSize, mVocabSizePadded); } else { - mSamplingWorkspaceSize = getAirTopPWorkspaceSize(batchSize, mVocabSizePadded); + mSamplingWorkspaceSize = getAirTopPWorkspaceSize(batchSize, mVocabSizePadded, mIsDeterministic); } std::array deviceBufferSizes; - deviceBufferSizes[0] = sizeof(int32_t) * batchSize * mVocabSizePadded; - deviceBufferSizes[1] = sizeof(int32_t) * (batchSize + 1); - deviceBufferSizes[2] = sizeof(int32_t) * (batchSize + 1); - deviceBufferSizes[3] = sizeof(uint32_t) * batchSize; + deviceBufferSizes[0] = sizeof(TokenIdType) * batchSize * mVocabSizePadded; + deviceBufferSizes[1] = sizeof(SizeType) * (batchSize + 1); + deviceBufferSizes[2] = sizeof(SizeType) * (batchSize + 1); + deviceBufferSizes[3] = sizeof(SizeType) * batchSize; deviceBufferSizes[4] = sizeof(float) * batchSize; deviceBufferSizes[5] = sizeof(float) * batchSize; deviceBufferSizes[6] = sizeof(float) * batchSize; deviceBufferSizes[7] = sizeof(float) * batchSize; - deviceBufferSizes[8] = sizeof(int32_t) * batchSize; + deviceBufferSizes[8] = sizeof(TokenIdType) * batchSize; deviceBufferSizes[9] = sizeof(bool) * batchSize; deviceBufferSizes[10] = *std::max_element(&deviceBufferSizes[3], &deviceBufferSizes[9]); @@ -102,17 +129,20 @@ void TopPSamplingLayer::allocateBuffer(size_t batchSize) mSkipDecodeDevice = mAllocator->reMalloc(mSkipDecodeDevice, deviceBufferSizes[9], false); mSetupWorkspaceDevice = mAllocator->reMalloc(mSetupWorkspaceDevice, deviceBufferSizes[10], false); - mSkipDecodeHost = (bool*) std::realloc(mSkipDecodeHost, sizeof(bool) * batchSize); + mSkipDecodeHost = static_cast(std::realloc(mSkipDecodeHost, sizeof(bool) * batchSize)); std::fill(mSkipDecodeHost, mSkipDecodeHost + batchSize, true); mAllocatedSize = std::accumulate(deviceBufferSizes.begin(), deviceBufferSizes.end(), 0); TLLM_LOG_DEBUG("topPSamplingLayer allocated %lu bytes on GPU", mAllocatedSize); + + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void TopPSamplingLayer::freeBuffer() { - TLLM_LOG_TRACE(__PRETTY_FUNCTION__); + TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + mAllocator->free((void**) (&mTopPIdValsDevice)); mAllocator->free((void**) (&mTopPOffsetDevice)); mAllocator->free((void**) (&mBeginTopPOffsetDevice)); @@ -125,34 +155,36 @@ void TopPSamplingLayer::freeBuffer() mAllocator->free((void**) (&mSkipDecodeDevice)); mAllocator->free((void**) (&mSetupWorkspaceDevice)); std::free(mSkipDecodeHost); + + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template -void TopPSamplingLayer::setup(size_t const batchSize, int32_t const* batchSlots, SetupParams const& setupParams) +void TopPSamplingLayer::setup(SizeType const batchSize, SizeType const* batchSlots, SetupParams const& setupParams) { - TLLM_LOG_TRACE(__PRETTY_FUNCTION__); + TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); - uint32_t const defaultTopK = 0; - auto runtimeTopK = setupParams.runtime_top_k.value_or(std::vector{defaultTopK}); + SizeType const defaultTopK = 0; + auto runtimeTopK = setupParams.runtime_top_k.value_or(std::vector{defaultTopK}); auto runtimeTopP = setupParams.runtime_top_p.value_or(std::vector{}); - size_t const runtimeTopKSize = runtimeTopK.size(); - size_t const runtimeTopPSize = runtimeTopP.size(); + auto const runtimeTopKSize = runtimeTopK.size(); + auto const runtimeTopPSize = runtimeTopP.size(); - float const defaultTopPDecay{1.0f}; + auto const defaultTopPDecay{1.0f}; auto decayVec = setupParams.top_p_decay.value_or(std::vector(batchSize, defaultTopPDecay)); - float const defaultTopPMin{1e-6f}; // prevent topp becoming 0.0 + auto const defaultTopPMin{1e-6f}; // prevent topp becoming 0.0 auto topPMinVec = setupParams.top_p_min.value_or(std::vector(batchSize, defaultTopPMin)); - int32_t const defaultTopPResetId{-1}; - auto topPResetIdsVec = setupParams.top_p_reset_ids.value_or(std::vector(batchSize, defaultTopPResetId)); + SizeType const defaultTopPResetId{-1}; + auto topPResetIdsVec = setupParams.top_p_reset_ids.value_or(std::vector(batchSize, defaultTopPResetId)); if (runtimeTopPSize == 0) { - for (size_t bi = 0; bi < batchSize; ++bi) + for (SizeType bi = 0; bi < static_cast(batchSize); ++bi) { - int32_t bid = bi; + auto bid = bi; if (batchSlots) { bid = batchSlots[bi]; @@ -190,21 +222,21 @@ void TopPSamplingLayer::setup(size_t const batchSize, int32_t const* batchSlo } } - uint32_t const topK = runtimeTopK.at(0); - float const topP = runtimeTopP.at(0); + auto const topK = runtimeTopK.at(0); + auto const topP = runtimeTopP.at(0); if (runtimeTopKSize > 1) { - TLLM_CHECK_WITH_INFO(runtimeTopK.size() == batchSize, - fmtstr("runtimeTopK.size() (%lu) == batchSize (%lu) is not satisfied!", runtimeTopK.size(), batchSize)); - cudaAutoCpy(reinterpret_cast(mSetupWorkspaceDevice), runtimeTopK.data(), batchSize, mStream); + TLLM_CHECK_WITH_INFO(static_cast(runtimeTopK.size()) == batchSize, + fmtstr("runtimeTopK.size() (%lu) == batchSize (%d) is not satisfied!", runtimeTopK.size(), batchSize)); + cudaAutoCpy(reinterpret_cast(mSetupWorkspaceDevice), runtimeTopK.data(), batchSize, mStream); invokeScatterDecodingParams( - reinterpret_cast(mSetupWorkspaceDevice), mRuntimeTopKDevice, batchSlots, batchSize, mStream); + reinterpret_cast(mSetupWorkspaceDevice), mRuntimeTopKDevice, batchSlots, batchSize, mStream); } if (runtimeTopPSize > 1) { - TLLM_CHECK_WITH_INFO(runtimeTopP.size() == batchSize, - fmtstr("runtime_top_p.size() (%lu) == batchSize (%lu) is not satisfied!", runtimeTopP.size(), batchSize)); + TLLM_CHECK_WITH_INFO(static_cast(runtimeTopP.size()) == batchSize, + fmtstr("runtime_top_p.size() (%lu) == batchSize (%d) is not satisfied!", runtimeTopP.size(), batchSize)); cudaAutoCpy(reinterpret_cast(mSetupWorkspaceDevice), runtimeTopP.data(), batchSize, mStream); invokeScatterDecodingParams( reinterpret_cast(mSetupWorkspaceDevice), mRuntimeTopPDevice, batchSlots, batchSize, mStream); @@ -213,8 +245,8 @@ void TopPSamplingLayer::setup(size_t const batchSize, int32_t const* batchSlo auto fillBuffers = [this, &batchSize, &batchSlots](std::string name, auto const& vector, auto deviceTmpBuffer, auto deviceBuffer) { - TLLM_CHECK_WITH_INFO(vector.size() == batchSize, - fmtstr("%s.size() (%lu) == batchSize (%lu) is not satisfied!", name.c_str(), vector.size(), batchSize)); + TLLM_CHECK_WITH_INFO(static_cast(vector.size()) == batchSize, + fmtstr("%s.size() (%lu) == batchSize (%d) is not satisfied!", name.c_str(), vector.size(), batchSize)); cudaAutoCpy(deviceTmpBuffer, vector.data(), batchSize, mStream); invokeScatterDecodingParams(deviceTmpBuffer, deviceBuffer, batchSlots, batchSize, mStream); }; @@ -224,11 +256,11 @@ void TopPSamplingLayer::setup(size_t const batchSize, int32_t const* batchSlo fillBuffers("top_p_min", topPMinVec, reinterpret_cast(mSetupWorkspaceDevice), mTopPMinDevice); fillBuffers( - "top_p_reset_ids", topPResetIdsVec, reinterpret_cast(mSetupWorkspaceDevice), mTopPResetIdsDevice); + "top_p_reset_ids", topPResetIdsVec, reinterpret_cast(mSetupWorkspaceDevice), mTopPResetIdsDevice); { - dim3 block(std::min((int) batchSize, 256)); - dim3 grid(divUp((int) batchSize, (int) block.x)); + dim3 block(std::min(static_cast(batchSize), 256)); + dim3 grid(divUp(static_cast(batchSize), static_cast(block.x))); setTopPRuntimeArgs<<>>(batchSize, topK, mRuntimeTopKDevice, runtimeTopKSize, topP, mRuntimeTopPDevice, runtimeTopPSize, mSkipDecodeDevice, batchSlots, mInitialTopPDevice); sync_check_cuda_error(); @@ -238,44 +270,54 @@ void TopPSamplingLayer::setup(size_t const batchSize, int32_t const* batchSlo std::vector runtimeTopPs(mMaxBatchSize); cudaAutoCpy(runtimeTopPs.data(), mRuntimeTopPDevice, mMaxBatchSize, mStream); { - float maxTopP = 0.f; - for (size_t bi = 0; bi < batchSize; ++bi) + auto maxTopP = 0.f; + for (SizeType bi = 0; bi < static_cast(batchSize); ++bi) { - int32_t bid = bi; - if (batchSlots) - { - bid = batchSlots[bi]; - } + auto const bid = batchSlots ? batchSlots[bi] : bi; maxTopP = std::max(maxTopP, runtimeTopPs[bid]); } mRuntimeMaxTopP = std::max(mRuntimeMaxTopP, maxTopP); } - if (!mIsDeterministic) + if (mIsAirTopP == true) { - int smCnt = mCudaDeviceProp->multiProcessorCount; - mAirTopPBlockNum = calcAirTopPBlockNum(batchSize, (int) mVocabSizePadded, smCnt); + int smCnt = 0; + if (mCudaDeviceProp) + { + smCnt = mCudaDeviceProp->multiProcessorCount; + } + if (smCnt <= 0) + { + int deviceId; + check_cuda_error(cudaGetDevice(&deviceId)); // Get the correct device id + cudaDeviceProp prop; + check_cuda_error(cudaGetDeviceProperties(&prop, deviceId)); + smCnt = prop.multiProcessorCount; + } + mAirTopPBlockNum = calcAirTopPBlockNum(batchSize, (int) mVocabSizePadded, smCnt, mIsDeterministic); } + + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void TopPSamplingLayer::forward(DecodingOutputParams& outputs, ForwardParams& inputs) { - TLLM_LOG_TRACE(__PRETTY_FUNCTION__); + TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto const batchSize = inputs.logits.shape[0]; // Probabilities must be already computed instead of logits auto probs = inputs.logits.template getPtr(); - auto endIds = inputs.end_ids.template getPtr(); - auto batchSlots = inputs.batch_slots ? inputs.batch_slots->template getPtr() : nullptr; + auto endIds = inputs.end_ids.template getPtr(); + auto batchSlots = inputs.batch_slots ? inputs.batch_slots->template getPtr() : nullptr; auto curandStatesDevice = inputs.curand_states; auto samplingWorkspaceDevice = inputs.sampling_workspace; TLLM_CHECK_WITH_INFO(curandStatesDevice, "No curand states provided"); TLLM_CHECK_WITH_INFO(samplingWorkspaceDevice, "No sampling workspace provided"); - if (mIsDeterministic) + if (mIsAirTopP == false) { invokeTopPInitialize( mTopPIdValsDevice, mTopPOffsetDevice, mBeginTopPOffsetDevice, batchSize, mVocabSizePadded, mStream); @@ -289,46 +331,34 @@ void TopPSamplingLayer::forward(DecodingOutputParams& outputs, ForwardParams& ? reinterpret_cast(outputs.finished->template getPtr()) : nullptr; - float* cumLogProbs = (outputs.cum_log_probs) ? outputs.cum_log_probs->template getPtr() : nullptr; - float* outputLogProbs = (outputs.output_log_probs) ? outputs.output_log_probs->template getPtr() : nullptr; - int* sequenceLength = (outputs.sequence_length) ? outputs.sequence_length->template getPtr() : nullptr; + auto cumLogProbs + = (outputs.cum_log_probs) ? outputs.cum_log_probs->template getPtr() : static_cast(nullptr); + auto outputLogProbs = (outputs.output_log_probs) ? outputs.output_log_probs->template getPtr() + : static_cast(nullptr); + auto sequenceLength = (outputs.sequence_length) ? outputs.sequence_length->template getPtr() + : static_cast(nullptr); - if (mIsDeterministic) + if (mIsAirTopP == false) { invokeBatchTopPSampling(samplingWorkspaceDevice, outputs.output_ids_ptr.template getPtr(), sequenceLength, finishedInput, finishedOutput, cumLogProbs, outputLogProbs, probs, mTopPIdValsDevice, mTopPOffsetDevice, mBeginTopPOffsetDevice, curandStatesDevice, batchSize, mMaxBatchSize, mVocabSizePadded, endIds, mRuntimeMaxTopP, mRuntimeTopPDevice, mStream, mSkipDecodeDevice, batchSlots); - sync_check_cuda_error(); - invokeComputeToppDecay(mRuntimeTopPDevice, mInitialTopPDevice, - outputs.output_ids_ptr.template getPtr(), mTopPDecayDevice, mTopPMinDevice, mTopPResetIdsDevice, - sequenceLength, batchSlots, batchSize, mStream); - sync_check_cuda_error(); } else { invokeBatchAirTopPSampling(samplingWorkspaceDevice, outputs.output_ids_ptr.template getPtr(), sequenceLength, finishedInput, finishedOutput, cumLogProbs, outputLogProbs, probs, curandStatesDevice, batchSize, mMaxBatchSize, mVocabSizePadded, endIds, mRuntimeMaxTopP, mRuntimeTopPDevice, mStream, - mAirTopPBlockNum, mSkipDecodeDevice, batchSlots); - sync_check_cuda_error(); + mAirTopPBlockNum, mSkipDecodeDevice, batchSlots, mIsDeterministic); } -} -template -TopPSamplingLayer::TopPSamplingLayer(std::size_t maxBatchSize, std::size_t vocabSize, std::size_t vocabSizePadded, - cudaStream_t stream, std::shared_ptr allocator, cudaDeviceProp* prop, bool isDeterministic) - : BaseSamplingLayer(maxBatchSize, vocabSize, vocabSizePadded, stream, std::move(allocator), prop) - , mIsDeterministic(isDeterministic) -{ - allocateBuffer(mMaxBatchSize); -} - -template -TopPSamplingLayer::~TopPSamplingLayer() -{ - TLLM_LOG_TRACE(__PRETTY_FUNCTION__); - freeBuffer(); + sync_check_cuda_error(); + invokeComputeToppDecay(mRuntimeTopPDevice, mInitialTopPDevice, + outputs.output_ids_ptr.template getPtr(), mTopPDecayDevice, mTopPMinDevice, + mTopPResetIdsDevice, sequenceLength, batchSlots, batchSize, mStream); + sync_check_cuda_error(); + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template class TopPSamplingLayer; diff --git a/cpp/tensorrt_llm/layers/topPSamplingLayer.h b/cpp/tensorrt_llm/layers/topPSamplingLayer.h index dd485d218..97d9b2e41 100644 --- a/cpp/tensorrt_llm/layers/topPSamplingLayer.h +++ b/cpp/tensorrt_llm/layers/topPSamplingLayer.h @@ -38,11 +38,13 @@ class TopPSamplingLayer : public BaseSamplingLayer using SetupParams = typename Base::SetupParams; using ForwardParams = typename Base::ForwardParams; - TopPSamplingLayer(std::size_t maxBatchSize, std::size_t vocabSize, std::size_t vocabSizePadded, cudaStream_t stream, - std::shared_ptr allocator, cudaDeviceProp* prop, bool isDeterministic = true); + TopPSamplingLayer(runtime::SizeType maxBatchSize, runtime::SizeType vocabSize, runtime::SizeType vocabSizePadded, + cudaStream_t stream, std::shared_ptr allocator, cudaDeviceProp* prop, + bool isDeterministic = true, bool isAirTopP = true); ~TopPSamplingLayer(); - void setup(std::size_t batchSize, int32_t const* batchSlots, SetupParams const& setupParams) override; + void setup( + runtime::SizeType batchSize, runtime::SizeType const* batchSlots, SetupParams const& setupParams) override; void forward(DecodingOutputParams& outputs, ForwardParams& inputs) override; bool const* getSkipDecodeHost() const @@ -51,22 +53,23 @@ class TopPSamplingLayer : public BaseSamplingLayer } protected: - uint32_t* mRuntimeTopKDevice = nullptr; + runtime::SizeType* mRuntimeTopKDevice = nullptr; float* mRuntimeTopPDevice = nullptr; float mRuntimeMaxTopP{0.f}; float* mInitialTopPDevice = nullptr; float* mTopPDecayDevice = nullptr; float* mTopPMinDevice = nullptr; - int32_t* mTopPResetIdsDevice = nullptr; + runtime::TokenIdType* mTopPResetIdsDevice = nullptr; void* mSetupWorkspaceDevice = nullptr; - int32_t* mTopPIdValsDevice = nullptr; - int32_t* mTopPOffsetDevice = nullptr; - int32_t* mBeginTopPOffsetDevice = nullptr; + runtime::TokenIdType* mTopPIdValsDevice = nullptr; + runtime::SizeType* mTopPOffsetDevice = nullptr; + runtime::SizeType* mBeginTopPOffsetDevice = nullptr; bool* mSkipDecodeDevice = nullptr; bool* mSkipDecodeHost = nullptr; bool mIsDeterministic = true; - int mAirTopPBlockNum; + runtime::SizeType mAirTopPBlockNum; + bool mIsAirTopP = false; using Base::mMaxBatchSize; using Base::mVocabSize; @@ -80,7 +83,7 @@ class TopPSamplingLayer : public BaseSamplingLayer using Base::mCudaDeviceProp; private: - void allocateBuffer(std::size_t batchSize); + void allocateBuffer(runtime::SizeType batchSize); void freeBuffer(); }; diff --git a/cpp/tensorrt_llm/plugins/api/tllmPlugin.cpp b/cpp/tensorrt_llm/plugins/api/tllmPlugin.cpp index 69ac51a15..d97a7dcbe 100644 --- a/cpp/tensorrt_llm/plugins/api/tllmPlugin.cpp +++ b/cpp/tensorrt_llm/plugins/api/tllmPlugin.cpp @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "tllmPlugin.h" +#include "tensorrt_llm/plugins/api/tllmPlugin.h" #include "tensorrt_llm/common/stringUtils.h" #include "tensorrt_llm/runtime/tllmLogger.h" diff --git a/cpp/tensorrt_llm/plugins/common/gemmPluginProfiler.h b/cpp/tensorrt_llm/plugins/common/gemmPluginProfiler.h index b0fbdde8a..1ed6555d3 100644 --- a/cpp/tensorrt_llm/plugins/common/gemmPluginProfiler.h +++ b/cpp/tensorrt_llm/plugins/common/gemmPluginProfiler.h @@ -122,11 +122,14 @@ class GemmIdCublas : public GemmIdCore public: bool transA{}; bool transB{}; + nvinfer1::DataType outputDtype; - GemmIdCublas(int n_, int k_, nvinfer1::DataType const& dtype_, bool transA_, bool transB_) + GemmIdCublas(int n_, int k_, nvinfer1::DataType const& dtype_, bool transA_, bool transB_, + nvinfer1::DataType const& output_dtype_) : GemmIdCore(n_, k_, dtype_) , transA(transA_) , transB(transB_) + , outputDtype(output_dtype_) { } @@ -134,7 +137,7 @@ class GemmIdCublas : public GemmIdCore bool operator==(GemmIdCublas const& id) const { - return isEqual(id) && transA == id.transA && transB == id.transB; + return isEqual(id) && transA == id.transA && transB == id.transB && outputDtype == id.outputDtype; } friend std::ostream& operator<<(std::ostream& out, GemmIdCublas const& id) @@ -143,6 +146,7 @@ class GemmIdCublas : public GemmIdCore out << " type=" << static_cast(id.dtype); out << " transA=" << id.transA; out << " transB=" << id.transB; + out << " outputDtype=" << static_cast(id.outputDtype); return out; } }; @@ -157,7 +161,8 @@ struct GemmIdCublasHash auto h3 = std::hash{}(static_cast(id.dtype)); auto h4 = std::hash{}(id.transA); auto h5 = std::hash{}(id.transB); - return h1 ^ h2 ^ h3 ^ h4 ^ h5; + auto h6 = std::hash{}(static_cast(id.outputDtype)); + return h1 ^ h2 ^ h3 ^ h4 ^ h5 ^ h6; } }; diff --git a/cpp/tensorrt_llm/plugins/gemmPlugin/gemmPlugin.cpp b/cpp/tensorrt_llm/plugins/gemmPlugin/gemmPlugin.cpp index 4a94bd0a7..db6bb3a7a 100644 --- a/cpp/tensorrt_llm/plugins/gemmPlugin/gemmPlugin.cpp +++ b/cpp/tensorrt_llm/plugins/gemmPlugin/gemmPlugin.cpp @@ -170,7 +170,7 @@ void GemmPlugin::init() mPluginProfiler->setTranspose(mTransA, mTransB); mPluginProfiler->setOutputType(mOutputType); - mGemmId = GemmIdCublas(mDims.n, mDims.k, mType, mTransA, mTransB); + mGemmId = GemmIdCublas(mDims.n, mDims.k, mType, mTransA, mTransB, mOutputType); } void GemmPlugin::setGemmConfig() diff --git a/cpp/tensorrt_llm/plugins/loraPlugin/loraPlugin.cpp b/cpp/tensorrt_llm/plugins/loraPlugin/loraPlugin.cpp index 9d81c0ec9..6d72cc336 100644 --- a/cpp/tensorrt_llm/plugins/loraPlugin/loraPlugin.cpp +++ b/cpp/tensorrt_llm/plugins/loraPlugin/loraPlugin.cpp @@ -131,7 +131,7 @@ void LoraPlugin::init() mPluginProfiler->setTranspose(mTransA, mTransB); - mGemmId = GemmIdCublas(mDims.n, mDims.k, mType, mTransA, mTransB); + mGemmId = GemmIdCublas(mDims.n, mDims.k, mType, mTransA, mTransB, mType); } void LoraPlugin::setGemmConfig() diff --git a/cpp/tensorrt_llm/plugins/selectiveScanPlugin/selectiveScanPlugin.cpp b/cpp/tensorrt_llm/plugins/selectiveScanPlugin/selectiveScanPlugin.cpp index 9db457af4..640980a75 100644 --- a/cpp/tensorrt_llm/plugins/selectiveScanPlugin/selectiveScanPlugin.cpp +++ b/cpp/tensorrt_llm/plugins/selectiveScanPlugin/selectiveScanPlugin.cpp @@ -84,7 +84,7 @@ nvinfer1::DimsExprs SelectiveScanPlugin::getOutputDimensions( bool SelectiveScanPlugin::supportsFormatCombination( int pos, nvinfer1::PluginTensorDesc const* inOut, int nbInputs, int nbOutputs) noexcept { - if (pos == getHostRequestTypesIdx()) + if (pos == getHostRequestTypesIdx() || pos == getLastTokenIdsIdx()) { return inOut[pos].type == nvinfer1::DataType::kINT32; } @@ -112,7 +112,7 @@ size_t SelectiveScanPlugin::getWorkspaceSize(nvinfer1::PluginTensorDesc const* i void SelectiveScanPlugin::setSSMParams(SSMParamsBase& params, const size_t batch, const size_t dim, const size_t seqLen, const size_t dstate, bool const isVariableB, bool const isVariableC, void* statePtr, void const* x, void const* delta, void const* deltaBias, void const* A, void const* B, void const* C, void const* D, void const* z, - void* out, bool deltaSoftplus) + int const* lastTokenIds, void* out, bool deltaSoftplus) { // Reset the parameters memset(¶ms, 0, sizeof(params)); @@ -138,6 +138,7 @@ void SelectiveScanPlugin::setSSMParams(SSMParamsBase& params, const size_t batch params.out_ptr = out; params.x_ptr = statePtr; params.z_ptr = const_cast(z); + params.last_token_ids_ptr = lastTokenIds; } template @@ -156,6 +157,7 @@ int SelectiveScanPlugin::enqueueImpl(nvinfer1::PluginTensorDesc const* inputDesc // 7. D [dim] // 8. z [batch_size, seq_len, dim] // 9. host_request_types [batch_size] int32. 0: context; 1: generation. + // 10. last_token_ids [batch_size] int32 // outputs // 0. output_tensor [batch_size, seq_len, dim] // 1. state [batch_size, dstate, dim] @@ -170,7 +172,8 @@ int SelectiveScanPlugin::enqueueImpl(nvinfer1::PluginTensorDesc const* inputDesc setSSMParams(ssm_params, batch_size, mDim, seq_len, mDState, mIsVariableB, mIsVariableC, outputs[1], inputs[getInputTensorIdx()], inputs[getDeltaIdx()], inputs[getDeltaBiasIdx()], inputs[getAIdx()], - inputs[getBIdx()], inputs[getCIdx()], inputs[getDIdx()], inputs[getZIdx()], outputs[0], mDeltaSoftplus); + inputs[getBIdx()], inputs[getCIdx()], inputs[getDIdx()], inputs[getZIdx()], + static_cast(inputs[getLastTokenIdsIdx()]), outputs[0], mDeltaSoftplus); if (reqTypes[0] == RequestType::kCONTEXT) { diff --git a/cpp/tensorrt_llm/plugins/selectiveScanPlugin/selectiveScanPlugin.h b/cpp/tensorrt_llm/plugins/selectiveScanPlugin/selectiveScanPlugin.h index ad9062fe3..ac393db80 100644 --- a/cpp/tensorrt_llm/plugins/selectiveScanPlugin/selectiveScanPlugin.h +++ b/cpp/tensorrt_llm/plugins/selectiveScanPlugin/selectiveScanPlugin.h @@ -39,6 +39,7 @@ namespace tensorrt_llm::plugins // 7. D [dim] // 8. z [batch_size, seq_len, dim] // 9. host_request_types [batch_size] int32. 0: context; 1: generation; 2: none. +// 10. last_token_ids [batch_size] int32 // outputs // 0. output_tensor [batch_size, seq_len, dim] // 1. state [batch_size, dstate, dim] @@ -142,13 +143,18 @@ class SelectiveScanPlugin : public BasePlugin return 9; }; + IndexType getLastTokenIdsIdx() const + { + return 10; + }; + void setSSMParams(tensorrt_llm::kernels::SSMParamsBase& params, // sizes const size_t batch, const size_t dim, const size_t seqLen, const size_t dstate, bool const isVariableB, bool const isVariableC, // device pointers void* statePtr, void const* x, void const* delta, void const* deltaBias, void const* A, void const* B, - void const* C, void const* D, void const* z, void* out, bool deltaSoftplus); + void const* C, void const* D, void const* z, int const* lastTokenIds, void* out, bool deltaSoftplus); private: int mDim; diff --git a/cpp/tensorrt_llm/plugins/weightOnlyGroupwiseQuantMatmulPlugin/weightOnlyGroupwiseQuantMatmulPlugin.cpp b/cpp/tensorrt_llm/plugins/weightOnlyGroupwiseQuantMatmulPlugin/weightOnlyGroupwiseQuantMatmulPlugin.cpp index eb33af803..bcd3584e4 100644 --- a/cpp/tensorrt_llm/plugins/weightOnlyGroupwiseQuantMatmulPlugin/weightOnlyGroupwiseQuantMatmulPlugin.cpp +++ b/cpp/tensorrt_llm/plugins/weightOnlyGroupwiseQuantMatmulPlugin/weightOnlyGroupwiseQuantMatmulPlugin.cpp @@ -15,8 +15,6 @@ * limitations under the License. */ #include "weightOnlyGroupwiseQuantMatmulPlugin.h" -#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/enabled.h" -#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/sm90/kernelLauncher.h" using namespace nvinfer1; using namespace tensorrt_llm::common; @@ -131,6 +129,7 @@ WeightOnlyGroupwiseQuantMatmulPlugin::WeightOnlyGroupwiseQuantMatmulPlugin( void WeightOnlyGroupwiseQuantMatmulPlugin::init(nvinfer1::DataType type, int quant_algo, int group_size) { + mArch = tensorrt_llm::common::getSMVersion(); mType = type; mQuantAlgo = quant_algo; mGroupSize = group_size; @@ -148,7 +147,7 @@ void WeightOnlyGroupwiseQuantMatmulPlugin::init(nvinfer1::DataType type, int qua if (quant_algo & FP8_ALPHA) { // Hopper style kernels - if (getSMVersion() < 90) + if (mArch < 90) { TLLM_THROW("W4A(fp)8 kernel is unsupported on pre-Hopper architectures!"); } @@ -184,6 +183,9 @@ void WeightOnlyGroupwiseQuantMatmulPlugin::init(nvinfer1::DataType type, int qua cutlass::uint4b_t, cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY>>(); } } + mCudaKernelEnabled = tensorrt_llm::kernels::weight_only::is_supported( + mArch, tensorrt_llm::kernels::weight_only::KernelType::FP16Int4Groupwise); + mCudaKernelType = tensorrt_llm::kernels::weight_only::KernelType::FP16Int4Groupwise; } #if defined(ENABLE_BF16) else if (mType == nvinfer1::DataType::kBF16) @@ -191,7 +193,7 @@ void WeightOnlyGroupwiseQuantMatmulPlugin::init(nvinfer1::DataType type, int qua if (quant_algo & FP8_ALPHA) { // Hopper style kernels - if (getSMVersion() < 90) + if (mArch < 90) { TLLM_THROW("FP8 is unsupported on pre-Hopper architectures!"); } @@ -214,14 +216,15 @@ void WeightOnlyGroupwiseQuantMatmulPlugin::init(nvinfer1::DataType type, int qua cutlass::uint4b_t, cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY>>(); } } + mCudaKernelEnabled = tensorrt_llm::kernels::weight_only::is_supported( + mArch, tensorrt_llm::kernels::weight_only::KernelType::BF16Int4Groupwise); + mCudaKernelType = tensorrt_llm::kernels::weight_only::KernelType::BF16Int4Groupwise; } #endif else { TLLM_THROW("Unsupported data type"); } - mCudaKernelEnabled - = tensorrt_llm::kernels::isWeightOnlyBatchedGemvEnabled(tensorrt_llm::kernels::WeightOnlyQuantType::Int4b); mPluginProfiler->setQuantAlgo(mQuantAlgo); mPluginProfiler->setGroupSize(mGroupSize); @@ -361,13 +364,7 @@ int WeightOnlyGroupwiseQuantMatmulPlugin::enqueue(nvinfer1::PluginTensorDesc con int const n = inputDesc[mWeightInputIdx].dims.d[1]; int const k = inputDesc[0].dims.d[inputDesc[0].dims.nbDims - 1]; - int smVersion = getSMVersion(); bool use_cuda_kernel = m < SMALL_M_FAST_PATH && mCudaKernelEnabled; -#if defined(ENABLE_BF16) - // CUDA kernels assume FP16 activations for Hopper - bool force_disable_cuda_kernel = smVersion == 90 && mType == nvinfer1::DataType::kBF16; - use_cuda_kernel = use_cuda_kernel && !force_disable_cuda_kernel; -#endif bool use_pre_quant_scale = mQuantAlgo & PRE_QUANT_SCALE; half const* zeros_ptr = (mQuantAlgo & ZERO) ? reinterpret_cast(inputs[mZerosInputIdx]) : nullptr; @@ -424,95 +421,44 @@ int WeightOnlyGroupwiseQuantMatmulPlugin::enqueue(nvinfer1::PluginTensorDesc con TLLM_CHECK_WITH_INFO(mType == nvinfer1::DataType::kHALF, "No valid weightOnlyGropwiseQuantMatmul configuration"); #endif - tensorrt_llm::kernels::WeightOnlyActivationType weight_only_act_type; // Quantized weights are packed in FP16 format (INT4*4 -> FP16) int real_n = n * FP16_INT4_RATIO; - if (mType == nvinfer1::DataType::kHALF) - { - weight_only_act_type = tensorrt_llm::kernels::WeightOnlyActivationType::FP16; - } - else if (mType == nvinfer1::DataType::kBF16) + if (use_cuda_kernel) { - weight_only_act_type = tensorrt_llm::kernels::WeightOnlyActivationType::BF16; - } - - if (smVersion == 90) - { - // Hopper style kernels - if (use_cuda_kernel) - { - // Use CUDA kernels for small batch size - // The CUDA kernel is designed for ColumnMajorTileInterleave weight layout used in fpAIntB cutlass kernel - // when sm >= 75 and the preprocessing of cutlass on sm70 does not interleave the weights. - void const* pre_quant_scale_ptr = nullptr; - if (use_pre_quant_scale) - pre_quant_scale_ptr = inputs[mPreQuantScaleInputIdx]; - void* cuda_kernel_act_ptr = const_cast(reinterpret_cast(inputs[0])); - void* cuda_kernel_act_scale_ptr = const_cast(reinterpret_cast(pre_quant_scale_ptr)); - void* cuda_kernel_weight_ptr = const_cast(reinterpret_cast(inputs[mWeightInputIdx])); - void* cuda_kernel_scales_ptr = const_cast(reinterpret_cast(inputs[mScalesInputIdx])); - void* cuda_kernel_zeros_ptr = const_cast(reinterpret_cast(zeros_ptr)); - void* cuda_kernel_bias_ptr = const_cast(reinterpret_cast(biases_ptr)); - void* cuda_kernel_out_ptr = const_cast(reinterpret_cast(outputs[0])); - - tensorrt_llm::kernels::weight_only::Params params{cuda_kernel_act_ptr, cuda_kernel_act_scale_ptr, - cuda_kernel_weight_ptr, cuda_kernel_scales_ptr, cuda_kernel_zeros_ptr, cuda_kernel_bias_ptr, - cuda_kernel_out_ptr, alpha, m, real_n, k, mGroupSize, - tensorrt_llm::kernels::weight_only::KernelType::W4A16}; - tensorrt_llm::kernels::weight_only::kernel_launcher(params, stream); - } - else - { - // Use cutlass kernels for large batch size - int const ws_bytes = m_weightOnlyGroupwiseGemmRunner->getWorkspaceSize(m, n, k); - - int32_t* weight_ptr = const_cast(reinterpret_cast(inputs[mWeightInputIdx])); - - auto const& bestTactic = mPluginProfiler->getBestConfig(m, mGemmId); - TLLM_CHECK_WITH_INFO(bestTactic, "No valid weight only groupwise GEMM tactic"); - m_weightOnlyGroupwiseGemmRunner->gemm(act_ptr, weight_ptr, inputs[mScalesInputIdx], zeros_ptr, biases_ptr, - alpha, outputs[0], m, real_n, k, mGroupSize, *bestTactic, - reinterpret_cast(workspace) + m * k * sizeof(half), ws_bytes, stream); - } + void const* pre_quant_scale_ptr = nullptr; + if (use_pre_quant_scale) + pre_quant_scale_ptr = inputs[mPreQuantScaleInputIdx]; + void const* cuda_kernel_act_ptr = inputs[0]; + void const* cuda_kernel_act_scale_ptr = pre_quant_scale_ptr; + void const* cuda_kernel_weight_ptr = inputs[mWeightInputIdx]; + void const* cuda_kernel_scales_ptr = inputs[mScalesInputIdx]; + void const* cuda_kernel_zeros_ptr = zeros_ptr; + void const* cuda_kernel_bias_ptr = biases_ptr; + void* cuda_kernel_out_ptr = outputs[0]; + + tensorrt_llm::kernels::weight_only::Params params{cuda_kernel_act_ptr, cuda_kernel_act_scale_ptr, + cuda_kernel_weight_ptr, cuda_kernel_scales_ptr, cuda_kernel_zeros_ptr, cuda_kernel_bias_ptr, + cuda_kernel_out_ptr, alpha, m, real_n, k, mGroupSize, mCudaKernelType, + static_cast(mQuantAlgo & FP8_ALPHA)}; + tensorrt_llm::kernels::weight_only::kernel_launcher(mArch, params, stream); } else { - // Pre-Hopper architectures - if (use_cuda_kernel) - { - // Use CUDA kernels for small batch size - // The CUDA kernel is designed for ColumnMajorTileInterleave weight layout used in fpAIntB cutlass kernel - // when sm >= 75 and the preprocessing of cutlass on sm70 does not interleave the weights. - void const* pre_quant_scale = nullptr; - if (use_pre_quant_scale) - pre_quant_scale = inputs[mPreQuantScaleInputIdx]; - tensorrt_llm::kernels::WeightOnlyParams params{reinterpret_cast(inputs[mWeightInputIdx]), - inputs[mScalesInputIdx], zeros_ptr, act_ptr, pre_quant_scale, biases_ptr, outputs[0], m, real_n, k, - mGroupSize, tensorrt_llm::kernels::WeightOnlyQuantType::Int4b, - tensorrt_llm::kernels::WeightOnlyType::GroupWise, - tensorrt_llm::kernels::WeightOnlyActivationFunctionType::Identity, weight_only_act_type}; - tensorrt_llm::kernels::weight_only_batched_gemv_launcher(params, stream); - } - else - { - // Use cutlass kernels for large batch size - int const ws_bytes = m_weightOnlyGroupwiseGemmRunner->getWorkspaceSize(m, n, k); - - int32_t* weight_ptr = const_cast(reinterpret_cast(inputs[mWeightInputIdx])); - - auto const& bestTactic = mPluginProfiler->getBestConfig(m, mGemmId); - TLLM_CHECK_WITH_INFO(bestTactic, - "No valid weight only groupwise GEMM tactic(It is usually caused by the failure to execute all " - "candidate " - "configurations of the CUTLASS kernel, please pay attention to the warning information when building " - "the " - "engine.)"); - m_weightOnlyGroupwiseGemmRunner->gemm(act_ptr, weight_ptr, inputs[mScalesInputIdx], zeros_ptr, biases_ptr, - outputs[0], m, real_n, k, mGroupSize, *bestTactic, - reinterpret_cast(workspace) + m * k * sizeof(half), ws_bytes, stream); - } + int const ws_bytes = m_weightOnlyGroupwiseGemmRunner->getWorkspaceSize(m, n, k); + + int32_t* weight_ptr = const_cast(reinterpret_cast(inputs[mWeightInputIdx])); + + auto const& bestTactic = mPluginProfiler->getBestConfig(m, mGemmId); + TLLM_CHECK_WITH_INFO(bestTactic, + "No valid weight only groupwise GEMM tactic(It is usually caused by the failure to execute all " + "candidate " + "configurations of the CUTLASS kernel, please pay attention to the warning information when building " + "the " + "engine.)"); + m_weightOnlyGroupwiseGemmRunner->gemm(act_ptr, weight_ptr, inputs[mScalesInputIdx], zeros_ptr, biases_ptr, + alpha, outputs[0], m, real_n, k, mGroupSize, *bestTactic, + reinterpret_cast(workspace) + m * k * sizeof(half), ws_bytes, stream); } - return 0; } diff --git a/cpp/tensorrt_llm/plugins/weightOnlyGroupwiseQuantMatmulPlugin/weightOnlyGroupwiseQuantMatmulPlugin.h b/cpp/tensorrt_llm/plugins/weightOnlyGroupwiseQuantMatmulPlugin/weightOnlyGroupwiseQuantMatmulPlugin.h index 5c664f5fb..5d4a8c90e 100644 --- a/cpp/tensorrt_llm/plugins/weightOnlyGroupwiseQuantMatmulPlugin/weightOnlyGroupwiseQuantMatmulPlugin.h +++ b/cpp/tensorrt_llm/plugins/weightOnlyGroupwiseQuantMatmulPlugin/weightOnlyGroupwiseQuantMatmulPlugin.h @@ -126,6 +126,8 @@ class WeightOnlyGroupwiseQuantMatmulPlugin : public BasePlugin size_t m_workspaceMaxSize; nvinfer1::DataType mType; bool mCudaKernelEnabled; + tensorrt_llm::kernels::weight_only::KernelType mCudaKernelType; + int mArch; // When M is smaller than this value, we trigger a fast path // I.e. a tailored kernel instead of cutlass. diff --git a/cpp/tensorrt_llm/plugins/weightOnlyQuantMatmulPlugin/weightOnlyQuantMatmulPlugin.cpp b/cpp/tensorrt_llm/plugins/weightOnlyQuantMatmulPlugin/weightOnlyQuantMatmulPlugin.cpp index e4f57f31e..3aeb11198 100644 --- a/cpp/tensorrt_llm/plugins/weightOnlyQuantMatmulPlugin/weightOnlyQuantMatmulPlugin.cpp +++ b/cpp/tensorrt_llm/plugins/weightOnlyQuantMatmulPlugin/weightOnlyQuantMatmulPlugin.cpp @@ -15,7 +15,6 @@ * limitations under the License. */ #include "weightOnlyQuantMatmulPlugin.h" -#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/enabled.h" using namespace nvinfer1; using namespace tensorrt_llm::common; @@ -110,29 +109,34 @@ WeightOnlyQuantMatmulPlugin::WeightOnlyQuantMatmulPlugin( void WeightOnlyQuantMatmulPlugin::init(nvinfer1::DataType type, WeightTypeId weightTypeId) { + mArch = tensorrt_llm::common::getSMVersion(); mType = type; mWeightTypeId = weightTypeId; + if (mWeightTypeId == WeightTypeId::INT8) { if (mType == nvinfer1::DataType::kHALF) { m_weightOnlyGemmRunner = std::make_shared< CutlassFpAIntBGemmRunner>(); + mCudaKernelEnabled = tensorrt_llm::kernels::weight_only::is_supported( + mArch, tensorrt_llm::kernels::weight_only::KernelType::FP16Int8PerChannel); + mCudaKernelType = tensorrt_llm::kernels::weight_only::KernelType::FP16Int8PerChannel; } #if defined(ENABLE_BF16) else if (mType == nvinfer1::DataType::kBF16) { m_weightOnlyGemmRunner = std::make_shared< CutlassFpAIntBGemmRunner<__nv_bfloat16, uint8_t, cutlass::WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY>>(); + mCudaKernelEnabled = tensorrt_llm::kernels::weight_only::is_supported( + mArch, tensorrt_llm::kernels::weight_only::KernelType::BF16Int8PerChannel); + mCudaKernelType = tensorrt_llm::kernels::weight_only::KernelType::BF16Int8PerChannel; } #endif else { TLLM_CHECK(false); } - - mCudaKernelEnabled - = tensorrt_llm::kernels::isWeightOnlyBatchedGemvEnabled(tensorrt_llm::kernels::WeightOnlyQuantType::Int8b); } else if (mWeightTypeId == WeightTypeId::INT4) { @@ -140,20 +144,24 @@ void WeightOnlyQuantMatmulPlugin::init(nvinfer1::DataType type, WeightTypeId wei { m_weightOnlyGemmRunner = std::make_shared< CutlassFpAIntBGemmRunner>(); + mCudaKernelEnabled = tensorrt_llm::kernels::weight_only::is_supported( + mArch, tensorrt_llm::kernels::weight_only::KernelType::FP16Int4PerChannel); + mCudaKernelType = tensorrt_llm::kernels::weight_only::KernelType::FP16Int4PerChannel; } #if defined(ENABLE_BF16) else if (mType == nvinfer1::DataType::kBF16) { m_weightOnlyGemmRunner = std::make_shared>(); + mCudaKernelEnabled = tensorrt_llm::kernels::weight_only::is_supported( + mArch, tensorrt_llm::kernels::weight_only::KernelType::BF16Int4PerChannel); + mCudaKernelType = tensorrt_llm::kernels::weight_only::KernelType::BF16Int4PerChannel; } #endif else { TLLM_CHECK(false); } - mCudaKernelEnabled - = tensorrt_llm::kernels::isWeightOnlyBatchedGemvEnabled(tensorrt_llm::kernels::WeightOnlyQuantType::Int4b); } else { @@ -296,38 +304,16 @@ int WeightOnlyQuantMatmulPlugin::enqueue(nvinfer1::PluginTensorDesc const* input #else TLLM_CHECK_WITH_INFO(mType == nvinfer1::DataType::kHALF, "No valid weightOnlyQuantMatmul configuration"); #endif - - tensorrt_llm::kernels::WeightOnlyQuantType weight_only_quant_type; - tensorrt_llm::kernels::WeightOnlyActivationType weight_only_act_type; - int real_n; - if (mType == nvinfer1::DataType::kHALF) - { - weight_only_act_type = tensorrt_llm::kernels::WeightOnlyActivationType::FP16; - } - else if (mType == nvinfer1::DataType::kBF16) - { - weight_only_act_type = tensorrt_llm::kernels::WeightOnlyActivationType::BF16; - } - if (mWeightTypeId == WeightTypeId::INT8) - { - weight_only_quant_type = tensorrt_llm::kernels::WeightOnlyQuantType::Int8b; - real_n = n; - } - else if (mWeightTypeId == WeightTypeId::INT4) - { - weight_only_quant_type = tensorrt_llm::kernels::WeightOnlyQuantType::Int4b; - real_n = n * INT8_INT4_RATIO; - } - if (use_cuda_kernel && getSMVersion() < 90) + int real_n = mWeightTypeId == WeightTypeId::INT4 ? n * INT8_INT4_RATIO : n; + if (use_cuda_kernel) { - // Use CUDA kernels for small batch size - // The CUDA kernel is designed for ColumnMajorTileInterleave weight layout used in fpAIntB cutlass - // kernel when sm >= 75 and the preprocessing of cutlass on sm70 does not interleave the weights. - tensorrt_llm::kernels::WeightOnlyParams params{reinterpret_cast(inputs[1]), inputs[2], nullptr, - inputs[0], nullptr, nullptr, outputs[0], m, real_n, k, 0, weight_only_quant_type, - tensorrt_llm::kernels::WeightOnlyType::PerChannel, - tensorrt_llm::kernels::WeightOnlyActivationFunctionType::Identity, weight_only_act_type}; - tensorrt_llm::kernels::weight_only_batched_gemv_launcher(params, stream); + void const* cuda_kernel_act_ptr = inputs[0]; + void const* cuda_kernel_weight_ptr = inputs[1]; + void const* cuda_kernel_scales_ptr = inputs[2]; + void* cuda_kernel_out_ptr = outputs[0]; + tensorrt_llm::kernels::weight_only::Params params(cuda_kernel_act_ptr, nullptr, cuda_kernel_weight_ptr, + cuda_kernel_scales_ptr, nullptr, nullptr, cuda_kernel_out_ptr, 1.f, m, real_n, k, 0, mCudaKernelType); + tensorrt_llm::kernels::weight_only::kernel_launcher(mArch, params, stream); } else { diff --git a/cpp/tensorrt_llm/plugins/weightOnlyQuantMatmulPlugin/weightOnlyQuantMatmulPlugin.h b/cpp/tensorrt_llm/plugins/weightOnlyQuantMatmulPlugin/weightOnlyQuantMatmulPlugin.h index aac4cd6e0..2683846b3 100644 --- a/cpp/tensorrt_llm/plugins/weightOnlyQuantMatmulPlugin/weightOnlyQuantMatmulPlugin.h +++ b/cpp/tensorrt_llm/plugins/weightOnlyQuantMatmulPlugin/weightOnlyQuantMatmulPlugin.h @@ -129,6 +129,8 @@ class WeightOnlyQuantMatmulPlugin : public BasePlugin nvinfer1::DataType mType; WeightTypeId mWeightTypeId; bool mCudaKernelEnabled; + tensorrt_llm::kernels::weight_only::KernelType mCudaKernelType; + int mArch; // When M is smaller than this value, we trigger a fast path // I.e. a tailored kernel instead of cutlass. diff --git a/cpp/tensorrt_llm/pybind/CMakeLists.txt b/cpp/tensorrt_llm/pybind/CMakeLists.txt index 9f3fd620d..c91fda7b5 100644 --- a/cpp/tensorrt_llm/pybind/CMakeLists.txt +++ b/cpp/tensorrt_llm/pybind/CMakeLists.txt @@ -30,6 +30,8 @@ set(SRCS batch_manager/llmRequest.cpp batch_manager/inferenceRequest.cpp batch_manager/namedTensor.cpp + executor/bindings.cpp + executor/executor.cpp runtime/generationInput.cpp runtime/generationOutput.cpp) diff --git a/cpp/tensorrt_llm/pybind/batch_manager/inferenceRequest.cpp b/cpp/tensorrt_llm/pybind/batch_manager/inferenceRequest.cpp index 96bd90fff..f8abd8780 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/inferenceRequest.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/inferenceRequest.cpp @@ -142,6 +142,7 @@ void InferenceRequest::initBindings(py::module_& m) &InferenceRequest::setPromptEmbeddingTable) .def_property( "prompt_vocab_size", &InferenceRequest::getPromptVocabSizeUnchecked, &InferenceRequest::setPromptVocabSize) + .def_property("lora_task_id", &InferenceRequest::getLoraTaskId, &InferenceRequest::setLoraTaskId) .def_property("lora_weights", &InferenceRequest::getLoraWeightsUnchecked, &InferenceRequest::setLoraWeights) .def_property("lora_config", &InferenceRequest::getLoraConfigUnchecked, &InferenceRequest::setLoraConfig) .def_property("is_streaming", &InferenceRequest::isStreaming, &InferenceRequest::setIsStreaming) diff --git a/cpp/tensorrt_llm/pybind/batch_manager/inferenceRequest.h b/cpp/tensorrt_llm/pybind/batch_manager/inferenceRequest.h index 1f82528c5..98ae79b34 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/inferenceRequest.h +++ b/cpp/tensorrt_llm/pybind/batch_manager/inferenceRequest.h @@ -24,6 +24,7 @@ #include #include +#include #include #include diff --git a/cpp/tensorrt_llm/pybind/batch_manager/llmRequest.cpp b/cpp/tensorrt_llm/pybind/batch_manager/llmRequest.cpp index 9b7566252..7a2fcc282 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/llmRequest.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/llmRequest.cpp @@ -79,8 +79,8 @@ std::shared_ptr LlmRequest::toTrtLlm() const return std::make_shared(mRequestId, mMaxNewTokens, std::make_shared>(mTokens.at(0)), mSamplingConfig, mIsStreaming, mEndId, mPadId, - embeddingBias, badWordsList, stopWordsList, promptEmbeddingTable, mPromptVocabSize, loraWeights, loraConfig, - mReturnLogProbs, mReturnContextLogits, mReturnGenerationLogits, mDraftTokens, draftLogits, + embeddingBias, badWordsList, stopWordsList, promptEmbeddingTable, mPromptVocabSize, mLoraTaskId, loraWeights, + loraConfig, mReturnLogProbs, mReturnContextLogits, mReturnGenerationLogits, mDraftTokens, draftLogits, mExcludeInputFromOutput, callbackAdapter(mLogitsPostProcessor)); } @@ -91,18 +91,19 @@ void LlmRequest::initBindings(py::module_& m) std::optional, std::optional, std::optional, std::optional, std::optional, std::optional, - std::optional, std::optional, + std::optional, std::optional, std::optional, std::optional, bool, bool, bool, std::optional, std::optional, bool, std::optional>(), py::arg("request_id"), py::arg("max_new_tokens"), py::arg("input_tokens"), py::arg("sampling_config"), py::arg("is_streaming"), py::arg("end_id") = std::nullopt, py::arg("pad_id") = std::nullopt, py::arg("embedding_bias") = std::nullopt, py::arg("bad_words_list") = std::nullopt, py::arg("stop_words_list") = std::nullopt, py::arg("prompt_embedding_table") = std::nullopt, - py::arg("prompt_vocab_size") = std::nullopt, py::arg("lora_weights") = std::nullopt, - py::arg("lora_config") = std::nullopt, py::arg("return_log_probs") = false, - py::arg("return_context_logits") = false, py::arg("return_generation_logits") = false, - py::arg("draft_tokens") = std::nullopt, py::arg("draft_logits") = std::nullopt, - py::arg("exclude_input_from_output") = false, py::arg("logits_post_processor") = std::nullopt) + py::arg("prompt_vocab_size") = std::nullopt, py::arg("lora_task_id") = std::nullopt, + py::arg("lora_weights") = std::nullopt, py::arg("lora_config") = std::nullopt, + py::arg("return_log_probs") = false, py::arg("return_context_logits") = false, + py::arg("return_generation_logits") = false, py::arg("draft_tokens") = std::nullopt, + py::arg("draft_logits") = std::nullopt, py::arg("exclude_input_from_output") = false, + py::arg("logits_post_processor") = std::nullopt) .def("get_num_tokens", &LlmRequest::getNumTokens, py::arg("beam")) .def_property_readonly("max_beam_num_tokens", &LlmRequest::getMaxBeamNumTokens) .def("get_token", &LlmRequest::getToken, py::arg("beam"), py::arg("pos")) @@ -116,6 +117,9 @@ void LlmRequest::initBindings(py::module_& m) .def_property("max_sent_token_pos", &LlmRequest::getMaxSentTokenPos, &LlmRequest::setMaxSentTokenPos) .def_property_readonly("prompt_embedding_table", &LlmRequest::getPromptEmbeddingTable) .def_property_readonly("prompt_vocab_size", &LlmRequest::getPromptVocabSize) + .def_property_readonly("lora_task_id", &LlmRequest::getLoraTaskId) + .def_property_readonly("lora_weights", &LlmRequest::getLoraWeights) + .def_property_readonly("lora_config", &LlmRequest::getLoraConfig) .def_property_readonly("embedding_bias", &LlmRequest::getEmbeddingBias) .def_property_readonly("bad_words_list", &LlmRequest::getBadWordsList) .def_property_readonly("stop_words_list", &LlmRequest::getStopWordsList) diff --git a/cpp/tensorrt_llm/pybind/batch_manager/llmRequest.h b/cpp/tensorrt_llm/pybind/batch_manager/llmRequest.h index ccd9c545b..9628c854c 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/llmRequest.h +++ b/cpp/tensorrt_llm/pybind/batch_manager/llmRequest.h @@ -42,6 +42,7 @@ class LlmRequest : public tb::GenericLlmRequest using SizeType = Base::SizeType; using TokenIdType = Base::TokenIdType; using RequestIdType = Base::RequestIdType; + using LoraTaskIdType = Base::LoraTaskIdType; using VecLogProbs = Base::VecLogProbs; using BeamTokens = Base::BeamTokens; using VecTokens = Base::VecTokens; @@ -52,14 +53,15 @@ class LlmRequest : public tb::GenericLlmRequest std::optional padId = std::nullopt, std::optional embeddingBias = std::nullopt, std::optional badWordsList = std::nullopt, std::optional stopWordsList = std::nullopt, std::optional promptEmbeddingTable = std::nullopt, - std::optional promptVocabSize = std::nullopt, std::optional loraWeights = std::nullopt, - std::optional loraConfig = std::nullopt, bool returnLogProbs = false, - bool returnContextLogits = false, bool returnGenerationLogits = false, + std::optional promptVocabSize = std::nullopt, std::optional loraTaskId = std::nullopt, + std::optional loraWeights = std::nullopt, std::optional loraConfig = std::nullopt, + bool returnLogProbs = false, bool returnContextLogits = false, bool returnGenerationLogits = false, std::optional draftTokens = std::nullopt, std::optional draftLogits = std::nullopt, bool excludeInputFromOutput = false, std::optional logitsPostProcessor = std::nullopt) : Base(requestId, maxNewTokens, std::make_shared>(std::move(inputTokens)), samplingConfig, isStreaming, endId, padId, embeddingBias, badWordsList, stopWordsList, promptEmbeddingTable, - promptVocabSize, loraWeights, loraConfig, returnLogProbs, returnContextLogits, returnGenerationLogits, + promptVocabSize, loraTaskId, loraWeights, loraConfig, returnLogProbs, returnContextLogits, + returnGenerationLogits, draftTokens.has_value() ? std::make_shared(std::move(draftTokens.value())) : std::make_shared(), draftLogits, excludeInputFromOutput, logitsPostProcessor) diff --git a/cpp/tensorrt_llm/pybind/bindings.cpp b/cpp/tensorrt_llm/pybind/bindings.cpp index b5a9c8c29..063eda461 100644 --- a/cpp/tensorrt_llm/pybind/bindings.cpp +++ b/cpp/tensorrt_llm/pybind/bindings.cpp @@ -27,6 +27,7 @@ #include "tensorrt_llm/pybind/batch_manager/inferenceRequest.h" #include "tensorrt_llm/pybind/batch_manager/llmRequest.h" #include "tensorrt_llm/pybind/batch_manager/namedTensor.h" +#include "tensorrt_llm/pybind/executor/bindings.h" #include "tensorrt_llm/pybind/runtime/generationInput.h" #include "tensorrt_llm/pybind/runtime/generationOutput.h" #include "tensorrt_llm/pybind/utils/pathCaster.h" @@ -383,4 +384,8 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m) .def_property_readonly("cpu", &tr::MemoryCounters::getCpu) .def_property_readonly("pinned", &tr::MemoryCounters::getPinned) .def_property_readonly("uvm", &tr::MemoryCounters::getUVM); + + // Create submodule for executor bindings. + py::module_ executor_submodule = m.def_submodule("executor", "Executor bindings"); + tensorrt_llm::pybind::executor::InitBindings(executor_submodule); } diff --git a/cpp/tensorrt_llm/pybind/executor/bindings.cpp b/cpp/tensorrt_llm/pybind/executor/bindings.cpp new file mode 100644 index 000000000..1f96c457e --- /dev/null +++ b/cpp/tensorrt_llm/pybind/executor/bindings.cpp @@ -0,0 +1,297 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#include "bindings.h" +#include "executor.h" +#include "tensorCaster.h" + +#include "tensorrt_llm/executor/executor.h" +#include "tensorrt_llm/executor/tensor.h" +#include "tensorrt_llm/executor/types.h" + +namespace py = pybind11; +namespace tle = tensorrt_llm::executor; +using Tensor = tle::Tensor; +using SizeType = tle::SizeType; +using FloatType = tle::FloatType; +using VecTokens = tle::VecTokens; +using IdType = tle::IdType; + +namespace tensorrt_llm::pybind::executor +{ + +void InitBindings(pybind11::module_& m) +{ + py::enum_(m, "ModelType").value("DECODER_ONLY", tle::ModelType::kDECODER_ONLY); + + py::enum_(m, "BatchingType") + .value("STATIC", tle::BatchingType::kSTATIC) + .value("INFLIGHT", tle::BatchingType::kINFLIGHT); + + py::enum_(m, "SchedulerPolicy") + .value("MAX_UTILIZATION", tle::SchedulerPolicy::kMAX_UTILIZATION) + .value("GUARANTEED_NO_EVICT", tle::SchedulerPolicy::kGUARANTEED_NO_EVICT); + + py::enum_(m, "CommunicationType").value("MPI", tle::CommunicationType::kMPI); + + py::enum_(m, "CommunicationMode").value("LEADER", tle::CommunicationMode::kLEADER); + + py::class_(m, "KvCacheStats") + .def(py::init<>()) + .def_readwrite("max_num_blocks", &tle::KvCacheStats::maxNumBlocks) + .def_readwrite("free_num_blocks", &tle::KvCacheStats::freeNumBlocks) + .def_readwrite("used_num_blocks", &tle::KvCacheStats::usedNumBlocks) + .def_readwrite("tokens_per_block", &tle::KvCacheStats::tokensPerBlock); + + py::class_(m, "StaticBatchingStats") + .def(py::init<>()) + .def_readwrite("num_scheduled_requests", &tle::StaticBatchingStats::numScheduledRequests) + .def_readwrite("num_context_requests", &tle::StaticBatchingStats::numContextRequests) + .def_readwrite("num_ctx_tokens", &tle::StaticBatchingStats::numCtxTokens) + .def_readwrite("num_gen_tokens", &tle::StaticBatchingStats::numGenTokens) + .def_readwrite("empty_gen_slots", &tle::StaticBatchingStats::emptyGenSlots); + + py::class_(m, "InflightBatchingStats") + .def(py::init<>()) + .def_readwrite("num_scheduled_requests", &tle::InflightBatchingStats::numScheduledRequests) + .def_readwrite("num_context_requests", &tle::InflightBatchingStats::numContextRequests) + .def_readwrite("num_gen_requests", &tle::InflightBatchingStats::numGenRequests) + .def_readwrite("num_paused_requests", &tle::InflightBatchingStats::numPausedRequests) + .def_readwrite("num_ctx_tokens", &tle::InflightBatchingStats::numCtxTokens) + .def_readwrite("micro_batch_id", &tle::InflightBatchingStats::microBatchId); + + py::class_(m, "IterationStats") + .def(py::init<>()) + .def_readwrite("timestamp", &tle::IterationStats::timestamp) + .def_readwrite("iter", &tle::IterationStats::iter) + .def_readwrite("num_active_requests", &tle::IterationStats::numActiveRequests) + .def_readwrite("max_num_active_requests", &tle::IterationStats::maxNumActiveRequests) + .def_readwrite("gpu_mem_usage", &tle::IterationStats::gpuMemUsage) + .def_readwrite("cpu_mem_usage", &tle::IterationStats::cpuMemUsage) + .def_readwrite("pinned_mem_usage", &tle::IterationStats::pinnedMemUsage) + .def_readwrite("kv_cache_stats", &tle::IterationStats::kvCacheStats) + .def_readwrite("static_batching_stats", &tle::IterationStats::staticBatchingStats) + .def_readwrite("inflight_batching_stats", &tle::IterationStats::inflightBatchingStats); + + py::enum_(m, "RequestStage") + .value("QUEUED", tle::RequestStage::kQUEUED) + .value("CONTEXT_IN_PROGRESS", tle::RequestStage::kCONTEXT_IN_PROGRESS) + .value("GENERATION_IN_PROGRESS", tle::RequestStage::kGENERATION_IN_PROGRESS) + .value("GENERATION_COMPLETE", tle::RequestStage::kGENERATION_COMPLETE); + + py::class_(m, "RequestStats") + .def(py::init<>()) + .def_readwrite("id", &tle::RequestStats::id) + .def_readwrite("stage", &tle::RequestStats::stage) + .def_readwrite("context_prefill_position", &tle::RequestStats::contextPrefillPosition) + .def_readwrite("num_generated_tokens", &tle::RequestStats::numGeneratedTokens) + .def_readwrite("scheduled", &tle::RequestStats::scheduled) + .def_readwrite("paused", &tle::RequestStats::paused); + + py::class_(m, "RequestStatsPerIteration") + .def(py::init<>()) + .def_readwrite("iter", &tle::RequestStatsPerIteration::iter) + .def_readwrite("request_stats", &tle::RequestStatsPerIteration::requestStats); + + py::class_(m, "SamplingConfig") + .def(py::init, std::optional, std::optional, + std::optional, std::optional, std::optional, + std::optional, std::optional, std::optional, std::optional, + std::optional, std::optional, std::optional, + std::optional>(), + py::arg("beam_width") = 1, py::arg("top_k") = py::none(), py::arg("top_p") = py::none(), + py::arg("top_p_min") = py::none(), py::arg("top_p_reset_ids") = py::none(), + py::arg("top_p_decay") = py::none(), py::arg("random_seed") = py::none(), + py::arg("temperature") = py::none(), py::arg("min_length") = py::none(), + py::arg("beam_search_diversity_rate") = py::none(), py::arg("repetition_penalty") = py::none(), + py::arg("presence_penalty") = py::none(), py::arg("frequency_penalty") = py::none(), + py::arg("length_penalty") = py::none(), py::arg("early_stopping") = py::none()) + .def_property_readonly("beam_width", &tle::SamplingConfig::getBeamWidth) + .def_property_readonly("top_k", &tle::SamplingConfig::getTopK) + .def_property_readonly("top_p", &tle::SamplingConfig::getTopP) + .def_property_readonly("top_p_min", &tle::SamplingConfig::getTopPMin) + .def_property_readonly("top_p_reset_ids", &tle::SamplingConfig::getTopPResetIds) + .def_property_readonly("top_p_decay", &tle::SamplingConfig::getTopPDecay) + .def_property_readonly("random_seed", &tle::SamplingConfig::getRandomSeed) + .def_property_readonly("temperature", &tle::SamplingConfig::getTemperature) + .def_property_readonly("min_length", &tle::SamplingConfig::getMinLength) + .def_property_readonly("beam_search_diversity_rate", &tle::SamplingConfig::getBeamSearchDiversityRate) + .def_property_readonly("repetition_penalty", &tle::SamplingConfig::getRepetitionPenalty) + .def_property_readonly("presence_penalty", &tle::SamplingConfig::getPresencePenalty) + .def_property_readonly("frequency_penalty", &tle::SamplingConfig::getFrequencyPenalty) + .def_property_readonly("length_penalty", &tle::SamplingConfig::getLengthPenalty) + .def_property_readonly("early_stopping", &tle::SamplingConfig::getEarlyStopping); + + py::class_(m, "OutputConfig") + .def(py::init(), py::arg("return_log_probs") = false, + py::arg("return_context_logits") = false, py::arg("return_generation_logits") = false, + py::arg("exclude_input_from_output") = false) + .def_readwrite("return_log_probs", &tle::OutputConfig::returnLogProbs) + .def_readwrite("return_context_logits", &tle::OutputConfig::returnContextLogits) + .def_readwrite("return_generation_logits", &tle::OutputConfig::returnGenerationLogits) + .def_readwrite("exclude_input_from_output", &tle::OutputConfig::excludeInputFromOutput); + + py::class_(m, "SpeculativeDecodingConfig") + .def(py::init, std::optional>(), py::arg("tokens"), + py::arg("logits") = py::none(), py::arg("acceptance_threshold") = py::none()) + .def_property_readonly("tokens", &tle::SpeculativeDecodingConfig::getTokens) + .def_property_readonly("logits", &tle::SpeculativeDecodingConfig::getLogits) + .def_property_readonly("acceptance_threshold", &tle::SpeculativeDecodingConfig::getAcceptanceThreshold); + + py::class_(m, "PromptTuningConfig") + .def(py::init(), py::arg("embedding_table")) + .def_property_readonly("embedding_table", &tle::PromptTuningConfig::getEmbeddingTable); + + py::class_(m, "LoraConfig") + .def(py::init, std::optional>(), py::arg("task_id"), + py::arg("weights") = py::none(), py::arg("config") = py::none()) + .def_property_readonly("task_id", &tle::LoraConfig::getTaskId) + .def_property_readonly("weights", &tle::LoraConfig::getWeights) + .def_property_readonly("config", &tle::LoraConfig::getConfig); + + py::class_(m, "Request") + .def(py::init, + std::optional, std::optional>, std::optional>, + std::optional, std::optional, + std::optional, std::optional>(), + py::arg("input_token_ids"), py::arg("max_new_tokens"), py::arg("streaming") = false, + py::arg_v("sampling_config", tle::SamplingConfig(), "SamplingConfig()"), + py::arg_v("output_config", tle::OutputConfig(), "OutputConfig()"), py::arg("end_id") = py::none(), + py::arg("pad_id") = py::none(), py::arg("bad_words") = py::none(), py::arg("stop_words") = py::none(), + py::arg("embedding_bias") = py::none(), py::arg("speculative_decoding_config") = py::none(), + py::arg("prompt_tuning_config") = py::none(), py::arg("lora_config") = py::none()) + .def_property_readonly("input_token_ids", &tle::Request::getInputTokenIds) + .def_property_readonly("max_new_tokens", &tle::Request::getMaxNewTokens) + .def_property("streaming", &tle::Request::getStreaming, &tle::Request::setStreaming) + .def_property("sampling_config", &tle::Request::getSamplingConfig, &tle::Request::setSamplingConfig) + .def_property("output_config", &tle::Request::getOutputConfig, &tle::Request::setOutputConfig) + .def_property("end_id", &tle::Request::getEndId, &tle::Request::setEndId) + .def_property("pad_id", &tle::Request::getPadId, &tle::Request::setPadId) + .def_property("bad_words", &tle::Request::getBadWords, &tle::Request::setBadWords) + .def_property("stop_words", &tle::Request::getStopWords, &tle::Request::setStopWords) + .def_property("embedding_bias", &tle::Request::getEmbeddingBias, &tle::Request::setEmbeddingBias) + .def_property("speculative_decoding_config", &tle::Request::getSpeculativeDecodingConfig, + &tle::Request::setSpeculativeDecodingConfig) + .def_property( + "prompt_tuning_config", &tle::Request::getPromptTuningConfig, &tle::Request::setPromptTuningConfig) + .def_property("lora_config", &tle::Request::getLoraConfig, &tle::Request::setLoraConfig); + + py::class_(m, "Result") + .def(py::init<>()) + .def_readwrite("is_final", &tle::Result::isFinal) + .def_readwrite("output_token_ids", &tle::Result::outputTokenIds) + .def_readwrite("cum_log_probs", &tle::Result::cumLogProbs) + .def_readwrite("log_probs", &tle::Result::logProbs) + .def_readwrite("context_logits", &tle::Result::contextLogits) + .def_readwrite("generation_logits", &tle::Result::generationLogits); + + py::class_(m, "Response") + .def(py::init(), py::arg("request_id"), py::arg("error_msg")) + .def(py::init(), py::arg("request_id"), py::arg("result")) + .def_property_readonly("request_id", &tle::Response::getRequestId) + .def("has_error", &tle::Response::hasError) + .def_property_readonly("error_msg", &tle::Response::getErrorMsg) + .def_property_readonly("result", &tle::Response::getResult); + + py::class_(m, "SchedulerConfig") + .def(py::init(), + py::arg_v("policy", tle::SchedulerPolicy::kGUARANTEED_NO_EVICT, "SchedulerPolicy.GUARANTEED_NO_EVICT")) + .def_property_readonly("policy", &tle::SchedulerConfig::getPolicy); + + py::class_(m, "KvCacheConfig") + .def(py::init, std::optional, std::optional, + std::optional>(), + py::arg("enable_block_reuse") = false, py::arg("max_tokens") = py::none(), + py::arg("max_attention_window") = py::none(), py::arg("sink_token_length") = py::none(), + py::arg("free_gpu_memory_fraction") = py::none()) + .def_property_readonly("enable_block_reuse", &tle::KvCacheConfig::getEnableBlockReuse) + .def_property_readonly("max_tokens", &tle::KvCacheConfig::getMaxTokens) + .def_property_readonly("max_attention_window", &tle::KvCacheConfig::getMaxAttentionWindow) + .def_property_readonly("sink_token_length", &tle::KvCacheConfig::getSinkTokenLength) + .def_property_readonly("free_gpu_memory_fraction", &tle::KvCacheConfig::getFreeGpuMemoryFraction); + + py::class_(m, "ParallelConfig") + .def(py::init>, + std::optional>>(), + py::arg_v("communication_type", tle::CommunicationType::kMPI, "CommunicationType.MPI"), + py::arg_v("communication_mode", tle::CommunicationMode::kLEADER, "CommunicationMode.LEADER"), + py::arg("device_ids") = py::none(), py::arg("participant_ids") = py::none()) + .def_property("communication_type", &tle::ParallelConfig::getCommunicationType, + &tle::ParallelConfig::setCommunicationType) + .def_property("communication_mode", &tle::ParallelConfig::getCommunicationMode, + &tle::ParallelConfig::setCommunicationMode) + .def_property("device_ids", &tle::ParallelConfig::getDeviceIds, &tle::ParallelConfig::setDeviceIds) + .def_property( + "participant_ids", &tle::ParallelConfig::getParticipantIds, &tle::ParallelConfig::setParticipantIds); + + py::class_(m, "PeftCacheConfig") + .def(py::init, std::optional>(), + py::arg("num_host_module_layer") = 0, py::arg("num_device_module_layer") = 0, + py::arg("optimal_adapter_size") = 8, py::arg("max_adapter_size") = 64, py::arg("num_put_workers") = 1, + py::arg("num_ensure_workers") = 1, py::arg("num_copy_streams") = 1, + py::arg("max_pages_per_block_host") = 24, py::arg("max_pages_per_block_device") = 8, + py::arg("device_cache_percent") = py::none(), py::arg("host_cache_size") = py::none()) + .def_property_readonly("num_host_module_layer", &tle::PeftCacheConfig::getNumHostModuleLayer) + .def_property_readonly("num_device_module_layer", &tle::PeftCacheConfig::getNumDeviceModuleLayer) + .def_property_readonly("optimal_adapter_size", &tle::PeftCacheConfig::getOptimalAdapterSize) + .def_property_readonly("max_adapter_size", &tle::PeftCacheConfig::getMaxAdapterSize) + .def_property_readonly("num_put_workers", &tle::PeftCacheConfig::getNumPutWorkers) + .def_property_readonly("num_ensure_workers", &tle::PeftCacheConfig::getNumEnsureWorkers) + .def_property_readonly("num_copy_streams", &tle::PeftCacheConfig::getNumCopyStreams) + .def_property_readonly("max_pages_per_block_host", &tle::PeftCacheConfig::getMaxPagesPerBlockHost) + .def_property_readonly("max_pages_per_block_device", &tle::PeftCacheConfig::getMaxPagesPerBlockDevice) + .def_property_readonly("device_cache_percent", &tle::PeftCacheConfig::getDeviceCachePercent) + .def_property_readonly("host_cache_size", &tle::PeftCacheConfig::getHostCacheSize); + + py::class_(m, "ExecutorConfig") + .def(py::init, tle::PeftCacheConfig>(), + py::arg("max_beam_width") = 1, py::arg_v("scheduler_config", tle::SchedulerConfig(), "SchedulerConfig()"), + py::arg_v("kv_cache_config", tle::KvCacheConfig(), "KvCacheConfig()"), + py::arg("enable_chunked_context") = false, py::arg("normalize_log_probs") = true, + py::arg("iter_stats_max_iterations") = tle::kDefaultIterStatsMaxIterations, + py::arg("request_stats_max_iterations") = tle::kDefaultRequestStatsMaxIterations, + py::arg_v("batching_type", tle::BatchingType::kINFLIGHT, "BatchingType.INFLIGHT"), + py::arg("parallel_config") = py::none(), + py::arg_v("peft_cache_config", tle::PeftCacheConfig(), "PeftCacheConfig()")) + .def_property("max_beam_width", &tle::ExecutorConfig::getMaxBeamWidth, &tle::ExecutorConfig::setMaxBeamWidth) + .def_property( + "scheduler_config", &tle::ExecutorConfig::getSchedulerConfig, &tle::ExecutorConfig::setSchedulerConfig) + .def_property("kv_cache_config", &tle::ExecutorConfig::getKvCacheConfig, &tle::ExecutorConfig::setKvCacheConfig) + .def_property("enable_chunked_context", &tle::ExecutorConfig::getEnableChunkedContext, + &tle::ExecutorConfig::setEnableChunkedContext) + .def_property("normalize_log_probs", &tle::ExecutorConfig::getNormalizeLogProbs, + &tle::ExecutorConfig::setNormalizeLogProbs) + .def_property("iter_stats_max_iterations", &tle::ExecutorConfig::getIterStatsMaxIterations, + &tle::ExecutorConfig::setIterStatsMaxIterations) + .def_property("request_stats_max_iterations", &tle::ExecutorConfig::getRequestStatsMaxIterations, + &tle::ExecutorConfig::setRequestStatsMaxIterations) + .def_property("batching_type", &tle::ExecutorConfig::getBatchingType, &tle::ExecutorConfig::setBatchingType) + .def_property( + "parallel_config", &tle::ExecutorConfig::getParallelConfig, &tle::ExecutorConfig::setParallelConfig) + .def_property( + "peft_cache_config", &tle::ExecutorConfig::getPeftCacheConfig, &tle::ExecutorConfig::setPeftCacheConfig); + + tensorrt_llm::pybind::executor::Executor::initBindings(m); +} + +} // namespace tensorrt_llm::pybind::executor diff --git a/cpp/tensorrt_llm/pybind/executor/bindings.h b/cpp/tensorrt_llm/pybind/executor/bindings.h new file mode 100644 index 000000000..7a686b19b --- /dev/null +++ b/cpp/tensorrt_llm/pybind/executor/bindings.h @@ -0,0 +1,27 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include + +namespace tensorrt_llm::pybind::executor +{ + +// Register bindings for executor API. +void InitBindings(pybind11::module_& m); + +} // namespace tensorrt_llm::pybind::executor diff --git a/cpp/tensorrt_llm/pybind/executor/executor.cpp b/cpp/tensorrt_llm/pybind/executor/executor.cpp new file mode 100644 index 000000000..ccde0f6fe --- /dev/null +++ b/cpp/tensorrt_llm/pybind/executor/executor.cpp @@ -0,0 +1,89 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "executor.h" +#include "tensorrt_llm/common/assert.h" +#include "tensorrt_llm/common/logger.h" +#include "tensorrt_llm/pybind/utils/pathCaster.h" + +#include +#include +#include + +namespace py = pybind11; +namespace tle = tensorrt_llm::executor; + +namespace tensorrt_llm::pybind::executor +{ + +Executor::Executor( + std::filesystem::path const& modelPath, tle::ModelType modelType, tle::ExecutorConfig const& executorConfig) +{ + mExecutor = std::make_unique(modelPath, modelType, executorConfig); +} + +Executor::Executor(std::string const& engineBuffer, std::string const& jsonConfigStr, tle::ModelType modelType, + tle::ExecutorConfig const& executorConfig) +{ + mExecutor = std::make_unique( + std::vector(engineBuffer.begin(), engineBuffer.end()), jsonConfigStr, modelType, executorConfig); +} + +py::object Executor::enter() +{ + TLLM_CHECK(static_cast(mExecutor)); + return py::cast(this); +} + +void Executor::exit( + [[maybe_unused]] py::handle type, [[maybe_unused]] py::handle value, [[maybe_unused]] py::handle traceback) +{ + shutdown(); +} + +void Executor::shutdown() +{ + // NOTE: we must release the GIL here. Executor has spawned a thread for the execution loop. That thread must be + // able to do forward progress for the shutdown process to succeed. It takes the GIL during its callbacks, so + // we release it now. Note that we shouldn't do anything related to python objects after that. + TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + py::gil_scoped_release release; + mExecutor->shutdown(); + mExecutor = nullptr; + TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); +} + +void Executor::initBindings(py::module_& m) +{ + py::class_(m, "Executor") + .def(py::init(), py::arg("model_path"), + py::arg("model_type"), py::arg("executor_config")) + .def(py::init(), + py::arg("engine_buffer"), py::arg("json_config_str"), py::arg("model_type"), py::arg("executor_config")) + .def("shutdown", &Executor::shutdown) + .def("__enter__", &Executor::enter) + .def("__exit__", &Executor::exit) + .def("enqueue_request", &Executor::enqueueRequest, py::arg("request")) + .def("enqueue_requests", &Executor::enqueueRequests, py::arg("requests")) + .def("await_responses", &Executor::awaitResponses, py::arg("id") = py::none(), py::arg("timeout") = py::none()) + .def("get_num_responses_ready", &Executor::getNumResponsesReady, py::arg("id") = py::none()) + .def("cancel_request", &Executor::cancelRequest, py::arg("id") = py::none()) + .def("get_latest_iteration_stats", &Executor::getLatestIterationStats) + .def("get_latest_request_stats", &Executor::getLatestRequestStats); +} + +} // namespace tensorrt_llm::pybind::executor diff --git a/cpp/tensorrt_llm/pybind/executor/executor.h b/cpp/tensorrt_llm/pybind/executor/executor.h new file mode 100644 index 000000000..ff749a9b9 --- /dev/null +++ b/cpp/tensorrt_llm/pybind/executor/executor.h @@ -0,0 +1,85 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include "tensorrt_llm/executor/executor.h" +#include "tensorrt_llm/executor/types.h" +#include + +namespace tle = tensorrt_llm::executor; + +namespace tensorrt_llm::pybind::executor +{ + +class Executor +{ +public: + Executor( + std::filesystem::path const& modelPath, tle::ModelType modelType, tle::ExecutorConfig const& executorConfig); + + Executor(std::string const& engineBuffer, std::string const& jsonConfigStr, tle::ModelType modelType, + tle::ExecutorConfig const& executorConfig); + + pybind11::object enter(); + void exit([[maybe_unused]] pybind11::handle type, [[maybe_unused]] pybind11::handle value, + [[maybe_unused]] pybind11::handle traceback); + void shutdown(); + + tle::IdType enqueueRequest(tle::Request request) + { + return mExecutor->enqueueRequest(std::move(request)); + } + + std::vector enqueueRequests(std::vector requests) + { + return mExecutor->enqueueRequests(std::move(requests)); + } + + std::vector awaitResponses( + std::optional id = std::nullopt, std::optional timeout = std::nullopt) + { + + return mExecutor->awaitResponses(id, timeout); + } + + tle::SizeType getNumResponsesReady(std::optional id = std::nullopt) + { + return mExecutor->getNumResponsesReady(id); + } + + void cancelRequest(tle::IdType id) + { + mExecutor->cancelRequest(id); + } + + std::deque getLatestIterationStats() + { + return mExecutor->getLatestIterationStats(); + } + + std::deque getLatestRequestStats() + { + return mExecutor->getLatestRequestStats(); + } + + static void initBindings(pybind11::module_& m); + +private: + std::unique_ptr mExecutor; +}; + +} // namespace tensorrt_llm::pybind::executor diff --git a/cpp/tensorrt_llm/pybind/executor/tensorCaster.h b/cpp/tensorrt_llm/pybind/executor/tensorCaster.h new file mode 100644 index 000000000..894e0af30 --- /dev/null +++ b/cpp/tensorrt_llm/pybind/executor/tensorCaster.h @@ -0,0 +1,59 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#include "tensorrt_llm/executor/tensor.h" +#include "tensorrt_llm/runtime/torch.h" +#include "tensorrt_llm/runtime/torchView.h" +#include + +namespace PYBIND11_NAMESPACE +{ + +namespace detail +{ +template <> +struct type_caster +{ +public: + PYBIND11_TYPE_CASTER(tensorrt_llm::executor::Tensor, _("torch.Tensor")); + + // Convert PyObject(torch.Tensor) -> tensorrt_llm::executor::Tensor + bool load(handle src, bool) + { + PyObject* obj = src.ptr(); + if (THPVariable_Check(obj)) + { + at::Tensor const& t = THPVariable_Unpack(obj); + value = tensorrt_llm::executor::detail::ofITensor(tensorrt_llm::runtime::TorchView::of(t)); + return true; + } + return false; + } + + // Convert tensorrt_llm::executor::Tensor -> PyObject(torch.Tensor) + static handle cast(tensorrt_llm::executor::Tensor const& src, return_value_policy /* policy */, handle /* parent */) + { + return THPVariable_Wrap(tensorrt_llm::runtime::Torch::tensor(tensorrt_llm::executor::detail::toITensor(src))); + } +}; + +} // namespace detail +} // namespace PYBIND11_NAMESPACE diff --git a/cpp/tensorrt_llm/runtime/CMakeLists.txt b/cpp/tensorrt_llm/runtime/CMakeLists.txt index 7a90909ba..965d2accc 100644 --- a/cpp/tensorrt_llm/runtime/CMakeLists.txt +++ b/cpp/tensorrt_llm/runtime/CMakeLists.txt @@ -22,6 +22,7 @@ set(SRCS loraManager.cpp loraUtils.cpp loraModule.cpp + loraCache.cpp decodingOutput.cpp gptDecoder.cpp gptDecoderBatch.cpp diff --git a/cpp/tensorrt_llm/runtime/bufferManager.cpp b/cpp/tensorrt_llm/runtime/bufferManager.cpp index fbdc40f3c..3fa9c9d8d 100644 --- a/cpp/tensorrt_llm/runtime/bufferManager.cpp +++ b/cpp/tensorrt_llm/runtime/bufferManager.cpp @@ -27,8 +27,9 @@ using namespace tensorrt_llm::runtime; namespace tc = tensorrt_llm::common; -BufferManager::BufferManager(CudaStreamPtr stream) +BufferManager::BufferManager(CudaStreamPtr stream, bool trimPool) : mStream{std::move(stream)} + , mTrimPool{trimPool} { TLLM_CHECK_WITH_INFO(static_cast(mStream), "Undefined CUDA stream"); thread_local static std::unordered_set initializedDevices(8); diff --git a/cpp/tensorrt_llm/runtime/gptDecoder.cpp b/cpp/tensorrt_llm/runtime/gptDecoder.cpp index b3052f4d8..b6f1bf9a2 100644 --- a/cpp/tensorrt_llm/runtime/gptDecoder.cpp +++ b/cpp/tensorrt_llm/runtime/gptDecoder.cpp @@ -37,13 +37,12 @@ GptDecoder::GptDecoder(DecodingMode const& mode, size_t maxBatchSize, size_t size_t vocabSizePadded, size_t maxSequenceLength, CudaStreamPtr const& stream) : mManager{stream} { - cudaDeviceProp prop; - tc::check_cuda_error(cudaGetDeviceProperties(&prop, 0)); - + int deviceId; + tc::check_cuda_error(cudaGetDevice(&deviceId)); // Get the correct device id + tc::check_cuda_error(cudaGetDeviceProperties(&mProp, deviceId)); auto allocator = std::make_shared(mManager); - mDynamicDecodeLayer = std::make_shared>( - mode, maxBatchSize, maxBeamWidth, vocabSize, vocabSizePadded, stream->get(), std::move(allocator), &prop); + mode, maxBatchSize, maxBeamWidth, vocabSize, vocabSizePadded, stream->get(), std::move(allocator), &mProp); auto constexpr nvFloatType = TRTDataType::value; mLogProbsTiled = mManager.gpu(ITensor::makeShape({static_cast(maxSequenceLength), @@ -73,7 +72,7 @@ void GptDecoder::setup(SamplingConfig const& samplingConfig, size_t batchSize if (samplingConfig.topK) { auto const& topK = samplingConfig.topK.value(); - setupParams.runtime_top_k = std::vector(std::begin(topK), std::end(topK)); + setupParams.runtime_top_k = std::vector(std::begin(topK), std::end(topK)); } setupParams.runtime_top_p = samplingConfig.topP; diff --git a/cpp/tensorrt_llm/runtime/gptJsonConfig.cpp b/cpp/tensorrt_llm/runtime/gptJsonConfig.cpp index ba3816660..7e552b004 100644 --- a/cpp/tensorrt_llm/runtime/gptJsonConfig.cpp +++ b/cpp/tensorrt_llm/runtime/gptJsonConfig.cpp @@ -153,7 +153,7 @@ void parsePluginConfig(GptModelConfig& modelConfig, Json const& pluginConfig) void parseLora(GptModelConfig& modelConfig, Json const& json, Json const& pluginConfig, bool engineVersionNone, SizeType tensorParallelism) { - auto const& config = engineVersionNone ? json.at("builder_config") : json.at("pretrained_config"); + auto const& config = engineVersionNone ? json.at("builder_config") : json.at("build_config").at("lora_config"); auto const loraMaxRank = parseJsonFieldOr(config, "max_lora_rank", SizeType{0}); auto const loraTargetModules = parseJsonFieldOptional>(config, "lora_target_modules"); diff --git a/cpp/tensorrt_llm/runtime/loraCache.cpp b/cpp/tensorrt_llm/runtime/loraCache.cpp new file mode 100644 index 000000000..48c7ec0ec --- /dev/null +++ b/cpp/tensorrt_llm/runtime/loraCache.cpp @@ -0,0 +1,904 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/runtime/loraCache.h" +#include "bufferManager.h" +#include "cudaEvent.h" +#include "cudaStream.h" +#include "iBuffer.h" +#include "tensorrt_llm/common/assert.h" +#include "tensorrt_llm/common/cudaUtils.h" +#include "tensorrt_llm/common/logger.h" +#include "tensorrt_llm/common/memoryUtils.h" +#include "tensorrt_llm/runtime/loraUtils.h" +#include +#include +#include +#include +#include +#include +#include + +namespace tensorrt_llm::runtime +{ + +LoraCachePageManager::LoraCachePageManager(LoraCachePageManagerConfig const& config, BufferManager const& bufferManager) + : mConfig(config) +{ + initialize(bufferManager); +} + +void LoraCachePageManager::initialize(BufferManager const& bufferManager) +{ + TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); + + TLLM_LOG_DEBUG("pageConfig: " + to_string(mConfig)); + + std::size_t pageIdx = 0; + while (pageIdx < static_cast(mConfig.getTotalNumPages())) + { + auto const numLocalPages = std::min( + mConfig.getTotalNumPages() - static_cast(pageIdx), mConfig.getMaxPagesPerBlock()); + auto const blockShape = ITensor::makeShape({numLocalPages, mConfig.getSlotsPerPage(), mConfig.getPageWidth()}); + TensorPtr block = bufferManager.allocate(mConfig.getMemoryType(), blockShape, mConfig.getDataType()); + mPageBlocks.push_back(block); + for (SizeType i = 0; i < numLocalPages; ++i) + { + mFreePageIds.push_back(pageIdx); + ++pageIdx; + } + } + mIsPageFree.assign(pageIdx, 1); + + TLLM_LOG_DEBUG("%s allocated %d blocks and %d pages", __PRETTY_FUNCTION__, mPageBlocks.size(), pageIdx); + TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__); +} + +std::optional> LoraCachePageManager::claimPages(SizeType numPages) +{ + TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); + if (numPages <= static_cast(mFreePageIds.size())) + { + std::vector outputPages{}; + outputPages.reserve(numPages); + for (auto it = mFreePageIds.begin(); + outputPages.size() < static_cast(numPages) && it != mFreePageIds.end(); + it = mFreePageIds.erase(it)) + { + mIsPageFree.at(*it) = 0; + outputPages.push_back(*it); + } + return std::make_optional(std::move(outputPages)); + } + TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__); + return std::nullopt; +} + +SizeType LoraCachePageManager::numAvailablePages() const +{ + return static_cast(mFreePageIds.size()); +} + +void LoraCachePageManager::releasePages(std::vector const& pageIds) +{ + for (auto pageId : pageIds) + { + if (pageId >= mIsPageFree.size() || mIsPageFree[pageId]) + { + TLLM_LOG_WARNING("Attempted to release already free lora cache page"); + } + else + { + mFreePageIds.push_front(pageId); + mIsPageFree.at(pageId) = 1; + } + } +} + +ITensor::SharedConstPtr LoraCachePageManager::blockPtr(SizeType blockIdx) const +{ + return mPageBlocks.at(blockIdx); +} + +ITensor::SharedConstPtr LoraCachePageManager::pagePtr(std::size_t pageIdx) const +{ + auto blockIdx = pageIdx / mConfig.getMaxPagesPerBlock(); + auto blockPageIdx = pageIdx % mConfig.getMaxPagesPerBlock(); + + return ITensor::view(ITensor::slice(mPageBlocks.at(blockIdx), blockPageIdx, 1), + ITensor::makeShape({mConfig.getSlotsPerPage(), mConfig.getPageWidth()})); +} + +ITensor::SharedPtr LoraCachePageManager::mutablePagePtr(std::size_t pageIdx) +{ + auto blockIdx = pageIdx / mConfig.getMaxPagesPerBlock(); + auto blockPageIdx = pageIdx % mConfig.getMaxPagesPerBlock(); + + return ITensor::view(ITensor::slice(mPageBlocks.at(blockIdx), blockPageIdx, 1), + ITensor::makeShape({mConfig.getSlotsPerPage(), mConfig.getPageWidth()})); +} + +void LoraCache::put(TaskIdType taskId, TensorPtr sourceWeights, TensorPtr sourceConfig, bool load) +{ + TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); + + auto taskValuePtr = [&]() -> std::optional + { + std::lock_guard cacheLock(mCacheMutex); + if (kVALUE_STATUS_MISSING != getStatus(taskId)) + { + bumpTaskInProgress(taskId); + return std::nullopt; + } + + mInProgressTasks.push_front(taskId); + TaskValuePtr cacheV = std::make_shared(std::vector{}, TaskLayerModuleConfigListPtr(), + mInProgressTasks.begin(), true, false, false, true); + mCacheMap.try_emplace(taskId, std::move(cacheV)); + return mCacheMap.at(taskId); + }(); + if (!taskValuePtr) + { + return; + } + auto taskValue = taskValuePtr.value(); + + TensorPtr config = sourceConfig->getShape().nbDims == 2 + ? sourceConfig + : ITensor::view( + sourceConfig, ITensor::makeShape({sourceConfig->getShape().d[1], sourceConfig->getShape().d[2]})); + + TensorPtr weights = sourceWeights->getShape().nbDims == 2 + ? sourceWeights + : ITensor::view( + sourceWeights, ITensor::makeShape({sourceWeights->getShape().d[1], sourceWeights->getShape().d[2]})); + + auto neededPages = determineNumPages(config); + std::vector pageIds{}; + try + { + pageIds = claimPagesWithEvict(neededPages); + } + catch (std::runtime_error& e) + { + std::lock_guard lk(mCacheMutex); + mInProgressTasks.erase(taskValue->it); + mCacheMap.erase(taskId); + throw e; + } + + taskValue->pageIds = std::move(pageIds); + { + std::lock_guard lk(mCacheMutex); + taskValue->loadInProgress = false; + } + + if (load) + { + loadWeights(*taskValue, weights, config); + } + + bool isDone; + { + std::lock_guard lk(mCacheMutex); + isDone = taskValue->done; + } + if (isDone) + { + markTaskDone(taskId); + } + TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__); +} + +void LoraCache::loadWeights(TaskIdType taskId, TensorPtr sourceWeights, TensorPtr sourceConfig) +{ + TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); + auto taskValuePtr = [&]() -> std::optional + { + std::lock_guard cacheLock(mCacheMutex); + auto taskStatus = getStatus(taskId); + if (kVALUE_STATUS_MISSING == taskStatus) + { + throw std::runtime_error("task " + std::to_string(taskId) + " has not been added to cache. call put first"); + } + else if (kVALUE_STATUS_LOADED == taskStatus) + { + return std::nullopt; + } + + auto taskValue = mCacheMap.at(taskId); + if (taskValue->loadInProgress) + { + return std::nullopt; + } + taskValue->loadInProgress = true; + return taskValue; + }(); + if (!taskValuePtr) + { + return; + } + auto taskValue = taskValuePtr.value(); + + TensorPtr config = sourceConfig->getShape().nbDims == 2 + ? sourceConfig + : ITensor::view( + sourceConfig, ITensor::makeShape({sourceConfig->getShape().d[1], sourceConfig->getShape().d[2]})); + + TensorPtr weights = sourceWeights->getShape().nbDims == 2 + ? sourceWeights + : ITensor::view( + sourceWeights, ITensor::makeShape({sourceWeights->getShape().d[1], sourceWeights->getShape().d[2]})); + + loadWeights(*taskValue, weights, config); + + bool isDone; + { + std::lock_guard lk(mCacheMutex); + isDone = taskValue->done; + } + if (isDone) + { + markTaskDone(taskId); + } + TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__); +} + +void LoraCache::loadWeights(TaskValue& taskValue, TensorPtr weights, TensorPtr config) +{ + TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); + std::vector pagePtrs{}; + pagePtrs.reserve(taskValue.pageIds.size()); + for (auto id : taskValue.pageIds) + { + pagePtrs.push_back(mCachePageManager->mutablePagePtr(id)); + } + + taskValue.configs = std::make_shared>(copyToPages( + weights, config, mModelConfig, mWorldConfig, mModuleIdToModule, *mBufferManager, pagePtrs, taskValue.pageIds)); + { + std::lock_guard lk(mCacheMutex); + taskValue.loadInProgress = false; + taskValue.loaded = true; + } + TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__); +} + +std::vector LoraCache::claimPagesWithEvict(SizeType numPages) +{ + TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); + TLLM_LOG_DEBUG("trying to claim " + std::to_string(numPages)); + std::lock_guard pageLock(mPagesMutex); + auto const availablePages = mCachePageManager->numAvailablePages(); + if (numPages <= availablePages) + { + auto pageIds = mCachePageManager->claimPages(numPages); + TLLM_CHECK(pageIds.has_value()); + return pageIds.value(); + } + + std::lock_guard cacheLock(mCacheMutex); + std::vector pageIdsToEvict; + std::vector taskIdsToEvict; + auto neededPages = numPages - availablePages; + auto it = mDoneTasks.rbegin(); + for (auto it = mDoneTasks.rbegin(); it != mDoneTasks.rend() && neededPages > 0; it = std::next(it)) + { + auto const taskId = *it; + taskIdsToEvict.push_back(taskId); + auto const& taskValue = *(mCacheMap.at(taskId)); + pageIdsToEvict.insert(pageIdsToEvict.end(), taskValue.pageIds.begin(), taskValue.pageIds.end()); + neededPages -= taskValue.pageIds.size(); + } + if (it == mDoneTasks.rend()) + { + TLLM_THROW("Cache is full. There are no done tasks to evict"); + } + + TLLM_LOG_DEBUG("evicting " + std::to_string(taskIdsToEvict.size())); + for (size_t i = 0; i < taskIdsToEvict.size(); ++i) + { + + TLLM_LOG_DEBUG("evicting taskId" + std::to_string(taskIdsToEvict.at(i))); + mDoneTasks.pop_back(); + mCacheMap.erase(taskIdsToEvict.at(i)); + } + mCachePageManager->releasePages(pageIdsToEvict); + auto pageIds = mCachePageManager->claimPages(numPages); + TLLM_CHECK(pageIds.has_value()); + TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__); + return pageIds.value(); +} + +void LoraCache::markTaskDone(TaskIdType taskId) +{ + TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); + TLLM_LOG_DEBUG("markTaskDone " + std::to_string(taskId)); + std::lock_guard lock(mCacheMutex); + if (mCacheMap.find(taskId) == mCacheMap.end()) + { + return; + } + auto& taskValue = *(mCacheMap.at(taskId)); + bool inProgress = taskValue.inProgress; + bool loaded = taskValue.loaded; + if (inProgress) + { + if (loaded) + { + mInProgressTasks.erase(taskValue.it); + mDoneTasks.push_front(taskId); + taskValue.it = mDoneTasks.begin(); + taskValue.inProgress = false; + } + } + taskValue.done = true; + TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__); +} + +void LoraCache::markAllDone() +{ + TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); + std::lock_guard lock(mCacheMutex); + for (auto it = mInProgressTasks.rbegin(), nit = it; it != mInProgressTasks.rend(); it = nit) + { + nit = std::next(it); + auto taskId = *it; + auto& taskValue = *(mCacheMap.at(*it)); + bool inProgress = taskValue.inProgress; + bool loaded = taskValue.loaded; + if (inProgress && loaded) + { + nit = decltype(it){mInProgressTasks.erase(taskValue.it)}; + mDoneTasks.push_front(taskId); + taskValue.it = mDoneTasks.begin(); + taskValue.inProgress = false; + } + taskValue.done = true; + } + TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__); +} + +std::shared_ptr> LoraCache::get(TaskIdType taskId) +{ + std::lock_guard lock(mCacheMutex); + if (kVALUE_STATUS_LOADED != getStatus(taskId)) + { + throw std::runtime_error("taskid not loaded"); + } + + bumpTaskInProgress(taskId); + return mCacheMap.at(taskId)->configs; +} + +void LoraCache::bump(TaskIdType taskId) +{ + std::lock_guard lk(mCacheMutex); + bumpTaskInProgress(taskId); +} + +void LoraCache::bumpTaskInProgress(TaskIdType taskId) +{ + auto it = mCacheMap.find(taskId); + if (it != mCacheMap.end()) + { + auto& taskValue = *(it->second); + if (taskValue.inProgress) + { + mInProgressTasks.erase(taskValue.it); + } + else + { + mDoneTasks.erase(taskValue.it); + } + mInProgressTasks.push_front(taskId); + taskValue.it = mInProgressTasks.begin(); + taskValue.inProgress = true; + taskValue.done = false; + } +} + +LoraCache::ValueStatus LoraCache::getStatus(TaskIdType taskId) const +{ + auto it = mCacheMap.find(taskId); + if (it != mCacheMap.end()) + { + return it->second->loaded ? kVALUE_STATUS_LOADED : kVALUE_STATUS_PROCESSING; + } + return kVALUE_STATUS_MISSING; +} + +SizeType LoraCache::determineNumPages(TaskIdType taskId) const +{ + std::lock_guard lk(mCacheMutex); + if (kVALUE_STATUS_MISSING == getStatus(taskId)) + { + throw std::runtime_error("task " + std::to_string(taskId) + " not found in cache call put first"); + } + + return mCacheMap.at(taskId)->pageIds.size(); +} + +SizeType LoraCache::determineNumPages(TensorPtr loraConfig) const +{ + TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); + auto const localNumLayers = mModelConfig.getNbLayers(mWorldConfig.getPipelineParallelism()); + auto const firstLayerId = mWorldConfig.getPipelineParallelRank() * localNumLayers; + auto const lastLayerId = firstLayerId + localNumLayers; + + SizeType currPage = 0; + SizeType currSlot = 0; + SizeType const slotsPerPage = mPageManagerConfig.getSlotsPerPage(); + SizeType const pageWidth = mPageManagerConfig.getPageWidth(); + for (SizeType row = 0; row < loraConfig->getShape().d[0]; ++row) + { + auto const rowPtr = bufferCast(*ITensor::slice(loraConfig, row, 1)); + auto const layerId = rowPtr[lora::kLORA_CONFIG_LAYER_OFF]; + if (layerId >= firstLayerId && layerId < lastLayerId) + { + auto const adapterSize = rowPtr[lora::kLORA_CONFIG_ADAPTER_SIZE_OFF]; + auto const& module = mModuleIdToModule.at(rowPtr[lora::kLORA_CONFIG_MODULE_OFF]); + auto const localSize = module.localInOutSize(adapterSize, mWorldConfig.getTensorParallelism()); + auto const numSlots = common::ceilDiv(localSize, pageWidth); + if (numSlots + currSlot > slotsPerPage) + { + currSlot = 0; + ++currPage; + } + + currSlot += numSlots; + } + } + TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); + return currPage + 1; +} + +LoraCache::LoraCache(LoraCachePageManagerConfig const& pageManagerConfig, GptModelConfig const& modelConfig, + WorldConfig const& worldConfig, BufferManager const& bufferManager) + : mPageManagerConfig(pageManagerConfig) + , mModelConfig(modelConfig) + , mWorldConfig(worldConfig) +{ + mCachePageManager = std::make_unique(mPageManagerConfig, bufferManager); + + auto modules = modelConfig.getLoraModules(); + for (auto const& m : modules) + { + mModuleIdToModule[m.value()] = m; + } + + mBufferManager = std::make_unique(std::make_shared()); + + for (size_t i = 0; i < static_cast(mPageManagerConfig.getNumCopyStreams()); ++i) + { + mDeviceBufferManagers.push_back(std::make_unique(std::make_shared())); + } +} + +template +void LoraCache::splitTransposeCpuInner(ITensor& output, ITensor const& input, SizeType tpSize, SizeType tpRank) +{ + auto const adapterSize = input.getShape().d[0]; + auto const hiddenSize = input.getShape().d[1]; + auto const splitHiddenSize = hiddenSize / tpSize; + + auto outputPtr = bufferCast(output); + auto const inputPtr = bufferCast(input); + + for (SizeType adapterIdx = 0; adapterIdx < adapterSize; ++adapterIdx) + { + for (SizeType hiddenIdx = 0; hiddenIdx < splitHiddenSize; ++hiddenIdx) + { + auto outputIdx = common::flat_index2(adapterIdx, hiddenIdx, splitHiddenSize); + auto inputIdx = common::flat_index2(adapterIdx, hiddenIdx + tpRank * splitHiddenSize, hiddenSize); + outputPtr[outputIdx] = inputPtr[inputIdx]; + } + } +} + +void LoraCache::splitTransposeCpu(ITensor& output, ITensor const& input, SizeType tpSize, SizeType tpRank) +{ + TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); + + switch (input.getDataType()) + { + case nvinfer1::DataType::kINT32: splitTransposeCpuInner(output, input, tpSize, tpRank); break; + case nvinfer1::DataType::kFLOAT: splitTransposeCpuInner(output, input, tpSize, tpRank); break; + case nvinfer1::DataType::kHALF: splitTransposeCpuInner(output, input, tpSize, tpRank); break; + case nvinfer1::DataType::kINT8: splitTransposeCpuInner(output, input, tpSize, tpRank); break; +#ifdef ENABLE_FP8 + case nvinfer1::DataType::kFP8: splitTransposeCpuInner<__nv_fp8_e4m3>(output, input, tpSize, tpRank); break; +#endif // ENABLE_FP8 +#ifdef ENABLE_BF16 + case nvinfer1::DataType::kBF16: splitTransposeCpuInner<__nv_bfloat16>(output, input, tpSize, tpRank); break; +#endif // ENABLE_BF16 + default: TLLM_CHECK_WITH_INFO(false, "data type not supported"); + } + + TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__); +} + +std::vector LoraCache::copyToPages(TensorPtr sourceWeights, TensorPtr sourceConfig, + GptModelConfig const& modelConfig, WorldConfig const& worldConfig, + std::unordered_map moduleIdToModule, BufferManager const& manager, + std::vector const& pages, std::vector const& pageIds) +{ + TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); + + TLLM_CHECK_WITH_INFO(!pages.empty(), "empty pages"); + + TensorPtr weights = sourceWeights->getShape().nbDims == 2 + ? sourceWeights + : ITensor::view( + sourceWeights, ITensor::makeShape({sourceWeights->getShape().d[1], sourceWeights->getShape().d[2]})); + + TensorPtr config = sourceConfig->getShape().nbDims == 2 + ? sourceConfig + : ITensor::view( + sourceConfig, ITensor::makeShape({sourceConfig->getShape().d[1], sourceConfig->getShape().d[2]})); + + TLLM_CHECK(pages[0]->getShape().nbDims == 2); + auto const slotsPerPage = pages[0]->getShape().d[0]; + auto const pageWidth = pages[0]->getShape().d[1]; + + auto const tpSize = worldConfig.getTensorParallelism(); + auto const tpRank = worldConfig.getTensorParallelRank(); + auto const ppSize = worldConfig.getPipelineParallelism(); + auto const ppRank = worldConfig.getPipelineParallelRank(); + auto const localNumLayers = modelConfig.getNbLayers(ppSize); + auto const firstLayerId = ppRank * localNumLayers; + auto const lastLayerId = firstLayerId + localNumLayers; + + SizeType currPage = 0; + SizeType currSlot = 0; + + std::vector rowPage; + std::vector rowSlot; + std::vector rowIndices; + + auto const numRows = config->getShape().d[0]; + for (SizeType row = 0; row < numRows; ++row) + { + auto const configPtr = bufferCast(*ITensor::slice(config, row, 1)); + auto const layerId = configPtr[lora::kLORA_CONFIG_LAYER_OFF]; + if (layerId >= firstLayerId && layerId < lastLayerId) + { + auto const adapterSize = configPtr[lora::kLORA_CONFIG_ADAPTER_SIZE_OFF]; + + auto const modId = configPtr[lora::kLORA_CONFIG_MODULE_OFF]; + auto const& module = moduleIdToModule.at(modId); + auto const localInOutSize = module.localInOutSize(adapterSize, tpSize); + auto const rowSlots = common::ceilDiv(localInOutSize, pageWidth); + if (currSlot + rowSlots > slotsPerPage) + { + currSlot = 0; + ++currPage; + } + + rowIndices.push_back(row); + rowSlot.push_back(currSlot); + rowPage.push_back(currPage); + currSlot += rowSlots; + } + } + + std::vector pageLocations(rowIndices.size()); + for (SizeType i = 0; i < static_cast(rowIndices.size()); ++i) + { + auto copyFn = [i = i, &rowIndices, &rowPage, &rowSlot, &pageLocations, weights, config, &pages, + &moduleIdToModule, &manager, pageWidth, tpSize, tpRank, pageIds]() + { + auto const row = rowIndices[i]; + auto const currPage = rowPage[i]; + auto const currSlot = rowSlot[i]; + auto const configPtr = bufferCast(*ITensor::slice(config, row, 1)); + auto const layerId = configPtr[lora::kLORA_CONFIG_LAYER_OFF]; + + auto const adapterSize = configPtr[lora::kLORA_CONFIG_ADAPTER_SIZE_OFF]; + + auto const modId = configPtr[lora::kLORA_CONFIG_MODULE_OFF]; + auto const& module = moduleIdToModule.at(modId); + auto const localInOutSize = module.localInOutSize(adapterSize, tpSize); + auto const rowSlots = common::ceilDiv(localInOutSize, pageWidth); + + auto const inDim = module.inDim(); + auto const outDim = module.outDim(); + auto const localOutDim = module.localOutDim(tpSize); + auto const inSize = module.inSize(adapterSize); + auto const outSize = module.outSize(adapterSize); + auto const localInSize = module.localInSize(adapterSize, tpSize); + auto const localOutSize = module.localOutSize(adapterSize, tpSize); + + TLLM_CHECK(module.inDimFirst() == false); + TLLM_CHECK(module.outDimFirst() == true); + TLLM_CHECK(module.inTpSplitDim() == 1 || module.inTpSplitDim() == -1); + TLLM_CHECK(module.outTpSplitDim() == 0 || module.outTpSplitDim() == -1); + + auto const splitIn = module.inTpSplitDim() == 1; + auto const splitOut = module.outTpSplitDim() == 0; + + TensorPtr rowWeights + = ITensor::view(ITensor::slice(weights, row, 1), ITensor::makeShape({inSize + outSize})); + TensorPtr weightsIn + = ITensor::view(ITensor::slice(rowWeights, 0, inSize), ITensor::makeShape({adapterSize, inDim})); + TensorPtr weightsOut + = ITensor::view(ITensor::slice(rowWeights, inSize, outSize), ITensor::makeShape({outDim, adapterSize})); + + TensorPtr pageSlice = ITensor::slice(pages.at(currPage), currSlot, rowSlots); + SizeType pageSliceSize = ITensor::volume(pageSlice->getShape()); + TensorPtr pageFlatView = ITensor::view(pageSlice, ITensor::makeShape({pageSliceSize})); + TensorPtr targetWeightsIn = ITensor::slice(pageFlatView, 0, localInSize); + TensorPtr targetWeightsOut = ITensor::slice(pageFlatView, localInSize, localOutSize); + + if (!splitIn) + { + manager.copy(*weightsIn, *targetWeightsIn); + } + else + { + splitTransposeCpu(*targetWeightsIn, *weightsIn, tpSize, tpRank); + } + + if (!splitOut) + { + manager.copy(*weightsOut, *targetWeightsOut); + } + else + { + TensorPtr source = ITensor::view( + ITensor::slice( + ITensor::view(weightsOut, ITensor::makeShape({tpSize, localOutDim, adapterSize})), tpRank, 1), + ITensor::makeShape({localOutDim, adapterSize})); + manager.copy(*source, *targetWeightsOut); + } + + pageLocations[i] + = LoraCache::TaskLayerModuleConfig{pageIds.at(currPage), currSlot, localInSize, localOutSize, modId, + layerId, adapterSize, rowSlots, reinterpret_cast(targetWeightsIn->data()), + reinterpret_cast(targetWeightsOut->data())}; + }; + copyFn(); + } + + TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__); + return pageLocations; +} + +std::map> LoraCache::copyTaskMapPages(TaskValue& targetTaskValue, + TaskValue const& sourceTaskValue, std::vector const& targetPageIds, LoraCache const& targetCache) +{ + auto const& pageIds = sourceTaskValue.pageIds; + + // collect mapping from oldPageId to (newPageId, num used slots in page) + std::map> oldToNewPageIds{}; + for (size_t i = 0; i < pageIds.size(); ++i) + { + oldToNewPageIds.insert_or_assign(pageIds[i], std::make_pair(targetPageIds[i], 0)); + } + + targetTaskValue.configs = std::make_shared>(*sourceTaskValue.configs); + targetTaskValue.pageIds = targetPageIds; + for (size_t i = 0; i < sourceTaskValue.configs->size(); ++i) + { + auto const& sourceConfigs = *(sourceTaskValue.configs); + auto& targetConfigs = *(targetTaskValue.configs); + auto& newPagePair = oldToNewPageIds.at(sourceConfigs[i].pageId); + newPagePair.second += sourceConfigs[i].numSlots; + targetConfigs[i].pageId = newPagePair.first; + auto page = targetCache.mCachePageManager->mutablePagePtr(targetConfigs[i].pageId); + auto const slotId = targetConfigs[i].slotIdx; + auto const numSlots = targetConfigs[i].numSlots; + auto const inSize = targetConfigs[i].inSize; + auto const outSize = targetConfigs[i].outSize; + TensorPtr slot = ITensor::view(ITensor::slice(page, slotId, numSlots), + ITensor::makeShape({numSlots * targetCache.mPageManagerConfig.getPageWidth()})); + targetConfigs[i].weightsInPointer = reinterpret_cast( + ITensor::view(ITensor::slice(slot, 0, inSize), ITensor::makeShape({inSize}))->data()); + targetConfigs[i].weightsOutPointer = reinterpret_cast( + ITensor::view(ITensor::slice(slot, inSize, outSize), ITensor::makeShape({outSize}))->data()); + } + + return oldToNewPageIds; +} + +void LoraCache::copyTask(TaskIdType taskId, LoraCache& deviceCache, bool markDone) +{ + TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); + TLLM_LOG_DEBUG("copyTask " + std::to_string(taskId)); + + TLLM_CHECK_WITH_INFO(deviceCache.mPageManagerConfig.getMemoryType() == runtime::MemoryType::kGPU + && !deviceCache.mDeviceBufferManagers.empty(), + "The deviceCache must hold GPU memory and have at least one bufferManager / copy stream"); + + // First get the taskValue from this cache + // TaskValue& taskValue = copyTaskGetThisTaskValue(taskId); + TaskValuePtr taskValue = [&]() -> TaskValuePtr + { + std::lock_guard cacheLock(mCacheMutex); + auto status = getStatus(taskId); + if (kVALUE_STATUS_PROCESSING == status) + { + throw std::runtime_error("can't move a processing task taskId=" + std::to_string(taskId)); + } + else if (status == kVALUE_STATUS_MISSING) + { + throw std::runtime_error("can't move a missing task" + std::to_string(taskId)); + } + auto taskValue = mCacheMap.at(taskId); + // mark task unloaded so we can evict the task while the copy in in progress + taskValue->loaded = false; + bumpTaskInProgress(taskId); + return taskValue; + }(); + + auto& pageIds = taskValue->pageIds; + auto neededPages = pageIds.size(); + + // Now create put the task in the target cache + // TaskValue* otherTaskValuePtr = copyTaskGetOtherTaskValue(taskId, taskValue, deviceCache, markDone); + std::optional optOtherTaskValuePtr = [&]() -> std::optional + { + std::lock_guard deviceCacheLock(deviceCache.mCacheMutex); + auto otherStatus = deviceCache.getStatus(taskId); + if (kVALUE_STATUS_MISSING != otherStatus) + { + deviceCache.bumpTaskInProgress(taskId); + taskValue->loaded = true; + return std::nullopt; + } + + deviceCache.mInProgressTasks.push_front(taskId); + auto cacheV = std::make_shared(std::vector{}, TaskLayerModuleConfigListPtr(), + deviceCache.mInProgressTasks.begin(), true, false, markDone, true); + deviceCache.mCacheMap.try_emplace(taskId, std::move(cacheV)); + auto otherTaskValue = deviceCache.mCacheMap.at(taskId); + // TODO (grclark) return shared_ptr + return otherTaskValue; + }(); + if (!optOtherTaskValuePtr) + { + return; + } + TaskValuePtr otherTaskValue = optOtherTaskValuePtr.value(); + + std::vector newPageIds{}; + try + { + newPageIds = deviceCache.claimPagesWithEvict(neededPages); + } + catch (std::runtime_error& e) + { + { + std::lock_guard lk(deviceCache.mCacheMutex); + deviceCache.mInProgressTasks.erase(otherTaskValue->it); + deviceCache.mCacheMap.erase(taskId); + taskValue->loaded = true; + throw std::runtime_error("Couldn't claim pages during copyTask -- " + std::string(e.what())); + } + } + + auto oldToNewPageIds = copyTaskMapPages(*otherTaskValue, *taskValue, newPageIds, deviceCache); + + auto const flatPageShape + = ITensor::makeShape({mPageManagerConfig.getPageWidth() * mPageManagerConfig.getSlotsPerPage()}); + size_t bufferManagerOffset = taskId % deviceCache.mDeviceBufferManagers.size(); + std::vector copyEvents(otherTaskValue->pageIds.size()); + size_t eventIdx = 0; + for (auto const& [oldPageId, newPagePair] : oldToNewPageIds) + { + auto const newPageId = newPagePair.first; + auto const copySize = newPagePair.second * mPageManagerConfig.getPageWidth(); + auto const copyShape = ITensor::makeShape({copySize}); + TLLM_LOG_DEBUG("copy page (task " + std::to_string(taskId) + ") " + std::to_string(oldPageId) + " -> " + + std::to_string(newPageId) + " size: " + std::to_string(copySize)); + TensorPtr oldPagePtr = mCachePageManager->mutablePagePtr(oldPageId); + TensorPtr newPagePtr = deviceCache.mCachePageManager->mutablePagePtr(newPageId); + TensorPtr source + = ITensor::view(ITensor::slice(ITensor::view(oldPagePtr, flatPageShape), 0, copySize), copyShape); + TensorPtr dest + = ITensor::view(ITensor::slice(ITensor::view(newPagePtr, flatPageShape), 0, copySize), copyShape); + deviceCache.mDeviceBufferManagers[bufferManagerOffset]->copy(*source, *dest); + deviceCache.mDeviceBufferManagers[bufferManagerOffset]->getStream().record(copyEvents[eventIdx++]); + bufferManagerOffset = (bufferManagerOffset + 1) % deviceCache.mDeviceBufferManagers.size(); + } + for (auto const& event : copyEvents) + { + event.synchronize(); + } + + bool otherIsDone; + { + std::lock_guard lk(mCacheMutex); + otherIsDone = otherTaskValue->done; + otherTaskValue->loadInProgress = false; + otherTaskValue->loaded = true; + } + if (otherIsDone) + { + deviceCache.markTaskDone(taskId); + } + + bool isDone; + { + std::lock_guard lk(mCacheMutex); + isDone = taskValue->done; + taskValue->loaded = true; + } + if (isDone) + { + markTaskDone(taskId); + } + TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__); +} + +ITensor::SharedConstPtr LoraCache::getPagePtr(size_t pageId) const +{ + return mCachePageManager->pagePtr(pageId); +} + +SizeType LoraCache::getNumPages() const +{ + return mPageManagerConfig.getTotalNumPages(); +} + +bool LoraCache::fits(TensorPtr config) const +{ + auto const neededPages = determineNumPages(config); + SizeType availablePages; + { + std::lock_guard lk(mPagesMutex); + availablePages = mCachePageManager->numAvailablePages(); + } + return neededPages < availablePages; +} + +std::string to_string(LoraCache::TaskLayerModuleConfig const& v) +{ + std::stringstream sstream; + sstream << "{pageIdx=" << v.pageId << "; " + << "slotIdx=" << v.slotIdx << "; " + << "inSize=" << v.inSize << "; " + << "outSize=" << v.outSize << "; " + << "moduleId=" << v.moduleId << "; " + << "layerId=" << v.layerId << "; " + << "adapterSize=" << v.adapterSize << "; " + << "numSlots=" << v.numSlots << "}"; + return sstream.str(); +} + +std::ostream& operator<<(std::ostream& os, LoraCache::TaskLayerModuleConfig const& v) +{ + os << to_string(v); + return os; +} + +bool LoraCache::TaskLayerModuleConfig::operator==(LoraCache::TaskLayerModuleConfig const& o) const +{ + return (pageId == o.pageId && slotIdx == o.slotIdx && inSize == o.inSize && outSize == o.outSize + && moduleId == o.moduleId && layerId == o.layerId && adapterSize == o.adapterSize && numSlots == o.numSlots); +} + +bool LoraCache::isDone(TaskIdType taskId) const +{ + std::lock_guard lk(mCacheMutex); + if (mCacheMap.count(taskId)) + { + auto const taskValue = mCacheMap.at(taskId); + return !taskValue->inProgress; + } + return false; +} +} // namespace tensorrt_llm::runtime diff --git a/cpp/tensorrt_llm/runtime/loraManager.cpp b/cpp/tensorrt_llm/runtime/loraManager.cpp index f5f07f8d8..bc0e741b2 100644 --- a/cpp/tensorrt_llm/runtime/loraManager.cpp +++ b/cpp/tensorrt_llm/runtime/loraManager.cpp @@ -24,27 +24,12 @@ #include "tensorrt_llm/runtime/iBuffer.h" #include "tensorrt_llm/runtime/iTensor.h" #include "tensorrt_llm/runtime/loraUtils.h" -#include "tensorrt_llm/runtime/runtimeKernels.h" #include "tensorrt_llm/runtime/utils/sessionUtils.h" #include "tensorrt_llm/runtime/worldConfig.h" #include namespace tensorrt_llm::runtime { -void LoraManager::addTask(TaskIdType reqId, TensorPtr weights, TensorPtr config) -{ - if (mLoras.find(reqId) != mLoras.end()) - { - return; - } - - mLoras[reqId] = std::make_tuple(weights, config); -} - -LoraManager::LoraReqTensors& LoraManager::getTask(TaskIdType reqId) -{ - return mLoras.at(reqId); -} void LoraManager::create( GptModelConfig const& modelConfig, WorldConfig const& worldConfig, BufferManager const& manager) @@ -56,110 +41,68 @@ void LoraManager::create( for (auto const& m : modules) { mModuleIdToModule[m.value()] = m; - mModuleOffest[m.value()] = modOff++; + mModuleOffset[m.value()] = modOff++; } - // TODO set this size from max adapter size - mWorkspace = manager.emptyTensor(MemoryType::kGPU, modelConfig.getDataType()); - TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } -void LoraManager::fillInputTensors(TensorPtr weightsPtrs, TensorPtr adapterSizes, ReqIdsVec const& reqIds, - std::vector const& reqBeamWidth, std::vector const& loraEnabled, SizeType numContextRequests, - GptModelConfig const& modelConfig, WorldConfig const& worldConfig) +void LoraManager::fillInputTensors(TensorPtr weightsPtrs, TensorPtr adapterSizes, PeftTable const& peftTable, + ReqIdsVec const& reqIds, std::vector const& reqBeamWidth, GptModelConfig const& modelConfig, + WorldConfig const& worldConfig) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); - auto localNbLayers = modelConfig.getNbLayers(worldConfig.getPipelineParallelism()); - auto firstLayerId = worldConfig.getPipelineParallelRank() * localNbLayers; - auto lastLayerId = firstLayerId + localNbLayers; - auto tpSize = worldConfig.getTensorParallelism(); - auto tpRank = worldConfig.getTensorParallelRank(); - auto batchSize = static_cast(reqIds.size()); for (SizeType bid = 0; bid < batchSize; ++bid) { - if (!loraEnabled[bid]) + auto it = peftTable.find(reqIds[bid]); + if (it == peftTable.end()) + { continue; - - fillInputTensors( - weightsPtrs, adapterSizes, bid, reqIds[bid], reqBeamWidth[bid], firstLayerId, lastLayerId, tpSize, tpRank); + } + auto peftValues = it->second; + fillInputTensors(weightsPtrs, adapterSizes, peftValues, bid, reqBeamWidth[bid], modelConfig, worldConfig); } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } -void LoraManager::fillInputTensors(TensorPtr weightsPtrs, TensorPtr adapterSizes, SizeType batchIdx, TaskIdType taskId, - SizeType beamWidth, SizeType firstLayerId, SizeType lastLayerId, SizeType tpSize, SizeType tpRank) +void LoraManager::fillInputTensors(TensorPtr weightsPtrs, TensorPtr adapterSizes, PeftValues const peftValues, + SizeType batchIdx, SizeType beamWidth, GptModelConfig const& modelConfig, WorldConfig const& worldConfig) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + auto const ppSize = worldConfig.getPipelineParallelism(); + auto const ppRank = worldConfig.getPipelineParallelRank(); + auto const localNumLayers = modelConfig.getNbLayers(ppSize); + auto const firstLayerId = ppRank * localNumLayers; + auto weightsPointersPtr = bufferCast(*weightsPtrs); auto adapterSizesPtr = bufferCast(*adapterSizes); - auto [reqWeights, reqKeys] = getTask(taskId); - auto reqKeysPtr = bufferCast(*reqKeys); - auto numRows = reqKeys->getShape().d[0]; - if (reqKeys->getShape().d[1] != lora::kLORA_CONFIG_ROW_SIZE) - { - throw std::runtime_error( - "Expected request lora_keys tor have row size of " + std::to_string(lora::kLORA_CONFIG_ROW_SIZE)); - } + TLLM_CHECK(!peftValues->empty()); + + auto const numRows = static_cast(peftValues->size()); for (SizeType row = 0; row < numRows; ++row) { - auto layerIdx = reqKeysPtr[row * lora::kLORA_CONFIG_ROW_SIZE + lora::kLORA_CONFIG_LAYER_OFF]; - if (layerIdx < firstLayerId || layerIdx >= lastLayerId) - continue; - - auto moduleId = reqKeysPtr[row * lora::kLORA_CONFIG_ROW_SIZE + lora::kLORA_CONFIG_MODULE_OFF]; - auto adapterSize = reqKeysPtr[row * lora::kLORA_CONFIG_ROW_SIZE + lora::kLORA_CONFIG_ADAPTER_SIZE_OFF]; - - auto modOff = mModuleOffest.at(moduleId); - auto& module = mModuleIdToModule.at(moduleId); - - auto inDim = (module.inDimFirst() && module.inTpSplitDim() == 0) - || (!module.inDimFirst() && module.inTpSplitDim() == 1) - ? module.inDim() / tpSize - : module.inDim(); - auto inTpSize = module.inTpSplitDim() == -1 ? 1 : tpSize; - auto inTpRank = module.inTpSplitDim() == -1 ? 0 : tpRank; - - auto outDim = (module.outDimFirst() && module.outTpSplitDim() == 0) - || (!module.outDimFirst() && module.outTpSplitDim() == 1) - ? module.outDim() / tpSize - : module.outDim(); - auto outTpSize = module.outTpSplitDim() == -1 ? 1 : tpSize; - auto outTpRank = module.outTpSplitDim() == -1 ? 0 : tpRank; - - auto inWeightsShape = module.inDimFirst() ? ITensor::makeShape({inTpSize, inDim, adapterSize}) - : ITensor::makeShape({inTpSize, adapterSize, inDim}); - auto outWeightsShape = module.outDimFirst() ? ITensor::makeShape({outTpSize, outDim, adapterSize}) - : ITensor::makeShape({outTpSize, adapterSize, outDim}); - - TensorPtr reqRowWeights = ITensor::slice(reqWeights, row, 1); - reqRowWeights->squeeze(0); - TensorPtr allInWeights = ITensor::view(reqRowWeights, inWeightsShape); - - TensorPtr allOutWeights - = ITensor::slice(reqRowWeights, allInWeights->getSize(), ITensor::volume(outWeightsShape)); - allOutWeights->reshape(outWeightsShape); + auto const& peftValue = peftValues->at(row); + auto const moduleId = peftValue.moduleId; + auto const adapterSize = peftValue.adapterSize; + auto const modOff = mModuleOffset.at(moduleId); + auto const layerIdx = peftValue.layerId; - auto inWeightsPtr = reinterpret_cast(ITensor::slice(allInWeights, inTpRank, 1)->data()); - auto outWeightsPtr = reinterpret_cast(ITensor::slice(allOutWeights, outTpRank, 1)->data()); + auto const inWeightsPtr = peftValue.weightsInPointer; + auto const outWeightsPtr = peftValue.weightsOutPointer; auto weightsPointersPtrOffset = common::flat_index4(modOff, layerIdx - firstLayerId, batchIdx, 0, weightsPtrs->getShape().d[1], weightsPtrs->getShape().d[2], weightsPtrs->getShape().d[3]); auto adapterSizesPtrOffset = common::flat_index3( modOff, layerIdx - firstLayerId, batchIdx, adapterSizes->getShape().d[1], adapterSizes->getShape().d[2]); - if (static_cast(weightsPtrs->getSize()) - < weightsPointersPtrOffset + lora::kLORA_NUM_WEIGHTS_POINTERS * beamWidth) - { - throw std::runtime_error("Coding error attempting to write lora ptrs outside range of buffer"); - } - if (static_cast(adapterSizes->getSize()) < adapterSizesPtrOffset + beamWidth) - { - throw std::runtime_error("Coding error attempting to write lora low ranks outside range of buffer"); - } + TLLM_CHECK_WITH_INFO(static_cast(weightsPtrs->getSize()) + >= weightsPointersPtrOffset + lora::kLORA_NUM_WEIGHTS_POINTERS * beamWidth, + "Coding error attempting to write lora ptrs outside range of buffer"); + TLLM_CHECK_WITH_INFO(static_cast(adapterSizes->getSize()) >= adapterSizesPtrOffset + beamWidth, + "Coding error attempting to write lora low ranks outside range of buffer"); auto const writeWeightsPtr = weightsPointersPtr + weightsPointersPtrOffset; auto const writeAdapterSizesPtr = adapterSizesPtr + adapterSizesPtrOffset; @@ -184,7 +127,7 @@ void LoraManager::insertInputTensors(TensorMap& inputTensors, TensorPtr weightsP for (auto const& [modId, mod] : mModuleIdToModule) { - auto modOff = mModuleOffest.at(modId); + auto modOff = mModuleOffset.at(modId); TensorPtr weightPtrsModSlice = ITensor::slice(weightsPtrs, modOff, 1); weightPtrsModSlice->squeeze(0); @@ -212,61 +155,4 @@ void LoraManager::insertInputTensors(TensorMap& inputTensors, TensorPtr weightsP } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } - -void LoraManager::formatTaskTensors(LoraWeightsTensorPtr weights, LoraConfigTensorPtr config, - GptModelConfig const& modelConfig, WorldConfig const& worldConfig, BufferManager const& manager) -{ - TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); - weights->squeeze(0); - config->squeeze(0); - - auto tpSize = worldConfig.getTensorParallelism(); - - SizeType nbRows = config->getShape().d[0]; - for (SizeType row = 0; row < nbRows; ++row) - { - auto rowPtr = bufferCast(*ITensor::slice(config, row, 1)); - auto modId = rowPtr[lora::kLORA_CONFIG_MODULE_OFF]; - auto adapterSize = rowPtr[lora::kLORA_CONFIG_ADAPTER_SIZE_OFF]; - - auto& module = mModuleIdToModule.at(modId); - TLLM_CHECK_WITH_INFO(!module.inDimFirst() && module.outDimFirst(), "unsupported module"); - if (module.inTpSplitDim() == 1) - { - TensorPtr inWeights = ITensor::slice(weights, row, 1); - inWeights->squeeze(0); - inWeights->reshape(ITensor::makeShape({adapterSize, module.inDim()})); - if (mWorkspace->getSize() < inWeights->getSize()) - { - mWorkspace = manager.gpu(inWeights->getShape(), inWeights->getDataType()); - } - mWorkspace->reshape(ITensor::makeShape({tpSize, adapterSize, module.inDim() / tpSize})); - kernels::splitTransposed(*mWorkspace, *inWeights, tpSize, manager.getStream()); - manager.copy(*mWorkspace, *inWeights); - } - if (module.outTpSplitDim() == 1) - { - TensorPtr rowWeights = ITensor::slice(weights, row, 1); - rowWeights->squeeze(0); - TensorPtr weightsOut - = ITensor::slice(rowWeights, adapterSize * module.inDim(), adapterSize * module.outDim()); - weightsOut->squeeze(0); - weightsOut->reshape(ITensor::makeShape({module.outDim(), adapterSize})); - if (mWorkspace->getSize() < weightsOut->getSize()) - { - mWorkspace = manager.gpu(weightsOut->getShape(), weightsOut->getDataType()); - } - mWorkspace->reshape(weightsOut->getShape()); - kernels::splitTransposed(*mWorkspace, *weightsOut, tpSize, manager.getStream()); - manager.copy(*mWorkspace, *weightsOut); - } - } - TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); -} - -void LoraManager::reset() -{ - mLoras.clear(); -} - } // namespace tensorrt_llm::runtime diff --git a/cpp/tensorrt_llm/runtime/loraManager.h b/cpp/tensorrt_llm/runtime/loraManager.h index 58f4e1252..47130f959 100644 --- a/cpp/tensorrt_llm/runtime/loraManager.h +++ b/cpp/tensorrt_llm/runtime/loraManager.h @@ -19,6 +19,7 @@ #include "tensorrt_llm/runtime/bufferManager.h" #include "tensorrt_llm/runtime/common.h" #include "tensorrt_llm/runtime/gptModelConfig.h" +#include "tensorrt_llm/runtime/loraCache.h" #include "tensorrt_llm/runtime/loraModule.h" #include "tensorrt_llm/runtime/worldConfig.h" #include @@ -40,6 +41,8 @@ class LoraManager using LoraConfigTensorPtr = TensorPtr; using LoraReqTensors = std::tuple; using TaskIdType = std::int64_t; + using PeftValues = std::shared_ptr> const; + using PeftTable = std::map>>; explicit LoraManager() {} @@ -51,43 +54,12 @@ class LoraManager */ void create(GptModelConfig const& modelConfig, WorldConfig const& worldConfig, BufferManager const& manager); - /** - * \brief Add Task (LoRA tensor to manager) - * \details weights and config are assumed to be in the proper format - * and to have been formatted with formatTaskTensors - * \param[in] taskId: id associated with these lora weights - * \param[in] weights: LoRA weights tensor [num_modules_layers, D x Hi + Ho x D]. - * Each row contains the flattened in / out LoRA weights for a single module / layer. - * D=adapter size (R value); Hi=hidden dim of in weights; Ho=hidden dim of out weights - * \param[in] config: LoRA config tensor [num_modules_layers, 3] - * each row contains 3 values (module_id, layer_idx, D) - * See LoraModule::ModelType for module_id details - */ - void addTask(TaskIdType taskId, LoraWeightsTensorPtr weights, LoraConfigTensorPtr config); - - /** - * \brief getTask by taskId - * \param[in] taskId: task id - */ - LoraReqTensors& getTask(TaskIdType taskId); - - /** - * \brief format tensors for addTask. See addTask for details on expected format - * \param[out] weights: LoRA weights tensor. See addTask for details - * \param[out] config: LoRA config tensor. See addTask for details - * \param[in] modelConfig: A GptModelConfig - * \param[in] worldConfig: A WorldConfig - * \param[in]: manager: A BufferManager - */ - void formatTaskTensors(LoraWeightsTensorPtr weights, LoraConfigTensorPtr config, GptModelConfig const& modelConfig, - WorldConfig const& worldConfig, BufferManager const& manager); - /** * \brief same as fillInputTensors but for an entire batch */ - void fillInputTensors(TensorPtr weightsPtrs, TensorPtr adapterSizes, ReqIdsVec const& reqIds, - std::vector const& reqBeamWidth, std::vector const& loraEnabled, SizeType numContextRequests, - GptModelConfig const& modelConfig, WorldConfig const& worldConfig); + void fillInputTensors(TensorPtr weightsPtrs, TensorPtr adapterSizes, PeftTable const& peftTable, + ReqIdsVec const& reqIds, std::vector const& reqBeamWidth, GptModelConfig const& modelConfig, + WorldConfig const& worldConfig); /** * \brief fill batch input tensors for LoRA. This method fills on batch slot. @@ -95,16 +67,14 @@ class LoraManager * (ie for `*_lora_weights_pointers_*` fields) * \param[out] adapterSizes: the adapter sizes tensor to fill * (ie for `*lora_low_rank_*` fields) + * \param[in] peftTable: reqId to LoraCache::Values * \param[in] batchIdx: the request batch index - * \param[in] taskId: the LoRA task id to use * \param[in] beamWidth: the request beam width - * \param[in] firstLayerId: firstLaterId in this rank for pipeline parallel models - * \param[in] lastLayerId: firstLayerId in this rank for pipeline parallel models - * \param[in] tpSize: tensor parallel size - * \param[in] tpRank: tensor parallel rank + * \param[in] modelConfig: a GptModelConfig + * \param[in] worldConfig: a WorldConfig */ - void fillInputTensors(TensorPtr weightsPtrs, TensorPtr adapterSizes, SizeType batchIdx, TaskIdType taskId, - SizeType beamWidth, SizeType firstLayerId, SizeType lastLayerId, SizeType tpSize, SizeType tpRank); + void fillInputTensors(TensorPtr weightsPtrs, TensorPtr adapterSizes, PeftValues const peftValues, SizeType batchIdx, + SizeType beamWidth, GptModelConfig const& modelConfig, WorldConfig const& worldConfig); /** * \brief fill tensor map for trt engine context @@ -117,12 +87,8 @@ class LoraManager void insertInputTensors(TensorMap& inputTensors, TensorPtr weightsPtrs, TensorPtr adapterSizes, GptModelConfig const& modelConfig, WorldConfig const& worldConfig) const; - void reset(); - private: - TensorPtr mWorkspace; - std::unordered_map mLoras; std::unordered_map mModuleIdToModule; - std::unordered_map mModuleOffest; + std::unordered_map mModuleOffset; }; } // namespace tensorrt_llm::runtime diff --git a/cpp/tensorrt_llm/runtime/loraUtils.cpp b/cpp/tensorrt_llm/runtime/loraUtils.cpp index 7baf8652b..2448e8821 100644 --- a/cpp/tensorrt_llm/runtime/loraUtils.cpp +++ b/cpp/tensorrt_llm/runtime/loraUtils.cpp @@ -16,6 +16,7 @@ #include "tensorrt_llm/runtime/gptModelConfig.h" #include "tensorrt_llm/runtime/iTensor.h" #include "tensorrt_llm/runtime/worldConfig.h" +#include namespace tensorrt_llm::runtime::lora { @@ -27,7 +28,6 @@ void loraValidateRequestTensorDims(std::optional const& optR "Request for LoRA inference must have both lora_weights and lora_keys"); SizeType constexpr expectedBatchSize = 1; - SizeType constexpr expectedLoraConfigValues = kLORA_CONFIG_ROW_SIZE; SizeType constexpr expectedWeightsDims = 3; SizeType constexpr expectedKeysDims = 3; @@ -46,41 +46,44 @@ void loraValidateRequestTensorDims(std::optional const& optR TLLM_CHECK_WITH_INFO(keys->getShape().d[1] == weights->getShape().d[1], "Expected dim1 lora_weights and lora_keys to have the same size"); - TLLM_CHECK_WITH_INFO( - keys->getShape().d[2] == expectedLoraConfigValues, "Expected dim2 of lora_keys to have a size of 3"); + TLLM_CHECK_WITH_INFO(keys->getShape().d[2] == kLORA_CONFIG_ROW_SIZE, + "Expected dim2 of lora_keys to have a size of " + std::to_string(kLORA_CONFIG_ROW_SIZE)); } -void loraValidateRequestTensors(std::optional const& optReqLoraWeights, +void loraValidateRequestTensors(std::optional const& optTaskId, + std::optional const& optReqLoraWeights, std::optional const& optReqLoraConfig, runtime::GptModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig) { - SizeType constexpr expectedLoraConfigValues = 3; - - loraValidateRequestTensorDims(optReqLoraWeights, optReqLoraConfig); + TLLM_CHECK_WITH_INFO(optTaskId.has_value(), "lora_task_id must be set for LoRA inference"); + if (optReqLoraWeights.has_value() || optReqLoraConfig.has_value()) + { + loraValidateRequestTensorDims(optReqLoraWeights, optReqLoraConfig); - auto weights = optReqLoraWeights.value(); - auto keys = optReqLoraConfig.value(); - SizeType nbModelLayers = modelConfig.getNbLayers(); - TLLM_CHECK_WITH_INFO(weights->getDataType() == modelConfig.getDataType(), - "Expected lora weights to be the same data type as base model"); + auto weights = optReqLoraWeights.value(); + auto config = optReqLoraConfig.value(); + SizeType nbModelLayers = modelConfig.getNbLayers(); + TLLM_CHECK_WITH_INFO(weights->getDataType() == modelConfig.getDataType(), + "Expected lora weights to be the same data type as base model"); - auto loraModules = modelConfig.getLoraModules(); - auto keysPtr = bufferCast(*keys); - for (SizeType row = 0; row < keys->getShape().d[1]; ++row) - { - auto modId = keysPtr[row * expectedLoraConfigValues]; - auto layerId = keysPtr[row * expectedLoraConfigValues + 1]; - auto adapterSize = keysPtr[row * expectedLoraConfigValues + 2]; + auto loraModules = modelConfig.getLoraModules(); + auto configPtr = bufferCast(*config); + for (SizeType row = 0; row < config->getShape().d[1]; ++row) + { + auto modId = configPtr[row * kLORA_CONFIG_ROW_SIZE + kLORA_CONFIG_MODULE_OFF]; + auto layerId = configPtr[row * kLORA_CONFIG_ROW_SIZE + kLORA_CONFIG_LAYER_OFF]; + auto adapterSize = configPtr[row * kLORA_CONFIG_ROW_SIZE + kLORA_CONFIG_ADAPTER_SIZE_OFF]; - TLLM_CHECK_WITH_INFO( - layerId >= 0 && layerId < nbModelLayers, "Expected layerId to be in the range [0, numModelLayers)"); - TLLM_CHECK_WITH_INFO(adapterSize > 0, "Expected adapterSize to be > 0"); - auto it = std::find_if( - loraModules.begin(), loraModules.end(), [modId](LoraModule const& m) { return m.value() == modId; }); - std::string moduleName(LoraModule::toModuleName(modId)); - TLLM_CHECK_WITH_INFO(it != loraModules.end(), "lora module " + moduleName + " not enabled for this model"); - TLLM_CHECK_WITH_INFO(it->flattenedInOutSize(adapterSize) <= weights->getShape().d[2], - "lora_weights has to few values for " + moduleName); + TLLM_CHECK_WITH_INFO( + layerId >= 0 && layerId < nbModelLayers, "Expected layerId to be in the range [0, numModelLayers)"); + TLLM_CHECK_WITH_INFO(adapterSize > 0, "Expected adapterSize to be > 0"); + auto it = std::find_if( + loraModules.begin(), loraModules.end(), [modId](LoraModule const& m) { return m.value() == modId; }); + std::string moduleName(LoraModule::toModuleName(modId)); + TLLM_CHECK_WITH_INFO(it != loraModules.end(), "lora module " + moduleName + " not enabled for this model"); + TLLM_CHECK_WITH_INFO(it->flattenedInOutSize(adapterSize) <= weights->getShape().d[2], + "lora_weights has to few values for " + moduleName); + } } } } // namespace tensorrt_llm::runtime::lora diff --git a/cpp/tensorrt_llm/runtime/loraUtils.h b/cpp/tensorrt_llm/runtime/loraUtils.h index 039b1d269..a3167dee0 100644 --- a/cpp/tensorrt_llm/runtime/loraUtils.h +++ b/cpp/tensorrt_llm/runtime/loraUtils.h @@ -30,7 +30,8 @@ SizeType constexpr kLORA_NUM_WEIGHTS_POINTERS = 2; void loraValidateRequestTensorDims(std::optional const& optReqLoraWeights, std::optional const& optReqLoraConfig); -void loraValidateRequestTensors(std::optional const& optReqLoraWeights, +void loraValidateRequestTensors(std::optional const& optTaskId, + std::optional const& optReqLoraWeights, std::optional const& optReqLoraConfig, runtime::GptModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig); } // namespace tensorrt_llm::runtime::lora diff --git a/cpp/tensorrt_llm/runtime/runtimeBuffers.cpp b/cpp/tensorrt_llm/runtime/runtimeBuffers.cpp index 17b230fbf..c69e88656 100644 --- a/cpp/tensorrt_llm/runtime/runtimeBuffers.cpp +++ b/cpp/tensorrt_llm/runtime/runtimeBuffers.cpp @@ -583,6 +583,7 @@ void RuntimeBuffers::prepareContextStep(TensorPtr const& inputIds, TokenIdType c auto& stream = manager.getStream(); SizeType const batchSize = generationConfig.batchSize; SizeType const maxInputLength = generationConfig.maxInputLength; + auto const& inputShape = inputIds->getShape(); // use context lengths only in context step sequenceLengths = contextLengthsDevice; @@ -604,7 +605,6 @@ void RuntimeBuffers::prepareContextStep(TensorPtr const& inputIds, TokenIdType c bufferCast(*sinkTokenLengths)[0] = generationConfig.sinkTokenLength; - auto const& inputShape = inputIds->getShape(); auto const contextLengthsHostPtr = bufferCast(*contextLengthsHost); auto const modelVariant = modelConfig.getModelVariant(); @@ -647,16 +647,6 @@ void RuntimeBuffers::prepareContextStep(TensorPtr const& inputIds, TokenIdType c pastKeyValueLengthsPtr[i] = contextLengthsHostPtr[i]; } - if (worldConfig.isPipelineParallel()) - { - auto const hiddenSize - = hiddenStates->getShape().nbDims == 2 ? hiddenStates->getShape().d[1] : hiddenStates->getShape().d[2]; - auto const hiddenStatesShape = modelConfig.usePackedInput() - ? ITensor::makeShape({inputShape.d[0], hiddenSize}) - : ITensor::makeShape({inputShape.d[0], inputShape.d[1], hiddenSize}); - hiddenStates->reshape(hiddenStatesShape); - } - if (modelConfig.usePromptTuning()) { std::vector reqBeamWidths(batchSize, 1); @@ -693,6 +683,15 @@ void RuntimeBuffers::prepareContextStep(TensorPtr const& inputIds, TokenIdType c positionIds = manager.copyFrom(positionIdsVec, attentionMask->getShape(), MemoryType::kGPU); } + if (worldConfig.isPipelineParallel()) + { + auto const hiddenSize = hiddenStates->getShape().d[hiddenStates->getShape().nbDims - 1]; + auto const hiddenStatesShape = modelConfig.usePackedInput() + ? ITensor::makeShape({inputShape.d[0], hiddenSize}) + : ITensor::makeShape({inputShape.d[0], inputShape.d[1], hiddenSize}); + hiddenStates->reshape(hiddenStatesShape); + } + if (modelConfig.useGptAttentionPlugin() && modelConfig.usePagedKvCache()) { auto constexpr contextBeamWidth = 1; @@ -722,17 +721,20 @@ RuntimeBuffers::TensorPtr RuntimeBuffers::prepareNextStep(SizeType const step, B SizeType const batchSize = generationConfig.batchSize; SizeType const beamWidth = generationConfig.beamWidth; - nvinfer1::Dims inputShape; - if (modelConfig.usePackedInput()) + auto const inputShape = [&modelConfig, batchSize, beamWidth]() { - // batch in last dim - inputShape = ITensor::makeShape({batchSize * beamWidth}); - } - else - { - // batch in first dim - inputShape = ITensor::makeShape({batchSize * beamWidth, 1}); - } + if (modelConfig.usePackedInput()) + { + // batch in last dim + return ITensor::makeShape({batchSize * beamWidth}); + } + else + { + // batch in first dim + return ITensor::makeShape({batchSize * beamWidth, 1}); + } + }(); + auto nextInputIds = newTokens ? ITensor::view(newTokens, inputShape) : TensorPtr{}; if (modelConfig.useGptAttentionPlugin()) @@ -774,16 +776,6 @@ RuntimeBuffers::TensorPtr RuntimeBuffers::prepareNextStep(SizeType const step, B { TLLM_THROW("Unsupported model variant"); } - - if (worldConfig.isPipelineParallel()) - { - auto const hiddenSize - = hiddenStates->getShape().nbDims == 2 ? hiddenStates->getShape().d[1] : hiddenStates->getShape().d[2]; - auto const hiddenStatesShape = modelConfig.usePackedInput() - ? ITensor::makeShape({inputShape.d[0], hiddenSize}) - : ITensor::makeShape({inputShape.d[0], inputShape.d[1], hiddenSize}); - hiddenStates->reshape(hiddenStatesShape); - } } else { @@ -817,6 +809,15 @@ RuntimeBuffers::TensorPtr RuntimeBuffers::prepareNextStep(SizeType const step, B positionIds = manager.copyFrom(positionIdsEndVec, ITensor::makeShape({nbInputs, 1}), MemoryType::kGPU); } + if (worldConfig.isPipelineParallel()) + { + auto const hiddenSize = hiddenStates->getShape().d[hiddenStates->getShape().nbDims - 1]; + auto const hiddenStatesShape = modelConfig.usePackedInput() + ? ITensor::makeShape({inputShape.d[0], hiddenSize}) + : ITensor::makeShape({inputShape.d[0], inputShape.d[1], hiddenSize}); + hiddenStates->reshape(hiddenStatesShape); + } + if (modelConfig.usePagedKvCache()) { for (auto batchIdx = firstBatchSlotIdx; batchIdx < firstBatchSlotIdx + batchSize; ++batchIdx) @@ -909,10 +910,10 @@ void RuntimeBuffers::getRuntimeBuffers(TensorMap& inputBuffers, TensorMap& outpu auto kvCacheShape = presentKeysValsAlt.at(0)->getShape(); kvCacheShape.d[3] = 0; - for (SizeType i = firstLayerId; i < firstLayerId + localNbLayers; ++i) + for (SizeType i = 0; i < localNbLayers; ++i) { - std::string name = "past_key_value_" + std::to_string(i); - TensorPtr tmp = ITensor::view(presentKeysValsAlt[i], kvCacheShape); + std::string name = "past_key_value_" + std::to_string(firstLayerId + i); + TensorPtr tmp = ITensor::view(presentKeysValsAlt.at(i), kvCacheShape); inputBuffers.insert_or_assign(name, std::move(tmp)); } } diff --git a/cpp/tensorrt_llm/runtime/tllmRuntime.cpp b/cpp/tensorrt_llm/runtime/tllmRuntime.cpp index a4feed6d2..09261697c 100644 --- a/cpp/tensorrt_llm/runtime/tllmRuntime.cpp +++ b/cpp/tensorrt_llm/runtime/tllmRuntime.cpp @@ -61,7 +61,7 @@ tensorrt_llm::runtime::TllmLogger defaultLogger{}; TllmRuntime::TllmRuntime(void const* engineData, std::size_t engineSize, nvinfer1::ILogger& logger) : mStream(std::make_shared()) - , mBufferManager{mStream} + , mBufferManager{mStream, true} // Ensure to trim the memory pool on destruction. , mRuntime{nvinfer1::createInferRuntime(logger)} , mEngine{mRuntime->deserializeCudaEngine(engineData, engineSize)} { diff --git a/cpp/tensorrt_llm/runtime/workerPool.h b/cpp/tensorrt_llm/runtime/workerPool.h new file mode 100644 index 000000000..aa70fe6b4 --- /dev/null +++ b/cpp/tensorrt_llm/runtime/workerPool.h @@ -0,0 +1,147 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "tensorrt_llm/common/cudaUtils.h" +#include "tensorrt_llm/common/logger.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace tensorrt_llm::runtime +{ + +class WorkerPool +{ +public: + explicit WorkerPool(std::size_t numWorkers = 1, int device = -1) + : mNumWorkers(numWorkers) + , mShutdown(false) + , mDevice(device) + { + initThreads(); + } + + ~WorkerPool() + { + shutdown(); + } + + template >> + std::future enqueue(Function&& task) + { + if (mShutdown) + { + throw std::runtime_error("WorkerPool is shutdown cannot enqueue new tasks"); + } + + auto const taskPromise = std::make_shared>(); + std::lock_guard lock(mTasksMutex); + mTasks.push( + [task = std::forward(task), taskPromise]() + { + try + { + if constexpr (std::is_void_v) + { + task(); + taskPromise->set_value(); + } + else + { + taskPromise->set_value(task()); + } + } + catch (...) + { + taskPromise->set_exception(std::current_exception()); + } + }); + mTasksCv.notify_one(); + return taskPromise->get_future(); + } + +private: + std::size_t mNumWorkers; + + std::queue> mTasks; + mutable std::mutex mTasksMutex; + std::condition_variable mTasksCv; + + std::atomic mShutdown = false; + + std::vector> mThreads; + + int mDevice{-1}; + + void shutdown() + { + if (mShutdown) + { + return; + } + mShutdown = true; + mTasksCv.notify_all(); + for (std::size_t i = 0; i < mThreads.size(); ++i) + { + mThreads.at(i)->join(); + } + } + + void initThreads() + { + for (std::size_t i = 0; i < mNumWorkers; ++i) + { + mThreads.push_back(std::make_shared(std::thread(&WorkerPool::doWork, this))); + } + } + + void doWork() + { + if (mDevice >= 0) + { + TLLM_CUDA_CHECK(cudaSetDevice(mDevice)); + } + else + { + TLLM_LOG_WARNING("WorkerPool did not set cuda device"); + } + while (!mShutdown) + { + std::function task; + { + std::unique_lock lock(mTasksMutex); + mTasksCv.wait(lock, [this]() { return !mTasks.empty() || mShutdown; }); + if (mTasks.empty()) + { + continue; + } + task = mTasks.front(); + mTasks.pop(); + } + + task(); + } + } +}; +} // namespace tensorrt_llm::runtime diff --git a/cpp/tensorrt_llm/runtime/worldConfig.cpp b/cpp/tensorrt_llm/runtime/worldConfig.cpp index ef6ffc6e9..8ebf810c8 100644 --- a/cpp/tensorrt_llm/runtime/worldConfig.cpp +++ b/cpp/tensorrt_llm/runtime/worldConfig.cpp @@ -72,10 +72,9 @@ WorldConfig::WorldConfig(SizeType tensorParallelism, SizeType pipelineParallelis TLLM_CHECK(mPipelineParallelism > 0); } -bool WorldConfig::validConfig(SizeType tensorParallelism, SizeType pipelineParallelism) +bool WorldConfig::validMpiConfig() const { - auto const mpiSize = COMM_SESSION.getSize(); - return mpiSize == tensorParallelism * pipelineParallelism; + return COMM_SESSION.getSize() == getSize(); } WorldConfig WorldConfig::mpi(SizeType gpusPerNode, std::optional tensorParallelism, diff --git a/cpp/tensorrt_llm/thop/dynamicDecodeOp.cpp b/cpp/tensorrt_llm/thop/dynamicDecodeOp.cpp index 755fb5b38..8921556ee 100644 --- a/cpp/tensorrt_llm/thop/dynamicDecodeOp.cpp +++ b/cpp/tensorrt_llm/thop/dynamicDecodeOp.cpp @@ -44,8 +44,9 @@ FtDynamicDecode::FtDynamicDecode(const size_t max_batch_size, const size_t ma auto stream = at::cuda::getCurrentCUDAStream().stream(); auto allocator = std::make_shared(stream); - cudaDeviceProp prop; - tensorrt_llm::common::check_cuda_error(cudaGetDeviceProperties(&prop, 0)); + int deviceId; + tensorrt_llm::common::check_cuda_error(cudaGetDevice(&deviceId)); // Get the correct device id + tensorrt_llm::common::check_cuda_error(cudaGetDeviceProperties(&prop_, deviceId)); dynamic_decode_layer_ = std::make_shared>(tr::DecodingMode::None(), max_batch_size, max_beam_width, vocab_size_, vocab_size_padded_, stream, std::move(allocator), &prop_); diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index f08fa99f9..9933f5250 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -52,6 +52,8 @@ endfunction() add_gtest(loraManagerTest runtime/loraManagerTest.cpp) add_gtest(loraUtilsTest runtime/loraUtilsTest.cpp) +add_gtest(loraCacheTest runtime/loraCacheTest.cpp) +add_gtest(workerPoolTest runtime/workerPoolTest.cpp) add_gtest(attentionKernelTest runtime/transposeKVKernelTest.cpp) add_gtest(gptDecoderTest runtime/gptDecoderTest.cpp) add_gtest(gptDecoderBatchTest runtime/gptDecoderBatchTest.cpp) @@ -96,6 +98,7 @@ set(SAMPLING_LAYER_TEST_SRC layers/topKSamplingLayerTest.cpp layers/topPSamplingLayerTest.cpp) add_gtest(samplingLayerTest "${SAMPLING_LAYER_TEST_SRC}") add_gtest(dynamicDecodeLayerTest layers/dynamicDecodeLayerTest.cpp) +add_gtest(medusaDecodeLayerTest layers/medusaDecodeLayerTest.cpp) if(BUILD_BATCH_MANAGER) if(EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/batch_manager) diff --git a/cpp/tests/kernels/decodingKernelTest.cpp b/cpp/tests/kernels/decodingKernelTest.cpp index 0f36d14c1..dfef16a3f 100644 --- a/cpp/tests/kernels/decodingKernelTest.cpp +++ b/cpp/tests/kernels/decodingKernelTest.cpp @@ -236,6 +236,7 @@ class DecodingKernelsTest : public testing::Test mNumsDraftTokens = BufferManager::pinned( ITensor::makeShape({mMaxBatchSize, mMaxDraftSeqPerStep}), nvinfer1::DataType::kINT32); mSequenceLengths = BufferManager::pinned(ITensor::makeShape({mMaxBatchSize}), nvinfer1::DataType::kINT32); + mAcceptedLengths = BufferManager::pinned(ITensor::makeShape({mMaxBatchSize}), nvinfer1::DataType::kINT32); mContextLengths = BufferManager::pinned(ITensor::makeShape({mMaxBatchSize}), nvinfer1::DataType::kINT32); mDraftContextLengths = BufferManager::pinned(ITensor::makeShape({mMaxBatchSize}), nvinfer1::DataType::kINT32); mFinishedSteps = BufferManager::pinned(ITensor::makeShape({mMaxDraftTokens + 1, mMaxBatchSize}), @@ -402,6 +403,10 @@ class DecodingKernelsTest : public testing::Test for (SizeType si = 0; si < numsDraftTokensPtr[bi * mMaxDraftSeqPerStep + ti]; ++si) { auto const pathIdx = tc::flat_index3(bi, ti, si, mMaxDraftSeqPerStep, mMaxDraftTokens); + if (pathsPtr[pathIdx] == -1) + { + continue; + } auto const draftTokenIdx = bi * mMaxDraftSeqlen + draftContextLengthsPtr[bi] + pathsPtr[pathIdx]; // Avoid generating endId. We'll insert in manually later if needed. draftTokensPtr[draftTokenIdx] = generateAvoidingValues(vocabDistr, {mPadId, endIdsPtr[bi]}); @@ -443,28 +448,42 @@ class DecodingKernelsTest : public testing::Test // ti (!= di), ti+1 (!= di+1), ... for (targetPredictedLen[bi] - targetAcceptedLen[bi]), // EOS, EOS, EOS, ... for (numsDraftTokensPtr[bi] - targetPredictedLen[bi]) // padId, padId, .. to mMaxSeqLen] - for (SizeType si = 0; si < numsDraftTokensPtr[bi * mMaxDraftSeqPerStep + ti]; ++si) + auto numDraftTokens = numsDraftTokensPtr[bi * mMaxDraftSeqPerStep + ti]; + for (SizeType si = 0; si < numDraftTokens; ++si) { - auto const pathIdx = tc::flat_index3(bi, ti, si, mMaxDraftSeqPerStep, mMaxDraftTokens); - auto const pathId = pathsPtr[pathIdx]; - if (pathId == -1) + auto const curPathIdx = tc::flat_index3(bi, ti, si, mMaxDraftSeqPerStep, mMaxDraftTokens); + auto const nextPathIdx = si + 1 < numDraftTokens + ? tc::flat_index3(bi, ti, si + 1, mMaxDraftSeqPerStep, mMaxDraftTokens) + : -1; + auto const curPathId = pathsPtr[curPathIdx]; + auto nextPathId = curPathId; + if (mAcceptMode == AcceptKernelMode::BY_IDS_WITH_PATH) + { + nextPathId = nextPathIdx > -1 ? pathsPtr[nextPathIdx] : -1; + } + + if (curPathId == -1) { continue; } - auto const draftTokenIdx = bi * mMaxDraftSeqlen + draftContextLengthsPtr[bi] + pathId; - auto const targetTokenIdx = bi * mMaxSeqLen + contextLengthsPtr[bi] + pathId; + auto const draftTokenIdx = bi * mMaxDraftSeqlen + draftContextLengthsPtr[bi] + nextPathId; + auto const targetTokenIdx = bi * mMaxSeqLen + contextLengthsPtr[bi] + curPathId; auto targetToken = mPadId; - if (0 <= si && si < targetAcceptedLen[bi]) + if (0 <= si && si < targetAcceptedLen[bi] && nextPathId != -1) { // Use draft token up to the accepted len targetToken = draftTokensPtr[draftTokenIdx]; } - else if (targetAcceptedLen[bi] <= si && si < targetPredictedLen[bi]) + else if (0 <= si && si < targetPredictedLen[bi]) { // Do not use draft token token up to the generated len - targetToken = generateAvoidingValues( - vocabDistr, {mPadId, endIdsPtr[bi], draftTokensPtr[draftTokenIdx]}); + std::unordered_set avoidValues = {mPadId, endIdsPtr[bi]}; + if (nextPathId != -1) + { + avoidValues.insert(draftTokensPtr[draftTokenIdx]); + } + targetToken = generateAvoidingValues(vocabDistr, avoidValues); } else if (targetPredictedLen[bi] <= si && si < numsDraftTokensPtr[bi]) { @@ -472,7 +491,7 @@ class DecodingKernelsTest : public testing::Test targetToken = endIdsPtr[bi]; } targetTokensPtr[targetTokenIdx] = targetToken; - TLLM_LOG_DEBUG("bi %d ti %d si %d pathId %d targetToken %d", bi, ti, si, pathId, targetToken); + TLLM_LOG_DEBUG("bi %d ti %d si %d pathId %d targetToken %d", bi, ti, si, curPathId, targetToken); } } } @@ -506,19 +525,32 @@ class DecodingKernelsTest : public testing::Test mAcceptedPathIdx.resize(mMaxBatchSize); mRefAcceptedTokens.resize(mMaxBatchSize); mFinishedByIdsPaths.resize(mMaxBatchSize); + mLastTargetIdx.resize(mMaxBatchSize); for (SizeType bi = 0; bi < mMaxBatchSize; ++bi) { SizeType maxAcceptedLen = -1; SizeType maxAcceptedPath = -1; + SizeType maxNextTargetTokenIdx = -1; bool maxFinished = false; std::vector maxAcceptedTokens; for (SizeType ti = 0; ti < mMaxDraftSeqPerStep; ++ti) { std::vector acceptedTokens; SizeType curAcceptedLen = mMaxDraftTokens; - SizeType curAcceptedPath = -1; + SizeType curAcceptedPath = ti; bool curFinished = false; - for (SizeType di = 0; di < mMaxDraftTokens; ++di) + + auto const pathIdx = tc::flat_index3(bi, ti, 0, mMaxDraftSeqPerStep, mMaxDraftTokens); + auto const pathId = pathsPtr[pathIdx]; + if (pathId == -1) + { + continue; + } + auto targetTokenIdx = bi * mMaxSeqLen + contextLengthsPtr[bi] + pathId; + auto targetToken = targetTokensPtr[targetTokenIdx]; + auto curNextTargetTokenIdx = pathId; + + for (SizeType di = 1; di < mMaxDraftTokens; ++di) { auto const pathIdx = tc::flat_index3(bi, ti, di, mMaxDraftSeqPerStep, mMaxDraftTokens); auto const pathId = pathsPtr[pathIdx]; @@ -527,12 +559,12 @@ class DecodingKernelsTest : public testing::Test curAcceptedLen = di; curAcceptedPath = ti; curFinished = false; + acceptedTokens.push_back(targetToken); break; } auto const draftTokenIdx = bi * mMaxDraftSeqlen + draftContextLengthsPtr[bi] + pathId; - auto const targetTokenIdx = bi * mMaxSeqLen + contextLengthsPtr[bi] + pathId; + auto targetTokenIdx = bi * mMaxSeqLen + contextLengthsPtr[bi] + pathId; auto const draftToken = draftTokensPtr[draftTokenIdx]; - auto const targetToken = targetTokensPtr[targetTokenIdx]; bool const hasEnd = targetToken == endIdsPtr[bi]; if (!hasEnd) { @@ -540,12 +572,19 @@ class DecodingKernelsTest : public testing::Test } if (draftToken != targetToken || hasEnd) { - auto const curLen = hasEnd ? di : di + 1; + auto const curLen = hasEnd ? di - 1 : di; curAcceptedLen = curLen; curAcceptedPath = ti; curFinished = hasEnd; + curNextTargetTokenIdx = pathId; break; } + targetToken = targetTokensPtr[targetTokenIdx]; + curNextTargetTokenIdx = pathId; + } + if (curAcceptedLen == mMaxDraftTokens) + { + acceptedTokens.push_back(targetToken); } if (curAcceptedLen > maxAcceptedLen) { @@ -553,12 +592,14 @@ class DecodingKernelsTest : public testing::Test maxAcceptedPath = curAcceptedPath; maxAcceptedTokens = acceptedTokens; maxFinished = curFinished; + maxNextTargetTokenIdx = curNextTargetTokenIdx; } } mAcceptedLen[bi] = maxAcceptedLen; mAcceptedPathIdx[bi] = maxAcceptedPath; mRefAcceptedTokens[bi] = maxAcceptedTokens; mFinishedByIdsPaths[bi] = maxFinished; + mLastTargetIdx[bi] = maxNextTargetTokenIdx; TLLM_LOG_DEBUG("bi %d maxAcceptedLen %d maxAcceptedPath %d", bi, maxAcceptedLen, maxAcceptedPath); std::ostringstream ss; for (auto& tk : maxAcceptedTokens) @@ -687,7 +728,7 @@ class DecodingKernelsTest : public testing::Test void callAcceptByIdsWithPaths() { tk::acceptDraftTokensByIdsWithPaths(bufferCast(*mDraftTokens), bufferCast(*mTargetTokens), - bufferCast(*mDraftContextLengths), + bufferCast(*mDraftContextLengths), bufferCast(*mAcceptedLengths), reinterpret_cast(bufferCast(*mFinishedFinal)), bufferCast(*mBatchSlots), bufferCast(*mPaths), bufferCast(*mEndIds), static_cast(nullptr), reinterpret_cast(bufferCast(*mMedusaLogitsPtrs)), @@ -779,6 +820,7 @@ class DecodingKernelsTest : public testing::Test auto batchSlotsPtr = BufferRange(*mBatchSlots); auto draftContextLengths = BufferRange(*mDraftContextLengths); auto draftContextLengthsInit = BufferRange(*mDraftContextLengthsCopy); + auto acceptedLengths = BufferRange(*mAcceptedLengths); auto draftTokensPtr = BufferRange(*mDraftTokens); auto finishedFinalPtr = reinterpret_cast(bufferCast(*mFinishedFinal)); @@ -787,17 +829,24 @@ class DecodingKernelsTest : public testing::Test { auto const batchSlot = batchSlotsPtr[bi]; auto const bestPathIdx = mAcceptedPathIdx[batchSlot]; + auto const lastTargetIdx = mLastTargetIdx[batchSlot]; + if (lastTargetIdx < 0) + { + continue; + } + auto const acceptedLen = mAcceptedLen[batchSlot]; auto acceptedTokens = mRefAcceptedTokens[batchSlot]; for (int32_t hi = 0; hi < mMaxNumHeads; ++hi) { auto refOffset - = tc::flat_index4(hi, bi, acceptedLen, 0, mMaxBatchSize, mMaxDraftSeqPerStep, mVocabSize); + = tc::flat_index4(hi, bi, lastTargetIdx, 0, mMaxBatchSize, mMaxDraftSeqPerStep, mVocabSize); auto outOffset = static_cast(medusaLogitsPtrsPtr[bi * mMaxNumHeads + hi] - static_cast(nullptr)); EXPECT_EQ(outOffset, refOffset) << " bi: " << bi << " hi: " << hi << " seed: " << seed; } + EXPECT_EQ(acceptedLengths[batchSlot], acceptedLen) << " bi: " << bi << " seed: " << seed; EXPECT_EQ(draftContextLengths[batchSlot], draftContextLengthsInit[batchSlot] + acceptedLen) << " bi: " << bi << " seed: " << seed << " out: " << draftContextLengths[batchSlot] << " ref: " << draftContextLengthsInit[batchSlot] + acceptedLen; @@ -858,6 +907,10 @@ class DecodingKernelsTest : public testing::Test for (SizeType seed = 0; seed < mSeeds; ++seed) { + // if (seed != 145) + // { + // continue; + // } TLLM_LOG_DEBUG("Seed %d", seed); initData(seed); @@ -889,6 +942,7 @@ class DecodingKernelsTest : public testing::Test TensorPtr mNumsDraftTokens; TensorPtr mSequenceLengths; + TensorPtr mAcceptedLengths; TensorPtr mContextLengths; TensorPtr mDraftContextLengthsCopy; TensorPtr mDraftContextLengths; @@ -907,6 +961,7 @@ class DecodingKernelsTest : public testing::Test std::vector mOutputLen; std::vector mAcceptedFinished; std::vector mAcceptedPathIdx; + std::vector mLastTargetIdx; std::vector> mRefAcceptedTokens; std::vector mFinishedByIdsPaths; diff --git a/cpp/tests/kernels/sampling/samplingAirTopPTest.cpp b/cpp/tests/kernels/sampling/samplingAirTopPTest.cpp index f52c74fc6..43d0ce11f 100644 --- a/cpp/tests/kernels/sampling/samplingAirTopPTest.cpp +++ b/cpp/tests/kernels/sampling/samplingAirTopPTest.cpp @@ -43,7 +43,8 @@ class AirTopPSamplingKernelTest : public SamplingKernelTest private: size_t getWorkspaceSize(SamplingKernelTestParam const& params) override { - return tensorrt_llm::kernels::getAirTopPWorkspaceSize(params.batchSize, params.vocabSize); + return tensorrt_llm::kernels::getAirTopPWorkspaceSize( + params.batchSize, params.vocabSize, params.isDeterministicTopP); } void callTestedFunction( @@ -56,7 +57,9 @@ class AirTopPSamplingKernelTest : public SamplingKernelTest TLLM_CUDA_CHECK(cudaDeviceGetAttribute(&smCnt, cudaDevAttrMultiProcessorCount, dev)); auto const maxBatchSize = 2 * params.batchSize; - int blockNum = tk::calcAirTopPBlockNum(params.batchSize, params.vocabSize, smCnt); + int blockNum + = tk::calcAirTopPBlockNum(params.batchSize, params.vocabSize, smCnt, params.isDeterministicTopP); + // Perform batched TopP sampling tk::invokeBatchAirTopPSampling(workspaceDevice->data(), bufferCast(*this->mIdsPtrHost), bufferCast(*this->mSeqLengthsDevice), @@ -73,37 +76,65 @@ class AirTopPSamplingKernelTest : public SamplingKernelTest reinterpret_cast(bufferCast(*this->mCurandStatesDevice)), params.batchSize, maxBatchSize, params.vocabSize, bufferCast(*this->mEndIdsDevice), this->mMaxTopP, bufferCast(*this->mTopPsDevice), this->mStream->get(), blockNum, - bufferCast(*this->mSkipDecodeDevice), bufferCast(*this->mBatchSlots)); + bufferCast(*this->mSkipDecodeDevice), bufferCast(*this->mBatchSlots), + params.isDeterministicTopP); } }; TYPED_TEST_SUITE(AirTopPSamplingKernelTest, FloatAndHalfTypes); -TYPED_TEST(AirTopPSamplingKernelTest, CorrectnessSmallP) +TYPED_TEST(AirTopPSamplingKernelTest, NondeterministicCorrectnessSmallP) { this->runTest(SamplingKernelTestParam().setBatchSize(6).setVocabSize(4).setTopP(0.2f)); }; -TYPED_TEST(AirTopPSamplingKernelTest, CorrectnessLargeP) +TYPED_TEST(AirTopPSamplingKernelTest, NondeterministicCorrectnessLargeP) { this->runTest(SamplingKernelTestParam().setBatchSize(6).setVocabSize(4).setTopP(0.9f)); }; -TYPED_TEST(AirTopPSamplingKernelTest, CorrectnessAncestral) +TYPED_TEST(AirTopPSamplingKernelTest, NondeterministicCorrectnessAncestral) { this->runTest(SamplingKernelTestParam().setBatchSize(6).setVocabSize(4).setTopP(1.0f)); }; -TYPED_TEST(AirTopPSamplingKernelTest, CorrectnessLargeVocabSmallP) +TYPED_TEST(AirTopPSamplingKernelTest, NondeterministicCorrectnessLargeVocabSmallP) { this->runTest(SamplingKernelTestParam().setBatchSize(32).setVocabSize(51200).setTopP(0.2f)); }; -TYPED_TEST(AirTopPSamplingKernelTest, CorrectnessLargeVocabLargeP) +TYPED_TEST(AirTopPSamplingKernelTest, NondeterministicCorrectnessLargeVocabLargeP) { this->runTest(SamplingKernelTestParam().setBatchSize(32).setVocabSize(51200).setTopP(0.9f)); }; +TYPED_TEST(AirTopPSamplingKernelTest, DeterministicCorrectnessSmallP) +{ + this->runTest(SamplingKernelTestParam().setBatchSize(6).setVocabSize(4).setTopP(0.2f).setDeterministicTopP(true)); +}; + +TYPED_TEST(AirTopPSamplingKernelTest, DeterministicCorrectnessLargeP) +{ + this->runTest(SamplingKernelTestParam().setBatchSize(6).setVocabSize(4).setTopP(0.9f).setDeterministicTopP(true)); +}; + +TYPED_TEST(AirTopPSamplingKernelTest, DeterministicCorrectnessAncestral) +{ + this->runTest(SamplingKernelTestParam().setBatchSize(6).setVocabSize(4).setTopP(1.0f).setDeterministicTopP(true)); +}; + +TYPED_TEST(AirTopPSamplingKernelTest, DeterministicCorrectnessLargeVocabSmallP) +{ + this->runTest( + SamplingKernelTestParam().setBatchSize(32).setVocabSize(51200).setTopP(0.2f).setDeterministicTopP(true)); +}; + +TYPED_TEST(AirTopPSamplingKernelTest, DeterministicCorrectnessLargeVocabLargeP) +{ + this->runTest( + SamplingKernelTestParam().setBatchSize(32).setVocabSize(51200).setTopP(0.9f).setDeterministicTopP(true)); +}; + class AirTopPSamplingKernelUtilsTest : public SamplingKernelTest { }; diff --git a/cpp/tests/kernels/sampling/samplingPenaltyTest.cpp b/cpp/tests/kernels/sampling/samplingPenaltyTest.cpp index 71204138c..05bd7a6bf 100644 --- a/cpp/tests/kernels/sampling/samplingPenaltyTest.cpp +++ b/cpp/tests/kernels/sampling/samplingPenaltyTest.cpp @@ -173,6 +173,9 @@ class TemperaturePenaltyTest : public SamplingKernelTest trk::invokeFill(*mLogitsRefHost, T{0.0f}, *mStream); trk::invokeFill(*mOutLogitsDevice, T{0.0f}, *mStream); + trk::invokeFill(*mLogitsRefHost, T{0.0f}, *mStream); + trk::invokeFill(*mOutLogitsDevice, T{0.0f}, *mStream); + auto batchSlotsPtr = bufferCast(*mBatchSlots); for (SizeType bi = 0; bi < mBatchSize; ++bi) { @@ -256,9 +259,9 @@ class TemperaturePenaltyTest : public SamplingKernelTest InvokeBatchApplyPenaltyParams penaltyParams{reinterpret_cast(bufferCast(*mLogitsPtrs)), bufferCast(*mOutLogitsDevice), bufferCast(*mBiasDevice), bufferCast(*mPenaltyWorkspaceDevice), nullptr, bufferCast(*mTemperaturesDevice), nullptr, - nullptr, nullptr, false, static_cast(mBatchSize), 1, 1, static_cast(mVocabSize), - static_cast(mVocabSizePadded), nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, - bufferCast(*mBatchSlots), mMaxTokensPerStep, bufferCast(*mTokensPerStep), mStream->get()}; + nullptr, nullptr, false, mBatchSize, 1, 1, mVocabSize, mVocabSizePadded, nullptr, nullptr, nullptr, nullptr, + nullptr, nullptr, bufferCast(*mBatchSlots), mMaxTokensPerStep, + bufferCast(*mTokensPerStep), mStream->get()}; tk::invokeBatchApplyPenalty(penaltyParams); auto logitsOutHost = mBufferManager->copyFrom(*mOutLogitsDevice, MemoryType::kCPU); @@ -655,15 +658,14 @@ class RepetitionPenaltyTest : public SamplingKernelTest { subsetup(param); - InvokeBatchApplyPenaltyParams penalty_params{reinterpret_cast(bufferCast(*mLogitsPtrs)), + InvokeBatchApplyPenaltyParams penaltyParams{reinterpret_cast(bufferCast(*mLogitsPtrs)), bufferCast(*mOutLogitsDevice), nullptr, bufferCast(*mPenaltyWorkspaceDevice), nullptr, nullptr, bufferCast(*mRepetitionPenaltiesDevice), bufferCast(*mPresencePenaltiesDevice), - bufferCast(*mFrequencyPenaltiesDevice), true, static_cast(mBatchSize), 1, mSequenceLength, - static_cast(mVocabSize), static_cast(mVocabSizePadded), - reinterpret_cast(bufferCast(*mIdsPtrDevice)), nullptr, + bufferCast(*mFrequencyPenaltiesDevice), true, mBatchSize, 1, mSequenceLength, mVocabSize, + mVocabSizePadded, reinterpret_cast(bufferCast(*mIdsPtrDevice)), nullptr, bufferCast(*mContextLengthDevice), bufferCast(*mSeqLengthDevice), nullptr, nullptr, bufferCast(*mBatchSlots), mMaxTokensPerStep, bufferCast(*mTokensPerStep), mStream->get()}; - tk::invokeBatchApplyPenalty(penalty_params); + tk::invokeBatchApplyPenalty(penaltyParams); auto logitsOutHost = mBufferManager->copyFrom(*mOutLogitsDevice, MemoryType::kCPU); @@ -1252,14 +1254,13 @@ class MinLengthPenaltyTest : public SamplingKernelTest { subsetup(param); - InvokeBatchApplyPenaltyParams penalty_params{reinterpret_cast(bufferCast(*mLogitsPtrs)), + InvokeBatchApplyPenaltyParams penaltyParams{reinterpret_cast(bufferCast(*mLogitsPtrs)), bufferCast(*mOutLogitsDevice), nullptr, bufferCast(*mPenaltyWorkspaceDevice), nullptr, nullptr, - nullptr, nullptr, nullptr, false, static_cast(mBatchSize), 1, mSequenceLength, - static_cast(mVocabSize), static_cast(mVocabSizePadded), nullptr, nullptr, - bufferCast(*mContextLengthDevice), bufferCast(*mSeqLengthDevice), + nullptr, nullptr, nullptr, false, mBatchSize, 1, mSequenceLength, mVocabSize, mVocabSizePadded, nullptr, + nullptr, bufferCast(*mContextLengthDevice), bufferCast(*mSeqLengthDevice), bufferCast(*mMinLengthDevice), bufferCast(*mEndIdsDevice), bufferCast(*mBatchSlots), mMaxTokensPerStep, bufferCast(*mTokensPerStep), mStream->get()}; - tk::invokeBatchApplyPenalty(penalty_params); + tk::invokeBatchApplyPenalty(penaltyParams); mStream->synchronize(); diff --git a/cpp/tests/kernels/sampling/samplingTest.h b/cpp/tests/kernels/sampling/samplingTest.h index 102c1caea..ccf8c3fe6 100644 --- a/cpp/tests/kernels/sampling/samplingTest.h +++ b/cpp/tests/kernels/sampling/samplingTest.h @@ -197,6 +197,7 @@ struct SamplingKernelTestParam int32_t maxTokensPerStep{1}; bool returnAllTopK{false}; bool useLogitsPtrs{false}; + bool isDeterministicTopP{false}; SamplingKernelTestParam& setBatchSize(int32_t bs) { @@ -240,6 +241,12 @@ struct SamplingKernelTestParam return *this; } + SamplingKernelTestParam& setDeterministicTopP(bool isDeter) + { + isDeterministicTopP = isDeter; + return *this; + } + std::string toString() const { return tensorrt_llm::common::fmtstr( diff --git a/cpp/tests/kernels/sampling/samplingTopKTest.cpp b/cpp/tests/kernels/sampling/samplingTopKTest.cpp index 4cfc41317..10a9d2426 100644 --- a/cpp/tests/kernels/sampling/samplingTopKTest.cpp +++ b/cpp/tests/kernels/sampling/samplingTopKTest.cpp @@ -59,7 +59,7 @@ class TopKSamplingKernelTest : public SamplingKernelTest params.useLogitsPtrs ? nullptr : bufferCast(*this->mProbsDevice), params.useLogitsPtrs ? reinterpret_cast(bufferCast(*this->mProbsPtrsDevice)) : nullptr, - bufferCast(*this->mIdsPtrHost), bufferCast(*this->mSeqLengthsDevice), + bufferCast(*this->mIdsPtrHost), nullptr, bufferCast(*this->mSeqLengthsDevice), reinterpret_cast( bufferCast(*this->mFinishedDevice)), reinterpret_cast( @@ -69,7 +69,7 @@ class TopKSamplingKernelTest : public SamplingKernelTest bufferCast(*this->mTopKsDevice), params.topP, bufferCast(*this->mTopPsDevice), params.vocabSize, bufferCast(*this->mEndIdsDevice), bufferCast(*this->mBatchSlots), this->mStream->get(), params.batchSize, maxBatchSize, bufferCast(*this->mTokensPerStep), - params.maxTokensPerStep, bufferCast(*this->mSkipDecodeDevice), params.normalizeLogProbs, + params.maxTokensPerStep, 0, bufferCast(*this->mSkipDecodeDevice), params.normalizeLogProbs, params.logitsHasProbs, params.returnAllTopK); } }; diff --git a/cpp/tests/kernels/sampling/samplingUtilsTest.cu b/cpp/tests/kernels/sampling/samplingUtilsTest.cu index 04f42207e..ad0fb714c 100644 --- a/cpp/tests/kernels/sampling/samplingUtilsTest.cu +++ b/cpp/tests/kernels/sampling/samplingUtilsTest.cu @@ -29,8 +29,6 @@ namespace tk = tensorrt_llm::kernels; namespace { -static float constexpr HALF_FLT_MAX = 65504.F; - __global__ void generateRandomNumber(int32_t* vals, curandState_t* states, int const batch_size) { int idx = threadIdx.x; diff --git a/cpp/tests/kernels/weightOnly/weightOnlyKernelTest.cpp b/cpp/tests/kernels/weightOnly/weightOnlyKernelTest.cpp index 84d95d1f9..04b004bb4 100644 --- a/cpp/tests/kernels/weightOnly/weightOnlyKernelTest.cpp +++ b/cpp/tests/kernels/weightOnly/weightOnlyKernelTest.cpp @@ -6,7 +6,7 @@ #include "tensorrt_llm/common/cudaUtils.h" #include "tensorrt_llm/common/quantization.h" #include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm.h" -#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/enabled.h" +#include "tensorrt_llm/kernels/preQuantScaleKernel.h" #include "tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelLauncher.h" #include @@ -18,51 +18,12 @@ #include #include #include +#include +#include #include #include -using tensorrt_llm::kernels::WeightOnlyParams; -using tensorrt_llm::kernels::WeightOnlyType; -using tensorrt_llm::kernels::WeightOnlyQuantType; -using tensorrt_llm::kernels::WeightOnlyActivationType; -using tensorrt_llm::kernels::WeightOnlyActivationFunctionType; -template -struct AType; - -template <> -struct AType -{ - using CudaKernelAType = half; - using CutlassKernelAType = half; -}; -#if defined(ENABLE_BF16) -template <> -struct AType -{ - using CudaKernelAType = __nv_bfloat16; - using CutlassKernelAType = __nv_bfloat16; -}; -#endif -template -struct BType; - -template <> -struct BType -{ - using CudaKernelBType = uint8_t; - using CutlassKernelBType = cutlass::uint4b_t; - static constexpr int elemsPerByte = 2; -}; - -template <> -struct BType -{ - using CudaKernelBType = uint8_t; - using CutlassKernelBType = uint8_t; - static constexpr int elemsPerByte = 1; -}; -struct CutlassKernel; -struct CudaKernel; +namespace wo = tensorrt_llm::kernels::weight_only; void simple_assert(bool flag) { @@ -72,192 +33,6 @@ void simple_assert(bool flag) } } -template -std::vector get_configs(T& runner, int k) -{ - auto configs = runner.getConfigs(); - std::vector rets; - for (auto config : configs) - { - if (config.stages >= 5) - { - continue; - } - if (config.split_k_style != tensorrt_llm::cutlass_extensions::SplitKStyle::NO_SPLIT_K) - { - int k_size = (k + config.split_k_factor - 1) / config.split_k_factor; - if (k_size % 64) - { - continue; - } - } - rets.push_back(config); - } - return rets; -} - -template -float benchmark_perchannel(void* act, void* weight, void* scales, void* zeros, void* bias, void* out, int m, int n, - int k, int group_size, int warmup, int iter) -{ - simple_assert(zeros == nullptr && bias == nullptr && group_size == 0); - cudaStream_t s; - cudaStreamCreate(&s); - cudaEvent_t begin, end; - cudaEventCreate(&begin); - cudaEventCreate(&end); - if constexpr (std::is_same_v) - { - WeightOnlyParams params{reinterpret_cast(weight), scales, zeros, act, nullptr, bias, out, m, n, k, - group_size, BFlag, WeightOnlyType::PerChannel, WeightOnlyActivationFunctionType::Identity, AFlag}; - for (int i = 0; i < warmup; ++i) - { - tensorrt_llm::kernels::weight_only_batched_gemv_launcher(params, s); - } - cudaEventRecord(begin, s); - for (int i = 0; i < iter; ++i) - { - tensorrt_llm::kernels::weight_only_batched_gemv_launcher(params, s); - } - } - else if (std::is_same_v) - { - tensorrt_llm::kernels::cutlass_kernels::CutlassFpAIntBGemmRunner::CutlassKernelAType, - typename BType::CutlassKernelBType, cutlass::WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY> - gemm; - auto configs = get_configs(gemm, k); - int ws_bytes = gemm.getWorkspaceSize(m, n, k); - char* ws_ptr = nullptr; - if (ws_bytes) - cudaMalloc(&ws_ptr, ws_bytes); - float fast_time = 1e8; - auto best_config = configs[0]; - for (auto& config : configs) - { - for (int i = 0; i < 2; ++i) - { - gemm.gemm(act, weight, scales, out, m, n, k, config, ws_ptr, ws_bytes, s); - } - cudaEventRecord(begin, s); - for (int i = 0; i < 5; ++i) - { - gemm.gemm(act, weight, scales, out, m, n, k, config, ws_ptr, ws_bytes, s); - } - cudaEventRecord(end, s); - cudaEventSynchronize(end); - float time; - cudaEventElapsedTime(&time, begin, end); - if (time < fast_time) - { - fast_time = time; - best_config = config; - } - } - - for (int i = 0; i < warmup; ++i) - { - gemm.gemm(act, weight, scales, out, m, n, k, best_config, ws_ptr, ws_bytes, s); - } - cudaEventRecord(begin, s); - for (int i = 0; i < iter; ++i) - { - gemm.gemm(act, weight, scales, out, m, n, k, best_config, ws_ptr, ws_bytes, s); - } - if (ws_ptr) - cudaFree(ws_ptr); - } - - cudaEventRecord(end, s); - cudaEventSynchronize(end); - float time; - cudaEventElapsedTime(&time, begin, end); - cudaEventDestroy(begin); - cudaEventDestroy(end); - cudaStreamDestroy(s); - return time / iter; -} - -template -float benchmark_groupwise(void* act, void* weight, void* scales, void* zeros, void* bias, void* out, int m, int n, - int k, int group_size, int warmup, int iter) -{ - simple_assert(zeros && bias && (group_size == 64 || group_size == 128)); - cudaStream_t s; - cudaStreamCreate(&s); - cudaEvent_t begin, end; - cudaEventCreate(&begin); - cudaEventCreate(&end); - if constexpr (std::is_same_v) - { - WeightOnlyParams params{reinterpret_cast(weight), scales, zeros, act, nullptr, bias, out, m, n, k, - group_size, BFlag, WeightOnlyType::GroupWise, WeightOnlyActivationFunctionType::Identity, AFlag}; - for (int i = 0; i < warmup; ++i) - { - tensorrt_llm::kernels::weight_only_batched_gemv_launcher(params, s); - } - cudaEventRecord(begin, s); - for (int i = 0; i < iter; ++i) - { - tensorrt_llm::kernels::weight_only_batched_gemv_launcher(params, s); - } - } - else if (std::is_same_v) - { - tensorrt_llm::kernels::cutlass_kernels::CutlassFpAIntBGemmRunner::CutlassKernelAType, - typename BType::CutlassKernelBType, cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS> - gemm; - auto configs = get_configs(gemm, k); - int ws_bytes = gemm.getWorkspaceSize(m, n, k); - char* ws_ptr = nullptr; - if (ws_bytes) - cudaMalloc(&ws_ptr, ws_bytes); - float fast_time = 1e8; - auto best_config = configs[0]; - for (auto& config : configs) - { - for (int i = 0; i < 2; ++i) - { - gemm.gemm(act, weight, scales, zeros, bias, out, m, n, k, group_size, config, ws_ptr, ws_bytes, s); - } - cudaEventRecord(begin, s); - for (int i = 0; i < 5; ++i) - { - gemm.gemm(act, weight, scales, zeros, bias, out, m, n, k, group_size, config, ws_ptr, ws_bytes, s); - } - cudaEventRecord(end, s); - cudaEventSynchronize(end); - float time; - cudaEventElapsedTime(&time, begin, end); - if (time < fast_time) - { - fast_time = time; - best_config = config; - } - } - - for (int i = 0; i < warmup; ++i) - { - gemm.gemm(act, weight, scales, zeros, bias, out, m, n, k, group_size, best_config, ws_ptr, ws_bytes, s); - } - cudaEventRecord(begin, s); - for (int i = 0; i < iter; ++i) - { - gemm.gemm(act, weight, scales, zeros, bias, out, m, n, k, group_size, best_config, ws_ptr, ws_bytes, s); - } - if (ws_ptr) - cudaFree(ws_ptr); - } - - cudaEventRecord(end, s); - cudaEventSynchronize(end); - float time; - cudaEventElapsedTime(&time, begin, end); - cudaEventDestroy(begin); - cudaEventDestroy(end); - cudaStreamDestroy(s); - return time / iter; -} - struct CudaBuffer { void* _data; @@ -334,7 +109,7 @@ float compare(void* _pa, void* _pb, int size, float scale) template void random_fill(std::vector& vec, T2 minv, T2 maxv) { - std::mt19937 gen(20231205); + std::mt19937 gen(rand()); std::uniform_real_distribution dis(static_cast(minv), static_cast(maxv)); for (auto& v : vec) { @@ -342,50 +117,224 @@ void random_fill(std::vector& vec, T2 minv, T2 maxv) } } -template -bool benchmark(int m, int n, int k, int group_size, int warmup, int iter) +template +std::vector get_configs(T& runner, int k) { - printf("benchmark mnk (%d, %d, %d) ", m, n, k); - if (AFlag == WeightOnlyActivationType::FP16) + auto configs = runner.getConfigs(); + std::vector rets; + for (auto config : configs) { - printf("FP16 Activation "); + if (config.stages >= 5) + { + continue; + } + if (config.split_k_style != tensorrt_llm::cutlass_extensions::SplitKStyle::NO_SPLIT_K) + { + int k_size = (k + config.split_k_factor - 1) / config.split_k_factor; + if (k_size % 64) + { + continue; + } + } + rets.push_back(config); } - else + return rets; +} + +template +struct cutlassTypeMapper +{ +}; + +#define CUTLASS_TYPE_MAPPER_REGISTRY( \ + CudaKernelType, KernelInfoStr, CutlassAType, CutlassWType, WElemBits, CutlassQuantOp) \ + template <> \ + struct cutlassTypeMapper \ + { \ + using AType = CutlassAType; \ + using WType = CutlassWType; \ + static constexpr cutlass::WeightOnlyQuantOp QuantOp = CutlassQuantOp; \ + static constexpr int WSizeInBits = WElemBits; \ + static std::string str(int m, int n, int k, int gs) \ + { \ + std::stringstream ss; \ + ss << KernelInfoStr << " mnk(" << m << ", " << n << ", " << k << ")"; \ + if (gs != 0) \ + ss << ", gs " << gs; \ + return ss.str(); \ + } \ + }; +CUTLASS_TYPE_MAPPER_REGISTRY(wo::KernelType::FP16Int4Groupwise, "FP16Int4Groupwise", half, cutlass::uint4b_t, 4, + cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS); +CUTLASS_TYPE_MAPPER_REGISTRY(wo::KernelType::BF16Int4Groupwise, "BF16Int4Groupwise", __nv_bfloat16, cutlass::uint4b_t, + 4, cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS); +CUTLASS_TYPE_MAPPER_REGISTRY(wo::KernelType::FP16Int8PerChannel, "FP16Int8PerChannel", half, uint8_t, 8, + cutlass::WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY); +CUTLASS_TYPE_MAPPER_REGISTRY(wo::KernelType::BF16Int8PerChannel, "BF16Int8PerChannel", __nv_bfloat16, uint8_t, 8, + cutlass::WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY); +CUTLASS_TYPE_MAPPER_REGISTRY(wo::KernelType::FP16Int4PerChannel, "FP16Int4PerChannel", half, cutlass::uint4b_t, 4, + cutlass::WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY); +CUTLASS_TYPE_MAPPER_REGISTRY(wo::KernelType::BF16Int4PerChannel, "BF16Int4PerChannel", __nv_bfloat16, cutlass::uint4b_t, + 4, cutlass::WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY); + +float run_cuda_kernel(wo::Params& params, int warmup, int iter) +{ + int arch = tensorrt_llm::common::getSMVersion(); + simple_assert(wo::is_supported(arch, params.type)); + cudaStream_t s; + cudaStreamCreate(&s); + cudaEvent_t begin, end; + cudaEventCreate(&begin); + cudaEventCreate(&end); + for (int i = 0; i < warmup; ++i) { - printf("BF16 Activation "); + wo::kernel_launcher(arch, params, s); } - if (BFlag == WeightOnlyQuantType::Int8b) + cudaEventRecord(begin, s); + for (int i = 0; i < iter; ++i) { - printf("Int8b "); + wo::kernel_launcher(arch, params, s); } - else + cudaEventRecord(end, s); + cudaEventSynchronize(end); + float time; + cudaEventElapsedTime(&time, begin, end); + cudaEventDestroy(begin); + cudaEventDestroy(end); + cudaStreamDestroy(s); + return time / iter; +} + +template +void exec_cutlass_kernel( + void* scaled_act, Runner& runner, wo::Params& params, Config& config, char* ws, size_t ws_size, cudaStream_t stream) +{ + using AType = typename cutlassTypeMapper::AType; + static constexpr cutlass::WeightOnlyQuantOp QuantOp = cutlassTypeMapper::QuantOp; + void* act = params.act; + if (params.act_scale) { - printf("Int4b "); + tensorrt_llm::kernels::apply_per_channel_scale_kernel_launcher( + reinterpret_cast(scaled_act), reinterpret_cast(params.act), + reinterpret_cast(params.act_scale), params.m, params.k, stream); + act = scaled_act; } - if (group_size == 0) + if constexpr (QuantOp == cutlass::WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY) { - printf("PerChannel Weight Only\n"); + runner.gemm( + act, params.weight, params.scales, params.out, params.m, params.n, params.k, config, ws, ws_size, stream); } - else + else if (QuantOp == cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS) + { + runner.gemm(act, params.weight, params.scales, params.zeros, params.bias, params.out, params.m, params.n, + params.k, params.groupsize, config, ws, ws_size, stream); + } +} + +template +float run_cutlass_kernel(wo::Params& params, int warmup, int iter) +{ + int arch = tensorrt_llm::common::getSMVersion(); + simple_assert(KT == params.type); + simple_assert(wo::is_supported(arch, params.type)); + using AType = typename cutlassTypeMapper::AType; + using WType = typename cutlassTypeMapper::WType; + CudaBuffer scaled_act(params.m * params.k * sizeof(AType)); + auto runner = std::make_shared::QuantOp>>(); + auto& gemm = *runner; + cudaStream_t s; + cudaStreamCreate(&s); + cudaEvent_t begin, end; + cudaEventCreate(&begin); + cudaEventCreate(&end); + auto configs = get_configs(gemm, params.k); + int ws_bytes = gemm.getWorkspaceSize(params.m, params.n, params.k); + char* ws_ptr = nullptr; + if (ws_bytes) + cudaMalloc(&ws_ptr, ws_bytes); + float fast_time = 1e8; + auto best_config = configs[0]; + for (auto& config : configs) + { + for (int i = 0; i < 2; ++i) + { + exec_cutlass_kernel(scaled_act.data(), gemm, params, config, ws_ptr, ws_bytes, s); + } + cudaEventRecord(begin, s); + for (int i = 0; i < 5; ++i) + { + exec_cutlass_kernel(scaled_act.data(), gemm, params, config, ws_ptr, ws_bytes, s); + } + cudaEventRecord(end, s); + cudaEventSynchronize(end); + float time; + cudaEventElapsedTime(&time, begin, end); + if (time < fast_time) + { + fast_time = time; + best_config = config; + } + } + + for (int i = 0; i < warmup; ++i) { - printf("GroupWise%d Weight Only\n", group_size); + exec_cutlass_kernel(scaled_act.data(), gemm, params, best_config, ws_ptr, ws_bytes, s); } - using AT = typename AType::CudaKernelAType; - using BT = typename BType::CudaKernelBType; - constexpr int elem_per_byte = BType::elemsPerByte; - CudaBuffer d_act(m * k * sizeof(AT)); - CudaBuffer d_weight(k * n * sizeof(uint8_t) / elem_per_byte); - CudaBuffer d_scales(n * k * sizeof(AT)); - CudaBuffer d_zeros(n * k * sizeof(AT)); - CudaBuffer d_bias(n * sizeof(AT)); - CudaBuffer d_out(m * n * sizeof(AT)); - std::vector h_act(m * k); + cudaEventRecord(begin, s); + for (int i = 0; i < iter; ++i) + { + exec_cutlass_kernel(scaled_act.data(), gemm, params, best_config, ws_ptr, ws_bytes, s); + } + if (ws_ptr) + cudaFree(ws_ptr); + cudaEventRecord(end, s); + cudaEventSynchronize(end); + float time; + cudaEventElapsedTime(&time, begin, end); + cudaEventDestroy(begin); + cudaEventDestroy(end); + cudaStreamDestroy(s); + return time / iter; +} + +template +bool benchmark_and_verify(int m, int n, int k, int groupsize, int warmup, int iter) +{ + std::srand(20240123); + simple_assert(m <= 4); + if constexpr (cutlassTypeMapper::QuantOp == cutlass::WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY) + { + simple_assert(groupsize == 0); + } + else if (cutlassTypeMapper::QuantOp == cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS) + { + simple_assert(groupsize == 64 || groupsize == 128); + } + using AType = typename cutlassTypeMapper::AType; + using WType = typename cutlassTypeMapper::WType; + static constexpr int ASizeInBits = sizeof(AType) * 8; + static constexpr int WSizeInBits = cutlassTypeMapper::WSizeInBits; + int gs_factor = groupsize == 0 ? 1 : groupsize; + printf("Kernel %s\n", cutlassTypeMapper::str(m, n, k, groupsize).c_str()); + + CudaBuffer d_act(m * k * ASizeInBits / 8); + CudaBuffer d_act_scale(k * ASizeInBits / 8); + CudaBuffer d_weight(k * n * WSizeInBits / 8); + CudaBuffer d_scales(n * k / gs_factor * ASizeInBits / 8); + CudaBuffer d_zeros(n * k / gs_factor * ASizeInBits / 8); + CudaBuffer d_bias(n * ASizeInBits / 8); + CudaBuffer d_out(m * n * ASizeInBits / 8); + std::vector h_act(m * k), h_act_scale(k); std::vector h_weight(k * n); - std::vector h_scales(n * k), h_zeros(n * k), h_bias(n); - std::vector h_out1(m * n), h_out2(m * n); + std::vector h_scales(n * k), h_zeros(n * k), h_bias(n); + std::vector h_out1(m * n), h_out2(m * n); random_fill(h_act, -1.f, 1.f); + random_fill(h_act_scale, -1.f, 1.f); random_fill(h_scales, -1.f, 1.f); + random_fill(h_zeros, -1.f, 1.f); + random_fill(h_bias, -1.f, 1.f); for (uint8_t& v : h_weight) { @@ -393,37 +342,31 @@ bool benchmark(int m, int n, int k, int group_size, int warmup, int iter) } d_act.copy_from(h_act.data()); + d_act_scale.copy_from(h_act_scale.data()); d_weight.copy_from(h_weight.data()); d_scales.copy_from(h_scales.data()); d_zeros.copy_from(h_zeros.data()); d_bias.copy_from(h_bias.data()); + void* p_act_scale = nullptr; void* p_zeros = nullptr; void* p_bias = nullptr; - if (group_size == 64 || group_size == 128) + + if (groupsize != 0) { p_zeros = d_zeros.data(); p_bias = d_bias.data(); + p_act_scale = d_act_scale.data(); } - + wo::Params params(d_act.data(), p_act_scale, d_weight.data(), d_scales.data(), p_zeros, p_bias, d_out.data(), 1.f, + m, n, k, groupsize, KT); float time1, time2; - std::function)> benchmark_func_cuda - = benchmark_perchannel; - std::function)> benchmark_func_cutlass - = benchmark_perchannel; - if (group_size != 0) - { - benchmark_func_cuda = benchmark_groupwise; - benchmark_func_cutlass = benchmark_groupwise; - } - time1 = benchmark_func_cuda(d_act.data(), d_weight.data(), d_scales.data(), p_zeros, p_bias, d_out.data(), m, n, k, - group_size, warmup, iter); + time1 = run_cuda_kernel(params, warmup, iter); d_out.copy_to(h_out1.data()); - time2 = benchmark_func_cutlass(d_act.data(), d_weight.data(), d_scales.data(), p_zeros, p_bias, d_out.data(), m, n, - k, group_size, warmup, iter); + time2 = run_cutlass_kernel(params, warmup, iter); d_out.copy_to(h_out2.data()); - float quant_scale = 1.f / (1 << (8 / elem_per_byte - 1)); - bool pass = compare(h_out1.data(), h_out2.data(), m * n, quant_scale); + float quant_scale = 1.f / (1 << (8 / WSizeInBits - 1)); + bool pass = compare(h_out1.data(), h_out2.data(), m * n, quant_scale); printf( "cuda kernel cost time %.6f, cutlass kernel cost time %.6f, cuda speedup %.3f\n", time1, time2, time2 / time1); return pass; @@ -431,37 +374,36 @@ bool benchmark(int m, int n, int k, int group_size, int warmup, int iter) TEST(Kernel, WeightOnly) { - // Will re-enable 90 later when sm90 cuda kernels are ready - if (tensorrt_llm::common::getSMVersion() >= 90) - { - return; - } + int const arch = tensorrt_llm::common::getSMVersion(); bool pass; int warmup = 10, iter = 30; - std::vector ms{1, 2, 4}; - std::vector ns{512, 1024, 2048, 4096}; - std::vector ks{512, 1024, 2048, 4096}; - std::vector gss{0, 64, 128}; + std::vector ms{1, 2, 3, 4}; + std::vector ns{2048, 4096}; + std::vector ks{2048, 4096}; for (auto m : ms) { for (auto n : ns) { for (auto k : ks) { - for (auto gs : gss) + pass = benchmark_and_verify(m, n, k, 0, warmup, iter); + EXPECT_TRUE(pass); + pass = benchmark_and_verify(m, n, k, 0, warmup, iter); + EXPECT_TRUE(pass); + if (arch >= 80) { - pass = benchmark( - m, n, k, gs, warmup, iter); + pass = benchmark_and_verify(m, n, k, 64, warmup, iter); EXPECT_TRUE(pass); - pass = benchmark( - m, n, k, gs, warmup, iter); + pass = benchmark_and_verify(m, n, k, 128, warmup, iter); EXPECT_TRUE(pass); #if defined(ENABLE_BF16) - pass = benchmark( - m, n, k, gs, warmup, iter); + pass = benchmark_and_verify(m, n, k, 64, warmup, iter); + EXPECT_TRUE(pass); + pass = benchmark_and_verify(m, n, k, 128, warmup, iter); + EXPECT_TRUE(pass); + pass = benchmark_and_verify(m, n, k, 0, warmup, iter); EXPECT_TRUE(pass); - pass = benchmark( - m, n, k, gs, warmup, iter); + pass = benchmark_and_verify(m, n, k, 0, warmup, iter); EXPECT_TRUE(pass); #endif } diff --git a/cpp/tests/layers/baseSamplingLayerTest.cpp b/cpp/tests/layers/baseSamplingLayerTest.cpp index 9b3fbbc6e..f2d6ea1e9 100644 --- a/cpp/tests/layers/baseSamplingLayerTest.cpp +++ b/cpp/tests/layers/baseSamplingLayerTest.cpp @@ -91,7 +91,8 @@ void BaseSamplingLayerTest::setup(uint64_t seed, SamplingParams const& params typename TopKSamplingLayer::SetupParams setupParams; setupParams.randomSeed = std::make_optional>({seed}); setupParams.runtime_top_k - = params.topKs.size() ? std::make_optional>(params.topKs) : std::nullopt; + = params.topKs.size() ? std::make_optional>(params.topKs) : std::nullopt; + std::cout << "topP size " << params.topPs.size() << std::endl; setupParams.runtime_top_p = params.topPs.size() ? std::make_optional>(params.topPs) : std::nullopt; setupParams.top_p_decay = params.decay.size() ? std::make_optional>(params.decay) : std::nullopt; diff --git a/cpp/tests/layers/baseSamplingLayerTest.h b/cpp/tests/layers/baseSamplingLayerTest.h index 82c4807ba..1300991e1 100644 --- a/cpp/tests/layers/baseSamplingLayerTest.h +++ b/cpp/tests/layers/baseSamplingLayerTest.h @@ -77,7 +77,7 @@ void computeProb(T* probs, T const* logits, int batchSize, int vocabSize) struct SamplingParams { - std::vector topKs; + std::vector topKs; std::vector topPs; std::vector temperatures; std::vector repetitionPenalties; diff --git a/cpp/tests/layers/dynamicDecodeLayerTest.cpp b/cpp/tests/layers/dynamicDecodeLayerTest.cpp index f62570476..b71041cc1 100644 --- a/cpp/tests/layers/dynamicDecodeLayerTest.cpp +++ b/cpp/tests/layers/dynamicDecodeLayerTest.cpp @@ -116,49 +116,59 @@ void DynamicDecodeLayerTest::SetUp() int device; cudaGetDevice(&device); cudaGetDeviceProperties(&mDeviceProp, device); - - auto const decodingMode = mBeamWidth == 1 ? DecodingMode::TopKTopP() : DecodingMode::BeamSearch(); - - mDecodeLayer = std::make_shared>(decodingMode, mMaxBatchSize, - mBeamWidth, mVocabSize, mVocabSizePadded, mStream->get(), mAllocator, &mDeviceProp); } template -void DynamicDecodeLayerTest::setup(uint64_t seed, SamplingParams const& params) +void DynamicDecodeLayerTest::allocateData(SamplingParams const& params) { - auto const dataType = TRTDataType::value; + TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); - // clang-format off + auto const decodingMode = [this]() + { + if (this->mBeamWidth == 1) + { + if (this->mUseMedusa) + { + return DecodingMode::Medusa(); + } + else + { + return DecodingMode::TopKTopP(); + } + } + else + { + return DecodingMode::BeamSearch(); + } + }(); - // prob = (0.0, 0.0, 0.0, 0.0, 0.4, 0.3, 0.2, 0.1, 0.0) - mTestLogitsInit = { - -FLT_MAX, -FLT_MAX, -FLT_MAX, -FLT_MAX, -0.9163, -1.2040, -1.6094, -2.3026, -FLT_MAX, // step 0 - -0.9163, -1.2040, -1.6094, -2.3026, -FLT_MAX, -FLT_MAX, -FLT_MAX, -FLT_MAX, -FLT_MAX, // step 1 - -FLT_MAX, -FLT_MAX, -0.9163, -1.2040, -1.6094, -2.3026, -FLT_MAX, -FLT_MAX, -FLT_MAX, // step 2 - -0.9163, -1.2040, -1.6094, -2.3026, -FLT_MAX, -FLT_MAX, -FLT_MAX, -FLT_MAX, -FLT_MAX // step 3 - }; + mDecodeLayer = std::make_shared>(decodingMode, mMaxBatchSize, + mBeamWidth, mVocabSize, mVocabSizePadded, mStream->get(), mAllocator, &mDeviceProp, mMaxTokensPerStep, + params.maxNumMedusaHeads); - // clang-format on + auto const dataType = TRTDataType::value; - mLogitsDevice = mBufferManager->gpu(ITensor::makeShape({mBatchSize, mBeamWidth, mVocabSizePadded}), dataType); + mLogitsDevice = mBufferManager->gpu( + ITensor::makeShape({mBatchSize, mMaxTokensPerStep, mBeamWidth, mVocabSizePadded}), dataType); mRuntimeLogitsHost - = mBufferManager->pinned(ITensor::makeShape({mBatchSize, mBeamWidth, mVocabSizePadded}), dataType); + = BufferManager::pinned(ITensor::makeShape({mBatchSize, mBeamWidth, mVocabSizePadded}), dataType); mSeqLengthsDevice = mBufferManager->gpu(ITensor::makeShape({mMaxBatchSize}), nvinfer1::DataType::kINT32); mContextLengthDevice = mBufferManager->gpu(ITensor::makeShape({mMaxBatchSize}), nvinfer1::DataType::kINT32); mFinishedDevice = mBufferManager->gpu( ITensor::makeShape({mMaxBatchSize}), TRTDataType::value); - mFinishedSumDevice = mBufferManager->pinned(ITensor::makeShape({1}), nvinfer1::DataType::kFLOAT); + mFinishedSumDevice = BufferManager::pinned(ITensor::makeShape({1}), nvinfer1::DataType::kFLOAT); mOutputIdsDevice = mBufferManager->gpu(ITensor::makeShape({mMaxBatchSize, mBeamWidth, mMaxSeqLen}), nvinfer1::DataType::kINT32); - mNewTokens = mBufferManager->pinned(ITensor::makeShape({mMaxBatchSize}), nvinfer1::DataType::kINT32); + mNewTokens + = BufferManager::pinned(ITensor::makeShape({mMaxTokensPerStep, mMaxBatchSize}), nvinfer1::DataType::kINT32); mEndIdsDevice = mBufferManager->gpu(ITensor::makeShape({mMaxBatchSize}), nvinfer1::DataType::kINT32); - mEmbeddingBiasHost = mBufferManager->pinned(ITensor::makeShape({mMaxBatchSize, mVocabSizePadded}), dataType); + mEmbeddingBiasHost = BufferManager::pinned(ITensor::makeShape({mMaxBatchSize, mVocabSizePadded}), dataType); mEmbeddingBiasDevice = mBufferManager->gpu(ITensor::makeShape({mMaxBatchSize, mVocabSizePadded}), dataType); mRefLogProbsHost - = mBufferManager->pinned(ITensor::makeShape({mMaxBatchSize, mMaxSeqLen}), nvinfer1::DataType::kFLOAT); + = BufferManager::pinned(ITensor::makeShape({mMaxBatchSize, mMaxSeqLen}), nvinfer1::DataType::kFLOAT); mOutputLogProbsDevice = mBufferManager->gpu(ITensor::makeShape({mMaxBatchSize, mMaxSeqLen}), nvinfer1::DataType::kFLOAT); mOutputLogProbsTiledDevice @@ -170,29 +180,63 @@ void DynamicDecodeLayerTest::setup(uint64_t seed, SamplingParams const& param mMaxStopWordsLen = getMaxWordsLen(params.stopWords); mBadWords - = mBufferManager->pinned(ITensor::makeShape({mMaxBatchSize, 2, mMaxBadWordsLen}), nvinfer1::DataType::kINT32); - mBadWordsLens = mBufferManager->pinned(ITensor::makeShape({mMaxBatchSize}), nvinfer1::DataType::kINT32); - mBadWordsPtrs = mBufferManager->pinned(ITensor::makeShape({mMaxBatchSize}), nvinfer1::DataType::kINT64); + = BufferManager::pinned(ITensor::makeShape({mMaxBatchSize, 2, mMaxBadWordsLen}), nvinfer1::DataType::kINT32); + mBadWordsLens = BufferManager::pinned(ITensor::makeShape({mMaxBatchSize}), nvinfer1::DataType::kINT32); + mBadWordsPtrs = BufferManager::pinned(ITensor::makeShape({mMaxBatchSize}), nvinfer1::DataType::kINT64); mStopWords - = mBufferManager->pinned(ITensor::makeShape({mMaxBatchSize, 2, mMaxStopWordsLen}), nvinfer1::DataType::kINT32); - mStopWordsLens = mBufferManager->pinned(ITensor::makeShape({mMaxBatchSize}), nvinfer1::DataType::kINT32); - mStopWordsPtrs = mBufferManager->pinned(ITensor::makeShape({mMaxBatchSize}), nvinfer1::DataType::kINT64); + = BufferManager::pinned(ITensor::makeShape({mMaxBatchSize, 2, mMaxStopWordsLen}), nvinfer1::DataType::kINT32); + mStopWordsLens = BufferManager::pinned(ITensor::makeShape({mMaxBatchSize}), nvinfer1::DataType::kINT32); + mStopWordsPtrs = BufferManager::pinned(ITensor::makeShape({mMaxBatchSize}), nvinfer1::DataType::kINT64); + + mBatchSlots = BufferManager::pinned(ITensor::makeShape({mBatchSize}), nvinfer1::DataType::kINT32); - mBatchSlots = mBufferManager->pinned(ITensor::makeShape({mBatchSize}), nvinfer1::DataType::kINT32); + if (mUseMedusa) + { + auto const maxMedusaHeads = params.maxNumMedusaHeads.value(); + mPathsDevice = mBufferManager->gpu( + ITensor::makeShape({mMaxBatchSize, mMaxTokensPerStep, maxMedusaHeads + 1}), nvinfer1::DataType::kINT32); + mAcceptedLengths = mBufferManager->gpu(ITensor::makeShape({mMaxBatchSize}), nvinfer1::DataType::kINT32); + mMedusaLogitsDevice = BufferManager::pinned( + ITensor::makeShape({maxMedusaHeads, mMaxBatchSize, mMaxTokensPerStep, mVocabSizePadded}), dataType); + mNextDraftTokensDevice + = mBufferManager->gpu(ITensor::makeShape({mMaxBatchSize, mMaxTokensPerStep}), nvinfer1::DataType::kINT32); + } + + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); +} + +template +void DynamicDecodeLayerTest::setup(uint64_t seed, SamplingParams const& params) +{ + TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + + auto const dataType = TRTDataType::value; + + // clang-format off + + // prob = (0.0, 0.0, 0.0, 0.0, 0.4, 0.3, 0.2, 0.1, 0.0) + mTestLogitsInit = { + -FLT_MAX, -FLT_MAX, -FLT_MAX, -FLT_MAX, -0.9163, -1.2040, -1.6094, -2.3026, -FLT_MAX, // step 0 + -0.9163, -1.2040, -1.6094, -2.3026, -FLT_MAX, -FLT_MAX, -FLT_MAX, -FLT_MAX, -FLT_MAX, // step 1 + -FLT_MAX, -FLT_MAX, -0.9163, -1.2040, -1.6094, -2.3026, -FLT_MAX, -FLT_MAX, -FLT_MAX, // step 2 + -0.9163, -1.2040, -1.6094, -2.3026, -FLT_MAX, -FLT_MAX, -FLT_MAX, -FLT_MAX, -FLT_MAX // step 3 + }; - trk::invokeFill(*mSeqLengthsDevice, int32_t{0}, *mStream); - trk::invokeFill(*mContextLengthDevice, int32_t{0}, *mStream); + // clang-format on + + trk::invokeFill(*mSeqLengthsDevice, SizeType{0}, *mStream); + trk::invokeFill(*mContextLengthDevice, SizeType{0}, *mStream); trk::invokeFill(*mFinishedDevice, uint8_t{0}, *mStream); - trk::invokeFill(*mOutputIdsDevice, int32_t{0}, *mStream); + trk::invokeFill(*mOutputIdsDevice, TokenIdType{0}, *mStream); trk::invokeFill(*mEmbeddingBiasDevice, T{0.0f}, *mStream); trk::invokeFill(*mCumLogProbsDevice, float{0.0f}, *mStream); trk::invokeFill(*mOutputLogProbsDevice, float{0.0f}, *mStream); trk::invokeFill(*mOutputLogProbsTiledDevice, float{0.0f}, *mStream); trk::invokeFill(*mRefLogProbsHost, float{0.0f}, *mStream); - trk::invokeFill(*mEndIdsDevice, int32_t{mEndId}, *mStream); + trk::invokeFill(*mEndIdsDevice, TokenIdType{mEndId}, *mStream); - auto batchSlotsPtr = bufferCast(*mBatchSlots); + auto batchSlotsPtr = bufferCast(*mBatchSlots); for (SizeType bi = 0; bi < mBatchSize; ++bi) { batchSlotsPtr[bi] = 2 * bi; @@ -218,12 +262,55 @@ void DynamicDecodeLayerTest::setup(uint64_t seed, SamplingParams const& param mLogitsVec[bi] = tcc::toTllmTensor(*logitsSlice); } + if (mUseMedusa) + { + auto const maxMedusaHeads = params.maxNumMedusaHeads.value(); + + trk::invokeFill(*mPathsDevice, SizeType{-1}, *mStream); + trk::invokeFill(*mAcceptedLengths, SizeType{0}, *mStream); + trk::invokeFill(*mNextDraftTokensDevice, TokenIdType{mEndId}, *mStream); + + auto const logitsHost + = ITensor::wrap(mTestLogitsInit, ITensor::makeShape({mMaxTokensPerStep, mVocabSizePadded})); + for (SizeType hi = 0; hi < maxMedusaHeads; ++hi) + { + TensorPtr logitsHeadDeviceView = ITensor::slice(mMedusaLogitsDevice, hi, 1); + logitsHeadDeviceView->squeeze(0); + for (SizeType bi = 0; bi < mBatchSize; ++bi) + { + TensorPtr logitsHeadBatchDeviceView = ITensor::slice(logitsHeadDeviceView, bi, 1); + mBufferManager->copy(*logitsHost, *logitsHeadBatchDeviceView); + } + } + + auto paths = params.paths.value(); + for (SizeType bi = 0; bi < mBatchSize; ++bi) + { + auto const numPaths = static_cast(paths[bi].size() / (maxMedusaHeads + 1)); + auto const pathsHost = ITensor::wrap(paths[bi], ITensor::makeShape({1, numPaths, maxMedusaHeads + 1})); + TensorPtr pathsDeviceSlice = ITensor::slice(mPathsDevice, batchSlotsPtr[bi], 1); + pathsDeviceSlice->squeeze(0); + TensorPtr pathsNumPathsDeviceSlice = ITensor::slice(pathsDeviceSlice, 0, numPaths); + pathsNumPathsDeviceSlice->unsqueeze(0); + mBufferManager->copy(*pathsHost, *pathsNumPathsDeviceSlice); + } + + auto outputIds = params.outputIds.value(); + for (SizeType bi = 0; bi < mBatchSize; ++bi) + { + auto const outputIdsBatchHost = ITensor::wrap(outputIds[bi], ITensor::makeShape({mMaxSeqLen})); + + auto outputIdsDevice = ITensor::slice(mOutputIdsDevice, batchSlotsPtr[bi], 1); + mBufferManager->copy(*outputIdsBatchHost, *outputIdsDevice); + } + } + typename DynamicDecodeLayer::SetupParams setupParams; setupParams.randomSeed = std::make_optional>({seed}); setupParams.temperature = params.temperatures.size() ? std::make_optional>(params.temperatures) : std::nullopt; setupParams.runtime_top_k - = params.topKs.size() ? std::make_optional>(params.topKs) : std::nullopt; + = params.topKs.size() ? std::make_optional>(params.topKs) : std::nullopt; setupParams.runtime_top_p = params.topPs.size() ? std::make_optional>(params.topPs) : std::nullopt; setupParams.repetition_penalty = params.repetitionPenalties.size() @@ -236,14 +323,17 @@ void DynamicDecodeLayerTest::setup(uint64_t seed, SamplingParams const& param ? std::make_optional>(params.frequencyPenalties) : std::nullopt; setupParams.min_length - = params.minLengths.size() ? std::make_optional>(params.minLengths) : std::nullopt; + = params.minLengths.size() ? std::make_optional>(params.minLengths) : std::nullopt; setupParams.top_p_decay = params.decay.size() ? std::make_optional>(params.decay) : std::nullopt; setupParams.top_p_min = params.minTopP.size() ? std::make_optional>(params.minTopP) : std::nullopt; setupParams.top_p_reset_ids - = params.topPResetIds.size() ? std::make_optional>(params.topPResetIds) : std::nullopt; + = params.topPResetIds.size() ? std::make_optional>(params.topPResetIds) : std::nullopt; setupParams.normalize_log_probs = {false}; + setupParams.topKMedusaHeads = params.topKMedusaHeads; + setupParams.tokensPerStep = params.tokensPerStep; + initXWordsTensors(batchSlotsPtr, bufferCast(*mBadWords), reinterpret_cast(bufferCast(*mBadWordsPtrs)), bufferCast(*mBadWordsLens), mMaxBadWordsLen, params.badWords); @@ -254,6 +344,8 @@ void DynamicDecodeLayerTest::setup(uint64_t seed, SamplingParams const& param mDecodeLayer->setup(mBatchSize, mBeamWidth, batchSlotsPtr, setupParams); mStream->synchronize(); + + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template @@ -310,9 +402,11 @@ void DynamicDecodeLayerTest::initXWordsTensors(SizeType* batchSlotsPtr, SizeT } template -typename DynamicDecodeLayer::ForwardParams DynamicDecodeLayerTest::createInputTensors(int32_t step) +typename DynamicDecodeLayer::ForwardParams DynamicDecodeLayerTest::createInputTensors(SizeType step) { - constexpr int32_t ite = 0; + TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + + constexpr SizeType ite = 0; typename DynamicDecodeLayer::ForwardParams forwardParams( step, ite, mMaxInputLen, mMaxSeqLen, mSinkTokenLength, mBatchSize, tcc::toTllmTensor(*mEndIdsDevice)); @@ -339,6 +433,12 @@ typename DynamicDecodeLayer::ForwardParams DynamicDecodeLayerTest::createI forwardParams.stop_words_lengths = tcc::toTllmTensor(*mStopWordsLens); forwardParams.max_stop_words_len = mMaxStopWordsLen; + if (mUseMedusa) + { + forwardParams.paths = tcc::toTllmTensor(*mPathsDevice); + forwardParams.medusaLogits = tcc::toTllmTensor(*mMedusaLogitsDevice); + } + // TODO(nkorobov): extend to // std::optional src_cache_indirection; // std::optional sequence_limit_length; @@ -346,12 +446,16 @@ typename DynamicDecodeLayer::ForwardParams DynamicDecodeLayerTest::createI // std::optional no_repeat_ngram_size; // std::optional> logits_vec; + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); + return forwardParams; } template typename DynamicDecodeLayer::OutputParams DynamicDecodeLayerTest::createOutputTensors() { + TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + typename DynamicDecodeLayer::OutputParams outputParams(tcc::toTllmTensor(*mOutputIdsDevice)); outputParams.sequence_length = tcc::toTllmTensor(*mSeqLengthsDevice); @@ -360,55 +464,69 @@ typename DynamicDecodeLayer::OutputParams DynamicDecodeLayerTest::createOu outputParams.finished_sum = tcc::toTllmTensor(*mFinishedSumDevice); - outputParams.cum_log_probs = tcc::toTllmTensor(*mCumLogProbsDevice); - outputParams.newTokens = tcc::toTllmTensor(*mNewTokens); - outputParams.output_log_probs = tcc::toTllmTensor(*mOutputLogProbsDevice); + if (!mUseMedusa) + { + // Output log probs are not supported in Medusa + outputParams.cum_log_probs = tcc::toTllmTensor(*mCumLogProbsDevice); + + outputParams.output_log_probs = tcc::toTllmTensor(*mOutputLogProbsDevice); - outputParams.output_log_probs_tiled = tcc::toTllmTensor(*mOutputLogProbsTiledDevice); + outputParams.output_log_probs_tiled = tcc::toTllmTensor(*mOutputLogProbsTiledDevice); + } + + if (mUseMedusa) + { + outputParams.nextDraftTokens = tcc::toTllmTensor(*mNextDraftTokensDevice); + + outputParams.acceptedLengths = tcc::toTllmTensor(*mAcceptedLengths); + } // TODO(nkorobov): extend to // std::optional parent_ids; // std::optional tgt_cache_indirection; // std::shared_ptr beamHypotheses; + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); + return outputParams; } template -void DynamicDecodeLayerTest::batchCopy(int32_t step) +void DynamicDecodeLayerTest::batchCopy(SizeType step) { auto const logitsHost = ITensor::wrap(mTestLogitsInit.data() + step * mVocabSizePadded, std::is_same_v ? nvinfer1::DataType::kFLOAT : nvinfer1::DataType::kHALF, - ITensor::makeShape({1, mVocabSizePadded})); - for (int32_t bi = 0; bi < mBatchSize; ++bi) + ITensor::makeShape({mMaxTokensPerStep, mVocabSizePadded})); + for (SizeType bi = 0; bi < mBatchSize; ++bi) { - auto logitsDeviceView = ITensor::slice(mLogitsDevice, bi, 1); + TensorPtr logitsDeviceView = ITensor::slice(mLogitsDevice, bi, 1); + logitsDeviceView->squeeze(0); mBufferManager->copy(*logitsHost, *logitsDeviceView); } mLogitsRefHost = mBufferManager->copyFrom(*mLogitsDevice, tensorrt_llm::runtime::MemoryType::kCPU); } template -bool DynamicDecodeLayerTest::checkResult(int32_t* outputIds, std::vector> const& expectedIds, - int32_t* seqLens, int32_t leadingDim, int32_t stride, int32_t step) +bool DynamicDecodeLayerTest::checkResult(TokenIdType* outputIds, + std::vector> const& expectedIds, SizeType* seqLens, SizeType leadingDim, SizeType stride, + SizeType step, bool outputIdsTransposed, SizeType strideTransposed) { - assert(expectedIds.size() == leadingDim * stride); - int failures = 0; - auto const batchSlotsPtr = bufferCast(*mBatchSlots); - for (int32_t i = 0; i < leadingDim * stride; ++i) + SizeType failures = 0; + auto const batchSlotsPtr = bufferCast(*mBatchSlots); + for (SizeType i = 0; i < leadingDim * stride; ++i) { - int32_t s = i / stride; - int32_t b = i % stride; + auto const s = i / stride; + auto const b = i % stride; auto const batchSlot = batchSlotsPtr[b]; - if (seqLens[batchSlot] <= step) + if (seqLens[batchSlot] <= step + s) { continue; } - std::set expts = expectedIds.at(i + step * stride); - auto bid = batchSlot; - auto const outputId = outputIds[bid * leadingDim + s]; + auto const& expts = expectedIds.at(i + step * stride); + auto const outputIdIdx = outputIdsTransposed ? s * strideTransposed + batchSlot : batchSlot * leadingDim + s; + auto const outputId = outputIds[outputIdIdx]; if (expts.count(outputId) == 0) { if (failures < 10) @@ -417,7 +535,7 @@ bool DynamicDecodeLayerTest::checkResult(int32_t* outputIds, std::vector::checkResult(int32_t* outputIds, std::vector void DynamicDecodeLayerTest::fillRefLogits( - int32_t const* seqLenHost, std::vector> const& expectedOutputIds, SizeType step) + SizeType const* seqLenHost, std::vector> const& expectedOutputIds, SizeType step) { - auto const batchSlotsPtr = bufferCast(*mBatchSlots); + auto const batchSlotsPtr = bufferCast(*mBatchSlots); auto const runtimeLogitsHost = bufferCast(*mRuntimeLogitsHost); for (SizeType bi = 0; bi < mBatchBeam; ++bi) { @@ -454,9 +572,15 @@ void DynamicDecodeLayerTest::fillRefLogits( template void DynamicDecodeLayerTest::runTestImpl( - std::vector> const& expectedOutputIds, SamplingParams const& params, int32_t endId) + std::vector> const& expectedOutputIds, SamplingParams const& params, TokenIdType endId) { + TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + mEndId = endId == -1 ? mVocabSize - 1 : endId; + mUseMedusa = params.useMedusa; + mMaxTokensPerStep = mUseMedusa ? mMaxOutputLen - mMaxInputLen : 1; + + allocateData(params); bool greedySearch = std::all_of(expectedOutputIds.begin(), expectedOutputIds.end(), [](auto v) { return v.size() == 1; }); @@ -464,11 +588,11 @@ void DynamicDecodeLayerTest::runTestImpl( { setup(seed, params); - int32_t step = mMaxInputLen; + auto step = mMaxInputLen; auto inputTensors = createInputTensors(step); auto outputTensors = createOutputTensors(); - for (step = mMaxInputLen; step < mMaxOutputLen; ++step) + for (step = mMaxInputLen; step < mMaxOutputLen; step += mMaxTokensPerStep) { // Reset by the test value since the sampling layer internally update the logit buffer. batchCopy(step); @@ -483,14 +607,15 @@ void DynamicDecodeLayerTest::runTestImpl( mDecodeLayer->getRuntimeLogitsDevice(), *mRuntimeLogitsHost, tensorrt_llm::runtime::MemoryType::kGPU); mStream->synchronize(); - if (greedySearch) + if (greedySearch && !mUseMedusa) { - fillRefLogits(bufferCast(*seqLenHost), expectedOutputIds, step); + fillRefLogits(bufferCast(*seqLenHost), expectedOutputIds, step); } { - bool passed = checkResult(bufferCast(*newTokensHost), expectedOutputIds, - bufferCast(*seqLenHost), 1, mBatchBeam, step); + auto const passed = checkResult(bufferCast(*newTokensHost), expectedOutputIds, + bufferCast(*seqLenHost), mMaxTokensPerStep, mBatchBeam, step, /* transposed */ true, + /* stride transposed */ mMaxBatchSize * mBeamWidth); EXPECT_TRUE(passed) << "New tokens check failed at seed " << seed; if (!passed) { @@ -502,22 +627,22 @@ void DynamicDecodeLayerTest::runTestImpl( // Check if logits were not modified in-place { - bool passed = compareValues(bufferCast(*mLogitsRefHost), bufferCast(*logitsHost), - mBatchSize * mBeamWidth * mVocabSizePadded); + auto const passed = compareValues(bufferCast(*mLogitsRefHost), bufferCast(*logitsHost), + mBatchSize * mMaxTokensPerStep * mBeamWidth * mVocabSizePadded); EXPECT_TRUE(passed) << "Unmodified logits check failed at seed " << seed; } } - mStream->synchronize(); - auto const outputIdsHost = mBufferManager->copyFrom(*mOutputIdsDevice, tensorrt_llm::runtime::MemoryType::kCPU); auto const seqLenHost = mBufferManager->copyFrom(*mSeqLengthsDevice, tensorrt_llm::runtime::MemoryType::kCPU); auto const logProbsHost = mBufferManager->copyFrom(*mOutputLogProbsDevice, tensorrt_llm::runtime::MemoryType::kCPU); + mStream->synchronize(); + { - bool passed = checkResult(bufferCast(*outputIdsHost), expectedOutputIds, - bufferCast(*seqLenHost), mMaxSeqLen, mBatchBeam, /* step */ 0); + auto const passed = checkResult(bufferCast(*outputIdsHost), expectedOutputIds, + bufferCast(*seqLenHost), mMaxSeqLen, mBatchBeam, /* step */ 0); EXPECT_TRUE(passed) << "Output Ids check failed at seed " << seed; if (!passed) { @@ -527,22 +652,27 @@ void DynamicDecodeLayerTest::runTestImpl( } } - if (greedySearch) + if (greedySearch && !mUseMedusa) { - bool passed = compareValues( + auto const passed = compareValues( bufferCast(*logProbsHost), bufferCast(*mRefLogProbsHost), mMaxSeqLen * mMaxBatchSize); EXPECT_TRUE(passed) << "Log probs check failed at seed " << seed; } } + + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void DynamicDecodeLayerTest::runTest( - std::vector> const& expectedOutputIds, SamplingParams const& params, int32_t endId) + std::vector> const& expectedOutputIds, SamplingParams const& params, TokenIdType endId) { - TLLM_LOG_DEBUG("Run test with linear logits"); - mUseLogitsVec = false; - runTestImpl(expectedOutputIds, params, endId); + if (!params.useMedusa) + { + TLLM_LOG_DEBUG("Run test with linear logits"); + mUseLogitsVec = false; + runTestImpl(expectedOutputIds, params, endId); + } TLLM_LOG_DEBUG("Run test with vectorized logits"); mUseLogitsVec = true; runTestImpl(expectedOutputIds, params, endId); @@ -555,12 +685,12 @@ TYPED_TEST_SUITE(DynamicDecodeLayerTest, FloatAndHalfTypes); TYPED_TEST(DynamicDecodeLayerTest, TopK) { - uint32_t topK = 2; + SizeType topK = 2; float topP = 0.0f; SamplingParams params; params.topKs = {topK}; params.topPs = {topP}; - std::vector> expectedOutputIds{ + std::vector> expectedOutputIds{ // batch {4, 5}, {4, 5}, {4, 5}, {4, 5}, {4, 5}, {4, 5}, // step 0 {0, 1}, {0, 1}, {0, 1}, {0, 1}, {0, 1}, {0, 1}, // step 1 @@ -572,12 +702,12 @@ TYPED_TEST(DynamicDecodeLayerTest, TopK) TYPED_TEST(DynamicDecodeLayerTest, TopK1TopP0) { - uint32_t topK = 1; + SizeType topK = 1; float topP = 0.0f; SamplingParams params; params.topKs = {topK}; params.topPs = {topP}; - std::vector> expectedOutputIds{ + std::vector> expectedOutputIds{ // batch {4}, {4}, {4}, {4}, {4}, {4}, // step 0 {0}, {0}, {0}, {0}, {0}, {0}, // step 1 @@ -589,11 +719,11 @@ TYPED_TEST(DynamicDecodeLayerTest, TopK1TopP0) TYPED_TEST(DynamicDecodeLayerTest, BatchTopK) { - std::vector topKs = {2, 1, 1, 2, 1, 1}; + std::vector topKs = {2, 1, 1, 2, 1, 1}; SamplingParams params; params.topKs = topKs; params.topPs = {1.0f}; - std::vector> expectedOutputIds{ + std::vector> expectedOutputIds{ // batch {4, 5}, {4}, {4}, {4, 5}, {4}, {4}, // step 0 {0, 1}, {0}, {0}, {0, 1}, {0}, {0}, // step 1 @@ -605,12 +735,12 @@ TYPED_TEST(DynamicDecodeLayerTest, BatchTopK) TYPED_TEST(DynamicDecodeLayerTest, TopKTopP) { - uint32_t topK = 2; + SizeType topK = 2; float topP = 0.3; SamplingParams params; params.topKs = {topK}; params.topPs = {topP}; - std::vector> expectedOutputIds{ + std::vector> expectedOutputIds{ // batch {4}, {4}, {4}, {4}, {4}, {4}, // step 0 {0}, {0}, {0}, {0}, {0}, {0}, // step 1 @@ -622,12 +752,12 @@ TYPED_TEST(DynamicDecodeLayerTest, TopKTopP) TYPED_TEST(DynamicDecodeLayerTest, BatchTopKTopP) { - std::vector topKs = {2, 2, 1, 2, 2, 1}; + std::vector topKs = {2, 2, 1, 2, 2, 1}; float topP = 0.3; SamplingParams params; params.topKs = topKs; params.topPs = {topP}; - std::vector> expectedOutputIds{ + std::vector> expectedOutputIds{ // batch {4}, {4}, {4}, {4}, {4}, {4}, // step 0 {0}, {0}, {0}, {0}, {0}, {0}, // step 1 @@ -639,12 +769,12 @@ TYPED_TEST(DynamicDecodeLayerTest, BatchTopKTopP) TYPED_TEST(DynamicDecodeLayerTest, TopKBatchTopP) { - uint32_t topK = 2; + SizeType topK = 2; std::vector topPs = {0.5, 0.3, 0.5, 0.5, 0.3, 0.5}; SamplingParams params; params.topKs = {topK}; params.topPs = topPs; - std::vector> expectedOutputIds{ + std::vector> expectedOutputIds{ // batch {4, 5}, {4}, {4, 5}, {4, 5}, {4}, {4, 5}, // step 0 {0, 1}, {0}, {0, 1}, {0, 1}, {0}, {0, 1}, // step 1 @@ -656,12 +786,12 @@ TYPED_TEST(DynamicDecodeLayerTest, TopKBatchTopP) TYPED_TEST(DynamicDecodeLayerTest, BatchTopKBatchTopP) { - std::vector topKs = {2, 2, 0, 2, 2, 1}; + std::vector topKs = {2, 2, 0, 2, 2, 1}; std::vector topPs = {0.0, 0.3, 0.5, 0.0, 0.3, 0.5}; SamplingParams params; params.topKs = topKs; params.topPs = topPs; - std::vector> expectedOutputIds{ + std::vector> expectedOutputIds{ // batch {4, 5}, {4}, {4, 5}, {4, 5}, {4}, {4}, // step 0 {0, 1}, {0}, {0, 1}, {0, 1}, {0}, {0}, // step 1 @@ -673,10 +803,10 @@ TYPED_TEST(DynamicDecodeLayerTest, BatchTopKBatchTopP) TYPED_TEST(DynamicDecodeLayerTest, InvalidArgsZeroTopK) { - uint32_t topK = 0; + SizeType topK = 0; SamplingParams params; params.topKs = {topK}; - std::vector> expectedOutputIds{ + std::vector> expectedOutputIds{ // batch {4}, {4}, {4}, {4}, {4}, {4}, // step 0 {0}, {0}, {0}, {0}, {0}, {0}, // step 1 @@ -691,7 +821,7 @@ TYPED_TEST(DynamicDecodeLayerTest, InvalidArgsZeroTopP) float topP = 0; SamplingParams params; params.topPs = {topP}; - std::vector> expectedOutputIds{ + std::vector> expectedOutputIds{ // batch {4}, {4}, {4}, {4}, {4}, {4}, // step 0 {0}, {0}, {0}, {0}, {0}, {0}, // step 1 @@ -703,12 +833,12 @@ TYPED_TEST(DynamicDecodeLayerTest, InvalidArgsZeroTopP) TYPED_TEST(DynamicDecodeLayerTest, InvalidArgsZeroTopKTopP) { - uint32_t topK = 0; + SizeType topK = 0; float topP = 0; SamplingParams params; params.topPs = {topP}; params.topKs = {topK}; - std::vector> expectedOutputIds{ + std::vector> expectedOutputIds{ // batch {4}, {4}, {4}, {4}, {4}, {4}, // step 0 {0}, {0}, {0}, {0}, {0}, {0}, // step 1 @@ -720,12 +850,12 @@ TYPED_TEST(DynamicDecodeLayerTest, InvalidArgsZeroTopKTopP) TYPED_TEST(DynamicDecodeLayerTest, InvalidArgsZeroBatchTopKTopP) { - std::vector topKs = {0, 0, 0, 0, 0, 0}; + std::vector topKs = {0, 0, 0, 0, 0, 0}; float topP = 0; SamplingParams params; params.topPs = {topP}; params.topKs = topKs; - std::vector> expectedOutputIds{ + std::vector> expectedOutputIds{ // batch {4}, {4}, {4}, {4}, {4}, {4}, // step 0 {0}, {0}, {0}, {0}, {0}, {0}, // step 1 @@ -737,12 +867,12 @@ TYPED_TEST(DynamicDecodeLayerTest, InvalidArgsZeroBatchTopKTopP) TYPED_TEST(DynamicDecodeLayerTest, InvalidArgsZeroTopKBatchTopP) { - uint32_t topK = 0; + SizeType topK = 0; std::vector topPs = {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}; SamplingParams params; params.topPs = topPs; params.topKs = {topK}; - std::vector> expectedOutputIds{ + std::vector> expectedOutputIds{ // batch {4}, {4}, {4}, {4}, {4}, {4}, // step 0 {0}, {0}, {0}, {0}, {0}, {0}, // step 1 @@ -754,10 +884,10 @@ TYPED_TEST(DynamicDecodeLayerTest, InvalidArgsZeroTopKBatchTopP) TYPED_TEST(DynamicDecodeLayerTest, InvalidArgsBatchTopKContainZero) { - std::vector topKs = {2, 1, 0, 0, 2, 1}; + std::vector topKs = {2, 1, 0, 0, 2, 1}; SamplingParams params; params.topKs = topKs; - std::vector> expectedOutputIds{ + std::vector> expectedOutputIds{ // batch {4, 5}, {4}, {4}, {4}, {4, 5}, {4}, // step 0 {0, 1}, {0}, {0}, {0}, {0, 1}, {0}, // step 1 @@ -769,12 +899,12 @@ TYPED_TEST(DynamicDecodeLayerTest, InvalidArgsBatchTopKContainZero) TYPED_TEST(DynamicDecodeLayerTest, InvalidArgsBatchTopKTopPContainZero) { - std::vector topKs = {2, 2, 1, 0, 2, 0}; + std::vector topKs = {2, 2, 1, 0, 2, 0}; float topP = 0.0; SamplingParams params; params.topPs = {topP}; params.topKs = topKs; - std::vector> expectedOutputIds{ + std::vector> expectedOutputIds{ // batch {4, 5}, {4, 5}, {4}, {4}, {4, 5}, {4}, // step 0 {0, 1}, {0, 1}, {0}, {0}, {0, 1}, {0}, // step 1 @@ -786,12 +916,12 @@ TYPED_TEST(DynamicDecodeLayerTest, InvalidArgsBatchTopKTopPContainZero) TYPED_TEST(DynamicDecodeLayerTest, InvalidArgsBatchTopKBatchTopPContainZero) { - std::vector topKs = {0, 2, 1, 2, 2, 0}; + std::vector topKs = {0, 2, 1, 2, 2, 0}; std::vector topPs = {0.0, 0.3, 0.9, 0.0, 0.3, 0.5}; SamplingParams params; params.topPs = topPs; params.topKs = topKs; - std::vector> expectedOutputIds{ + std::vector> expectedOutputIds{ // batch {4}, {4}, {4}, {4, 5}, {4}, {4, 5}, // step 0 {0}, {0}, {0}, {0, 1}, {0}, {0, 1}, // step 1 @@ -807,7 +937,7 @@ TYPED_TEST(DynamicDecodeLayerTest, TopPTemperature) SamplingParams params; params.temperatures = {temperature}; params.topPs = {1.0f}; - std::vector> expectedOutputIds{ + std::vector> expectedOutputIds{ {4}, {4}, {4}, {4}, {4}, {4}, // step 0 {0}, {0}, {0}, {0}, {0}, {0}, // step 1 {2}, {2}, {2}, {2}, {2}, {2}, // step 2 @@ -822,7 +952,7 @@ TYPED_TEST(DynamicDecodeLayerTest, TopPTemperatureBatch) SamplingParams params; params.temperatures = temperatures; params.topPs = {0.5f}; - std::vector> expectedOutputIds{ + std::vector> expectedOutputIds{ {4}, {4, 5, 6, 7}, {4, 5}, {4, 5}, {4}, {4, 5}, // step 0 {0}, {0, 1, 2, 3}, {0, 1}, {0, 1}, {0}, {0, 1}, // step 1 {2}, {2, 3, 4, 5}, {2, 3}, {2, 3}, {2}, {2, 3}, // step 2 @@ -833,12 +963,12 @@ TYPED_TEST(DynamicDecodeLayerTest, TopPTemperatureBatch) TYPED_TEST(DynamicDecodeLayerTest, TopPRepetitionPenalty) { - uint32_t topK = 1; + SizeType topK = 1; float repetitionPenalty = 1e9f; SamplingParams params; params.repetitionPenalties = {repetitionPenalty}; params.topPs = {0.3f}; - std::vector> expectedOutputIds{ + std::vector> expectedOutputIds{ // batch {4}, {4}, {4}, {4}, {4}, {4}, // step 0 {0}, {0}, {0}, {0}, {0}, {0}, // step 1 @@ -854,7 +984,7 @@ TYPED_TEST(DynamicDecodeLayerTest, TopPRepetitionPenaltiesBatch) SamplingParams params; params.repetitionPenalties = repetitionPenalties; params.topPs = {0.3f}; - std::vector> expectedOutputIds{ + std::vector> expectedOutputIds{ // batch {4}, {4}, {4}, {4}, {4}, {4}, // step 0 {0}, {0}, {0}, {0}, {0}, {0}, // step 1 @@ -870,7 +1000,7 @@ TYPED_TEST(DynamicDecodeLayerTest, TopPPresencePenalty) SamplingParams params; params.presencePenalties = {presencePenalty}; params.topPs = {0.3f}; - std::vector> expectedOutputIds{ + std::vector> expectedOutputIds{ // batch {4}, {4}, {4}, {4}, {4}, {4}, // step 0 {0}, {0}, {0}, {0}, {0}, {0}, // step 1 @@ -886,7 +1016,7 @@ TYPED_TEST(DynamicDecodeLayerTest, TopPPresencePenaltiesBatch) SamplingParams params; params.presencePenalties = presencePenalties; params.topPs = {0.3f}; - std::vector> expectedOutputIds{ + std::vector> expectedOutputIds{ // batch {4}, {4}, {4}, {4}, {4}, {4}, // step 0 {0}, {0}, {0}, {0}, {0}, {0}, // step 1 @@ -902,7 +1032,7 @@ TYPED_TEST(DynamicDecodeLayerTest, TopPFrequencyPenalty) SamplingParams params; params.frequencyPenalties = {frequencyPenalty}; params.topPs = {0.3f}; - std::vector> expectedOutputIds{ + std::vector> expectedOutputIds{ // batch {4}, {4}, {4}, {4}, {4}, {4}, // step 0 {0}, {0}, {0}, {0}, {0}, {0}, // step 1 @@ -918,7 +1048,7 @@ TYPED_TEST(DynamicDecodeLayerTest, TopPFrequencyPenaltiesBatch) SamplingParams params; params.frequencyPenalties = frequencyPenalties; params.topPs = {0.3f}; - std::vector> expectedOutputIds{ + std::vector> expectedOutputIds{ // batch {4}, {4}, {4}, {4}, {4}, {4}, // step 0 {0}, {0}, {0}, {0}, {0}, {0}, // step 1 @@ -936,7 +1066,7 @@ TYPED_TEST(DynamicDecodeLayerTest, TopPRepetitionPresencePenalty) params.repetitionPenalties = {repetitionPenalty}; params.presencePenalties = {presencePenalty}; params.topPs = {0.3f}; - std::vector> expectedOutputIds{ + std::vector> expectedOutputIds{ // batch {4}, {4}, {4}, {4}, {4}, {4}, // step 0 {0}, {0}, {0}, {0}, {0}, {0}, // step 1 @@ -954,7 +1084,7 @@ TYPED_TEST(DynamicDecodeLayerTest, TopPRepetitionPresencePenaltiesBatch) params.repetitionPenalties = {repetitionPenalties}; params.presencePenalties = {presencePenalties}; params.topPs = {0.3f}; - std::vector> expectedOutputIds{ + std::vector> expectedOutputIds{ // batch {4}, {4}, {4}, {4}, {4}, {4}, // step 0 {0}, {0}, {0}, {0}, {0}, {0}, // step 1 @@ -972,7 +1102,7 @@ TYPED_TEST(DynamicDecodeLayerTest, TopPRepetitionFrequencyPenalty) params.repetitionPenalties = {repetitionPenalty}; params.frequencyPenalties = {frequencyPenalty}; params.topPs = {0.3f}; - std::vector> expectedOutputIds{ + std::vector> expectedOutputIds{ // batch {4}, {4}, {4}, {4}, {4}, {4}, // step 0 {0}, {0}, {0}, {0}, {0}, {0}, // step 1 @@ -990,7 +1120,7 @@ TYPED_TEST(DynamicDecodeLayerTest, TopPRepetitionFrequencyPenaltiesBatch) params.repetitionPenalties = {repetitionPenalties}; params.frequencyPenalties = {frequencyPenalties}; params.topPs = {0.3f}; - std::vector> expectedOutputIds{ + std::vector> expectedOutputIds{ // batch {4}, {4}, {4}, {4}, {4}, {4}, // step 0 {0}, {0}, {0}, {0}, {0}, {0}, // step 1 @@ -1008,7 +1138,7 @@ TYPED_TEST(DynamicDecodeLayerTest, TopPPresenceFrequencyPenalty) params.presencePenalties = {presencePenalty}; params.frequencyPenalties = {frequencyPenalty}; params.topPs = {0.3f}; - std::vector> expectedOutputIds{ + std::vector> expectedOutputIds{ // batch {4}, {4}, {4}, {4}, {4}, {4}, // step 0 {0}, {0}, {0}, {0}, {0}, {0}, // step 1 @@ -1026,7 +1156,7 @@ TYPED_TEST(DynamicDecodeLayerTest, TopPPresenceFrequencyPenaltiesBatch) params.presencePenalties = {presencePenalties}; params.frequencyPenalties = {frequencyPenalties}; params.topPs = {0.3f}; - std::vector> expectedOutputIds{ + std::vector> expectedOutputIds{ // batch {4}, {4}, {4}, {4}, {4}, {4}, // step 0 {0}, {0}, {0}, {0}, {0}, {0}, // step 1 @@ -1046,7 +1176,7 @@ TYPED_TEST(DynamicDecodeLayerTest, TopPFullPenalty) params.presencePenalties = {presencePenalty}; params.frequencyPenalties = {frequencyPenalty}; params.topPs = {0.3f}; - std::vector> expectedOutputIds{ + std::vector> expectedOutputIds{ // batch {4}, {4}, {4}, {4}, {4}, {4}, // step 0 {0}, {0}, {0}, {0}, {0}, {0}, // step 1 @@ -1066,7 +1196,7 @@ TYPED_TEST(DynamicDecodeLayerTest, TopPFullPenaltiesBatch) params.presencePenalties = {presencePenalties}; params.frequencyPenalties = {frequencyPenalties}; params.topPs = {0.3f}; - std::vector> expectedOutputIds{ + std::vector> expectedOutputIds{ // batch {4}, {4}, {4}, {4}, {4}, {4}, // step 0 {0}, {0}, {0}, {0}, {0}, {0}, // step 1 @@ -1078,12 +1208,12 @@ TYPED_TEST(DynamicDecodeLayerTest, TopPFullPenaltiesBatch) TYPED_TEST(DynamicDecodeLayerTest, TopPMinLengthBatch) { - std::vector minLengths = {3, 1, 1, 3, 0, 3}; + std::vector minLengths = {3, 1, 1, 3, 0, 3}; SamplingParams params; params.minLengths = minLengths; params.topPs = {0.3f}; - int32_t const endId = 0; - std::vector> expectedOutputIds{ + TokenIdType const endId = 0; + std::vector> expectedOutputIds{ // batch {4}, {4}, {4}, {4}, {4}, {4}, // step 0 {1}, {0}, {0}, {1}, {0}, {1}, // step 1 @@ -1098,7 +1228,7 @@ TYPED_TEST(DynamicDecodeLayerTest, TopPBias) SamplingParams params; params.topPs = {0.5f}; params.useBias = true; - std::vector> expectedOutputIds{ + std::vector> expectedOutputIds{ // batch {4, 5}, {4, 5}, {4, 5}, {4, 5}, {4, 5}, {4, 5}, // step 0 {2, 3}, {2, 3}, {2, 3}, {2, 3}, {2, 3}, {2, 3}, // step 1 @@ -1110,13 +1240,13 @@ TYPED_TEST(DynamicDecodeLayerTest, TopPBias) TYPED_TEST(DynamicDecodeLayerTest, TopKTemperature) { - uint32_t topK = 2; + SizeType topK = 2; float temperature = 0.05f; SamplingParams params; params.temperatures = {temperature}; params.topKs = {topK}; params.topPs = {1.0f}; - std::vector> expectedOutputIds{ + std::vector> expectedOutputIds{ {4}, {4}, {4}, {4}, {4}, {4}, // step 0 {0}, {0}, {0}, {0}, {0}, {0}, // step 1 {2}, {2}, {2}, {2}, {2}, {2}, // step 2 @@ -1127,13 +1257,13 @@ TYPED_TEST(DynamicDecodeLayerTest, TopKTemperature) TYPED_TEST(DynamicDecodeLayerTest, TopKTemperatureBatch) { - uint32_t topK = 2; + SizeType topK = 2; std::vector temperatures = {0.05f, 1e3f, 1.0f, 0.5f, 0.05f, 1.0f}; SamplingParams params; params.temperatures = temperatures; params.topKs = {topK}; params.topPs = {1.0f}; - std::vector> expectedOutputIds{ + std::vector> expectedOutputIds{ {4}, {4, 5, 6, 7}, {4, 5}, {4, 5}, {4}, {4, 5}, // step 0 {0}, {0, 1, 2, 3}, {0, 1}, {0, 1}, {0}, {0, 1}, // step 1 {2}, {2, 3, 4, 5}, {2, 3}, {2, 3}, {2}, {2, 3}, // step 2 @@ -1144,13 +1274,13 @@ TYPED_TEST(DynamicDecodeLayerTest, TopKTemperatureBatch) TYPED_TEST(DynamicDecodeLayerTest, TopKRepetitionPenalty) { - uint32_t topK = 1; + SizeType topK = 1; float repetitionPenalty = 1e9f; SamplingParams params; params.repetitionPenalties = {repetitionPenalty}; params.topKs = {topK}; params.topPs = {1.0f}; - std::vector> expectedOutputIds{ + std::vector> expectedOutputIds{ // batch {4}, {4}, {4}, {4}, {4}, {4}, // step 0 {0}, {0}, {0}, {0}, {0}, {0}, // step 1 @@ -1162,13 +1292,13 @@ TYPED_TEST(DynamicDecodeLayerTest, TopKRepetitionPenalty) TYPED_TEST(DynamicDecodeLayerTest, TopKRepetitionPenaltiesBatch) { - uint32_t topK = 1; + SizeType topK = 1; std::vector repetitionPenalties = {1e9f, 1e9f, 1.0f, 1.0f, 1.0f, 1e9f}; SamplingParams params; params.repetitionPenalties = repetitionPenalties; params.topKs = {topK}; params.topPs = {1.0f}; - std::vector> expectedOutputIds{ + std::vector> expectedOutputIds{ // batch {4}, {4}, {4}, {4}, {4}, {4}, // step 0 {0}, {0}, {0}, {0}, {0}, {0}, // step 1 @@ -1180,13 +1310,13 @@ TYPED_TEST(DynamicDecodeLayerTest, TopKRepetitionPenaltiesBatch) TYPED_TEST(DynamicDecodeLayerTest, TopKPresencePenalty) { - uint32_t topK = 1; + SizeType topK = 1; float presencePenalty = 1e9f; SamplingParams params; params.presencePenalties = {presencePenalty}; params.topKs = {topK}; params.topPs = {1.0f}; - std::vector> expectedOutputIds{ + std::vector> expectedOutputIds{ // batch {4}, {4}, {4}, {4}, {4}, {4}, // step 0 {0}, {0}, {0}, {0}, {0}, {0}, // step 1 @@ -1198,13 +1328,13 @@ TYPED_TEST(DynamicDecodeLayerTest, TopKPresencePenalty) TYPED_TEST(DynamicDecodeLayerTest, TopKPresencePenaltiesBatch) { - uint32_t topK = 1; + SizeType topK = 1; std::vector presencePenalties = {1e9f, 1e9f, 0.0f, 0.0f, 0.0f, 1e9f}; SamplingParams params; params.presencePenalties = presencePenalties; params.topKs = {topK}; params.topPs = {1.0f}; - std::vector> expectedOutputIds{ + std::vector> expectedOutputIds{ // batch {4}, {4}, {4}, {4}, {4}, {4}, // step 0 {0}, {0}, {0}, {0}, {0}, {0}, // step 1 @@ -1216,13 +1346,13 @@ TYPED_TEST(DynamicDecodeLayerTest, TopKPresencePenaltiesBatch) TYPED_TEST(DynamicDecodeLayerTest, TopKFrequencyPenalty) { - uint32_t topK = 1; + SizeType topK = 1; float frequencyPenalty = 1e9f; SamplingParams params; params.frequencyPenalties = {frequencyPenalty}; params.topKs = {topK}; params.topPs = {1.0f}; - std::vector> expectedOutputIds{ + std::vector> expectedOutputIds{ // batch {4}, {4}, {4}, {4}, {4}, {4}, // step 0 {0}, {0}, {0}, {0}, {0}, {0}, // step 1 @@ -1234,13 +1364,13 @@ TYPED_TEST(DynamicDecodeLayerTest, TopKFrequencyPenalty) TYPED_TEST(DynamicDecodeLayerTest, TopKFrequencyPenaltiesBatch) { - uint32_t topK = 1; + SizeType topK = 1; std::vector frequencyPenalties = {1e9f, 1e9f, 0.0f, 0.0f, 0.0f, 1e9f}; SamplingParams params; params.frequencyPenalties = frequencyPenalties; params.topKs = {topK}; params.topPs = {1.0f}; - std::vector> expectedOutputIds{ + std::vector> expectedOutputIds{ // batch {4}, {4}, {4}, {4}, {4}, {4}, // step 0 {0}, {0}, {0}, {0}, {0}, {0}, // step 1 @@ -1252,7 +1382,7 @@ TYPED_TEST(DynamicDecodeLayerTest, TopKFrequencyPenaltiesBatch) TYPED_TEST(DynamicDecodeLayerTest, TopKRepetitionPresencePenalty) { - uint32_t topK = 1; + SizeType topK = 1; float repetitionPenalty = 1e9f; float presencePenalty = 1e9f; SamplingParams params; @@ -1260,7 +1390,7 @@ TYPED_TEST(DynamicDecodeLayerTest, TopKRepetitionPresencePenalty) params.presencePenalties = {presencePenalty}; params.topKs = {topK}; params.topPs = {1.0f}; - std::vector> expectedOutputIds{ + std::vector> expectedOutputIds{ // batch {4}, {4}, {4}, {4}, {4}, {4}, // step 0 {0}, {0}, {0}, {0}, {0}, {0}, // step 1 @@ -1272,7 +1402,7 @@ TYPED_TEST(DynamicDecodeLayerTest, TopKRepetitionPresencePenalty) TYPED_TEST(DynamicDecodeLayerTest, TopKRepetitionPresencePenaltiesBatch) { - uint32_t topK = 1; + SizeType topK = 1; std::vector repetitionPenalties = {1e9f, 1e9f, 1.0f, 1.0f, 1.0f, 1e9f}; std::vector presencePenalties = {1e9f, 1e9f, 0.0f, 0.0f, 0.0f, 1e9f}; SamplingParams params; @@ -1280,7 +1410,7 @@ TYPED_TEST(DynamicDecodeLayerTest, TopKRepetitionPresencePenaltiesBatch) params.presencePenalties = {presencePenalties}; params.topKs = {topK}; params.topPs = {1.0f}; - std::vector> expectedOutputIds{ + std::vector> expectedOutputIds{ // batch {4}, {4}, {4}, {4}, {4}, {4}, // step 0 {0}, {0}, {0}, {0}, {0}, {0}, // step 1 @@ -1292,7 +1422,7 @@ TYPED_TEST(DynamicDecodeLayerTest, TopKRepetitionPresencePenaltiesBatch) TYPED_TEST(DynamicDecodeLayerTest, TopKRepetitionFrequencyPenalty) { - uint32_t topK = 1; + SizeType topK = 1; float repetitionPenalty = 1e9f; float frequencyPenalty = 1e9f; SamplingParams params; @@ -1300,7 +1430,7 @@ TYPED_TEST(DynamicDecodeLayerTest, TopKRepetitionFrequencyPenalty) params.frequencyPenalties = {frequencyPenalty}; params.topKs = {topK}; params.topPs = {1.0f}; - std::vector> expectedOutputIds{ + std::vector> expectedOutputIds{ // batch {4}, {4}, {4}, {4}, {4}, {4}, // step 0 {0}, {0}, {0}, {0}, {0}, {0}, // step 1 @@ -1312,7 +1442,7 @@ TYPED_TEST(DynamicDecodeLayerTest, TopKRepetitionFrequencyPenalty) TYPED_TEST(DynamicDecodeLayerTest, TopKRepetitionFrequencyPenaltiesBatch) { - uint32_t topK = 1; + SizeType topK = 1; std::vector repetitionPenalties = {1e9f, 1e9f, 1.0f, 1.0f, 1.0f, 1e9f}; std::vector frequencyPenalties = {1e9f, 1e9f, 0.0f, 0.0f, 0.0f, 1e9f}; SamplingParams params; @@ -1320,7 +1450,7 @@ TYPED_TEST(DynamicDecodeLayerTest, TopKRepetitionFrequencyPenaltiesBatch) params.frequencyPenalties = {frequencyPenalties}; params.topKs = {topK}; params.topPs = {1.0f}; - std::vector> expectedOutputIds{ + std::vector> expectedOutputIds{ // batch {4}, {4}, {4}, {4}, {4}, {4}, // step 0 {0}, {0}, {0}, {0}, {0}, {0}, // step 1 @@ -1332,7 +1462,7 @@ TYPED_TEST(DynamicDecodeLayerTest, TopKRepetitionFrequencyPenaltiesBatch) TYPED_TEST(DynamicDecodeLayerTest, TopKPresenceFrequencyPenalty) { - uint32_t topK = 1; + SizeType topK = 1; float presencePenalty = 1e9f; float frequencyPenalty = 1e9f; SamplingParams params; @@ -1340,7 +1470,7 @@ TYPED_TEST(DynamicDecodeLayerTest, TopKPresenceFrequencyPenalty) params.frequencyPenalties = {frequencyPenalty}; params.topKs = {topK}; params.topPs = {1.0f}; - std::vector> expectedOutputIds{ + std::vector> expectedOutputIds{ // batch {4}, {4}, {4}, {4}, {4}, {4}, // step 0 {0}, {0}, {0}, {0}, {0}, {0}, // step 1 @@ -1352,7 +1482,7 @@ TYPED_TEST(DynamicDecodeLayerTest, TopKPresenceFrequencyPenalty) TYPED_TEST(DynamicDecodeLayerTest, TopKPresenceFrequencyPenaltiesBatch) { - uint32_t topK = 1; + SizeType topK = 1; std::vector presencePenalties = {1e9f, 1e9f, 0.0f, 0.0f, 0.0f, 1e9f}; std::vector frequencyPenalties = {1e9f, 1e9f, 0.0f, 0.0f, 0.0f, 1e9f}; SamplingParams params; @@ -1360,7 +1490,7 @@ TYPED_TEST(DynamicDecodeLayerTest, TopKPresenceFrequencyPenaltiesBatch) params.frequencyPenalties = {frequencyPenalties}; params.topKs = {topK}; params.topPs = {1.0f}; - std::vector> expectedOutputIds{ + std::vector> expectedOutputIds{ // batch {4}, {4}, {4}, {4}, {4}, {4}, // step 0 {0}, {0}, {0}, {0}, {0}, {0}, // step 1 @@ -1372,7 +1502,7 @@ TYPED_TEST(DynamicDecodeLayerTest, TopKPresenceFrequencyPenaltiesBatch) TYPED_TEST(DynamicDecodeLayerTest, TopKFullPenalty) { - uint32_t topK = 1; + SizeType topK = 1; float repetitionPenalty = 1e9f; float presencePenalty = 1e9f; float frequencyPenalty = 1e9f; @@ -1382,7 +1512,7 @@ TYPED_TEST(DynamicDecodeLayerTest, TopKFullPenalty) params.frequencyPenalties = {frequencyPenalty}; params.topKs = {topK}; params.topPs = {1.0f}; - std::vector> expectedOutputIds{ + std::vector> expectedOutputIds{ // batch {4}, {4}, {4}, {4}, {4}, {4}, // step 0 {0}, {0}, {0}, {0}, {0}, {0}, // step 1 @@ -1394,7 +1524,7 @@ TYPED_TEST(DynamicDecodeLayerTest, TopKFullPenalty) TYPED_TEST(DynamicDecodeLayerTest, TopKFullPenaltiesBatch) { - uint32_t topK = 1; + SizeType topK = 1; std::vector repetitionPenalties = {1e9f, 1e9f, 1.0f, 1.0f, 1.0f, 1e9f}; std::vector presencePenalties = {1e9f, 1e9f, 0.0f, 0.0f, 0.0f, 1e9f}; std::vector frequencyPenalties = {1e9f, 1e9f, 0.0f, 0.0f, 0.0f, 1e9f}; @@ -1404,7 +1534,7 @@ TYPED_TEST(DynamicDecodeLayerTest, TopKFullPenaltiesBatch) params.frequencyPenalties = {frequencyPenalties}; params.topKs = {topK}; params.topPs = {1.0f}; - std::vector> expectedOutputIds{ + std::vector> expectedOutputIds{ // batch {4}, {4}, {4}, {4}, {4}, {4}, // step 0 {0}, {0}, {0}, {0}, {0}, {0}, // step 1 @@ -1416,14 +1546,14 @@ TYPED_TEST(DynamicDecodeLayerTest, TopKFullPenaltiesBatch) TYPED_TEST(DynamicDecodeLayerTest, TopKMinLengthBatch) { - uint32_t topK = 1; - std::vector minLengths = {3, 1, 1, 3, 0, 3}; + SizeType topK = 1; + std::vector minLengths = {3, 1, 1, 3, 0, 3}; SamplingParams params; params.minLengths = minLengths; params.topKs = {topK}; params.topPs = {1.0f}; - int32_t const endId = 0; - std::vector> expectedOutputIds{ + TokenIdType const endId = 0; + std::vector> expectedOutputIds{ // batch {4}, {4}, {4}, {4}, {4}, {4}, // step 0 {1}, {0}, {0}, {1}, {0}, {1}, // step 1 @@ -1435,12 +1565,12 @@ TYPED_TEST(DynamicDecodeLayerTest, TopKMinLengthBatch) TYPED_TEST(DynamicDecodeLayerTest, TopKBias) { - uint32_t topK = 2; + SizeType topK = 2; SamplingParams params; params.topKs = {topK}; params.topPs = {1.0f}; params.useBias = true; - std::vector> expectedOutputIds{ + std::vector> expectedOutputIds{ // batch {4, 5}, {4, 5}, {4, 5}, {4, 5}, {4, 5}, {4, 5}, // step 0 {2, 3}, {2, 3}, {2, 3}, {2, 3}, {2, 3}, {2, 3}, // step 1 @@ -1452,12 +1582,12 @@ TYPED_TEST(DynamicDecodeLayerTest, TopKBias) TYPED_TEST(DynamicDecodeLayerTest, BadWords) { - uint32_t topK = 1; + SizeType topK = 1; SamplingParams params; params.topKs = {topK}; params.topPs = {1.0f}; params.badWords = {{{4, 0}, {2}}, {{0, 2}}, {{4, 0, 2}, {4, 0, 3, 0}}, {{3}}, {{4}, {5}}, {{0}, {3}}}; - std::vector> expectedOutputIds{ + std::vector> expectedOutputIds{ // batch {4}, {4}, {4}, {4}, {6}, {4}, // step 0 {1}, {0}, {0}, {0}, {0}, {1}, // step 1 @@ -1469,12 +1599,12 @@ TYPED_TEST(DynamicDecodeLayerTest, BadWords) TYPED_TEST(DynamicDecodeLayerTest, StopWords) { - uint32_t topK = 1; + SizeType topK = 1; SamplingParams params; params.topKs = {topK}; params.topPs = {1.0f}; params.stopWords = {{{4, 0}, {2}}, {{0, 2}}, {{4, 0, 2}}, {{3}}, {{4}, {5}}, {{4, 0, 2, 0}}}; - std::vector> expectedOutputIds{ + std::vector> expectedOutputIds{ // batch {4}, {4}, {4}, {4}, {4}, {4}, // step 0 {0}, {0}, {0}, {0}, {0}, {0}, // step 1 @@ -1483,4 +1613,34 @@ TYPED_TEST(DynamicDecodeLayerTest, StopWords) }; this->runTest(expectedOutputIds, params); } + +TYPED_TEST(DynamicDecodeLayerTest, MedusaSimpleTest) +{ + SamplingParams params; + params.topKs = {1, 1, 1, 1, 1, 1}; + params.topKMedusaHeads = {{3, 1}, {1, 3}, {3, 1}, {2, 2}, {2, 2}, {1, 3}}; + params.tokensPerStep = {4, 4, 4, 4, 4, 4}; + params.maxNumMedusaHeads = 2; + // clang-format off + params.paths = {{0, 1, 2, + 0, 3, -1}, + {0, 1, -1, + 0, -1, -1}, + {0, 1, 3}, + {0, 2, 3}, + {0, 2, -1}, + {0, 3, -1}}; + // clang-format on + params.outputIds = {{0, 4, 0, 2}, {0, 4, 0, 2}, {0, 4, 0, 0}, {0, 4, 4, 2}, {0, 4, 0, 2}, {0, 4, 0, 2}}; + params.useMedusa = true; + std::vector> expectedOutputIds{ + // batch + {4}, {4}, {4}, {4}, {4}, {4}, // step 0 + {0}, {0}, {0}, {2}, {4}, {4}, // step 1 + {2}, {0}, {0}, {0}, {0}, {0}, // step 2 + {2}, {2}, {0}, {2}, {2}, {2} // step 3 + }; + this->runTest(expectedOutputIds, params); +} + } // namespace tensorrt_llm::tests::layers::sampling diff --git a/cpp/tests/layers/dynamicDecodeLayerTest.h b/cpp/tests/layers/dynamicDecodeLayerTest.h index a82e81f0f..9ec49fe3f 100644 --- a/cpp/tests/layers/dynamicDecodeLayerTest.h +++ b/cpp/tests/layers/dynamicDecodeLayerTest.h @@ -27,6 +27,7 @@ #include "tensorrt_llm/kernels/samplingTopKKernels.h" #include "tensorrt_llm/kernels/samplingTopPKernels.h" #include "tensorrt_llm/runtime/bufferManager.h" +#include "tensorrt_llm/runtime/common.h" #include "tensorrt_llm/runtime/cudaStream.h" #include "tensorrt_llm/runtime/runtimeKernels.h" #include "tensorrt_llm/runtime/tllmLogger.h" @@ -40,19 +41,27 @@ namespace tensorrt_llm::tests::layers::sampling struct SamplingParams { - std::vector topKs; + std::vector topKs; std::vector topPs; std::vector temperatures; std::vector repetitionPenalties; std::vector presencePenalties; std::vector frequencyPenalties; - std::vector minLengths; + std::vector minLengths; std::vector decay; std::vector minTopP; - std::vector topPResetIds; - std::vector>> badWords; - std::vector>> stopWords; - bool useBias = false; + std::vector topPResetIds; + std::vector>> badWords; + std::vector>> stopWords; + bool useBias{false}; + + // Medusa setup + bool useMedusa{false}; + std::optional maxNumMedusaHeads{std::nullopt}; + std::optional>> topKMedusaHeads{std::nullopt}; + std::optional> tokensPerStep{std::nullopt}; + std::optional>> paths; + std::optional>> outputIds; }; template @@ -64,22 +73,22 @@ class DynamicDecodeLayerTest : public testing::Test using TensorPtr = tensorrt_llm::runtime::ITensor::SharedPtr; using BufferPtr = tensorrt_llm::runtime::IBuffer::SharedPtr; - int32_t seed = 0; - const static uint64_t mMaxSeed = 32; - int32_t const mBatchSize = 6; - int32_t const mMaxBatchSize = 2 * mBatchSize; - int32_t const mBeamWidth = 1; - int32_t const mBatchBeam = mBatchSize * mBeamWidth; - int32_t const mVocabSize = 9; - int32_t const mVocabSizePadded = mVocabSize; + static const uint64_t mMaxSeed{32}; + runtime::SizeType const mBatchSize{6}; + runtime::SizeType const mMaxBatchSize{2 * mBatchSize}; + runtime::SizeType const mBeamWidth{1}; + runtime::SizeType const mBatchBeam{mBatchSize * mBeamWidth}; + runtime::SizeType const mVocabSize{9}; + runtime::SizeType const mVocabSizePadded{mVocabSize}; - int32_t const mMaxInputLen = 0; // has no effect. - int32_t const mMaxOutputLen = 4; - int32_t const mMaxSeqLen = mMaxInputLen + mMaxOutputLen; - int32_t const mSinkTokenLength = 0; - int32_t mEndId = mVocabSize; + runtime::SizeType const mMaxInputLen{0}; // has no effect. + runtime::SizeType const mMaxOutputLen{4}; + runtime::SizeType const mMaxSeqLen{mMaxInputLen + mMaxOutputLen}; + runtime::SizeType const mSinkTokenLength{0}; + runtime::TokenIdType mEndId = mVocabSize; + runtime::SizeType mMaxTokensPerStep{1}; - bool mUseLogitsVec = false; + bool mUseLogitsVec{false}; TensorPtr mLogitsDevice; TensorPtr mRuntimeLogitsHost; @@ -110,12 +119,16 @@ class DynamicDecodeLayerTest : public testing::Test TensorPtr mCumLogProbsDevice; + // Medusa tensors + TensorPtr mPathsDevice; + TensorPtr mAcceptedLengths; + TensorPtr mMedusaLogitsDevice; + TensorPtr mNextDraftTokensDevice; + std::vector mLogitsVec; struct cudaDeviceProp mDeviceProp; - const tensorrt_llm::common::DataType data_type = tensorrt_llm::common::getTensorType(); - // Order is important because we pass mAllocator to mDecodeLayer and it is used in destructor std::shared_ptr mStream; std::shared_ptr mBufferManager; @@ -124,33 +137,39 @@ class DynamicDecodeLayerTest : public testing::Test std::vector mTestLogitsInit; - int32_t mMaxBadWordsLen{0}; - int32_t mMaxStopWordsLen{0}; + runtime::SizeType mMaxBadWordsLen{0}; + runtime::SizeType mMaxStopWordsLen{0}; + + bool mUseMedusa{false}; private: + void allocateData(SamplingParams const& params); + void setup(uint64_t seed, SamplingParams const& params); - int32_t getMaxWordsLen(std::vector>> const& inputWords); - void initXWordsTensors(int32_t* batchSlotsPtr, int32_t* wordsData, int32_t** wordsPtr, int32_t* wordsLenData, - int32_t maxWordsLen, std::vector>> const& inputWords); + runtime::SizeType getMaxWordsLen(std::vector>> const& inputWords); + void initXWordsTensors(runtime::SizeType* batchSlotsPtr, runtime::TokenIdType* wordsData, + runtime::TokenIdType** wordsPtr, runtime::SizeType* wordsLenData, runtime::SizeType maxWordsLen, + std::vector>> const& inputWords); - typename tensorrt_llm::layers::DynamicDecodeLayer::ForwardParams createInputTensors(int32_t step); + typename tensorrt_llm::layers::DynamicDecodeLayer::ForwardParams createInputTensors(runtime::SizeType step); typename tensorrt_llm::layers::DynamicDecodeLayer::OutputParams createOutputTensors(); - void batchCopy(int32_t step); - bool checkResult(int32_t* outputIds, std::vector> const& expectedIds, int32_t* seqLens, - int32_t leadingDim, int32_t stride, int32_t step); + void batchCopy(runtime::SizeType step); + bool checkResult(runtime::TokenIdType* outputIds, std::vector> const& expectedIds, + runtime::SizeType* seqLens, runtime::SizeType leadingDim, runtime::SizeType stride, runtime::SizeType step, + bool outputIdsTransposed = false, runtime::SizeType strideTransposed = 0); - void runTestImpl( - std::vector> const& expectedOutputIds, SamplingParams const& params, int32_t endId = -1); + void runTestImpl(std::vector> const& expectedOutputIds, SamplingParams const& params, + runtime::TokenIdType endId = -1); - void fillRefLogits( - int32_t const* seqLenHost, std::vector> const& expectedOutputIds, int32_t step); + void fillRefLogits(runtime::SizeType const* seqLenHost, + std::vector> const& expectedOutputIds, runtime::SizeType step); public: - void runTest( - std::vector> const& expectedOutputIds, SamplingParams const& params, int32_t endId = -1); + void runTest(std::vector> const& expectedOutputIds, SamplingParams const& params, + runtime::TokenIdType endId = -1); }; typedef testing::Types FloatAndHalfTypes; diff --git a/cpp/tests/layers/medusaDecodeLayerTest.cpp b/cpp/tests/layers/medusaDecodeLayerTest.cpp new file mode 100644 index 000000000..c74444c55 --- /dev/null +++ b/cpp/tests/layers/medusaDecodeLayerTest.cpp @@ -0,0 +1,438 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tests/layers/medusaDecodeLayerTest.h" +#include "tensorrt_llm/kernels/decodingCommon.h" +#include "tensorrt_llm/runtime/runtimeKernels.h" +#include + +namespace tensorrt_llm::tests::layers +{ + +// TODO(nkorobov): +// Add tests for +// - finished states +// - finished sum +// - max length +// - repeat n grams +// - padded vocab +// - beam search + +using namespace tensorrt_llm::runtime; +using namespace tensorrt_llm::layers; +using namespace tensorrt_llm::common; + +namespace tk = tensorrt_llm::kernels; +namespace tcc = tensorrt_llm::common::conversion; +namespace trk = tensorrt_llm::runtime::kernels; + +constexpr float EPSILON = 1e-20f; + +template +void MedusaDecodingLayerTest::SetUp() +{ + mStream = std::make_shared(); + mBufferManager = std::make_shared(mStream); + + mAllocator = std::make_shared(*mBufferManager); +} + +template +void MedusaDecodingLayerTest::allocateBuffers() +{ + mMedusaDecodingLayer = std::make_shared>( + mMaxBatchSize, mVocabSize, mVocabSizePadded, mMaxTokensPerStep, mMaxNumHeads, mStream->get(), mAllocator); + + auto const dataType = TRTDataType::value; + + // clang-format off + + // prob = (0.0, 0.0, 0.0, 0.0, 0.4, 0.3, 0.2, 0.1, 0.0) + std::vector targetLogitsInit = { + -FLT_MAX, -FLT_MAX, -FLT_MAX, -FLT_MAX, -0.9163, -1.2040, -1.6094, -2.3026, -FLT_MAX, // token 0 + -0.9163, -1.2040, -1.6094, -2.3026, -FLT_MAX, -FLT_MAX, -FLT_MAX, -FLT_MAX, -FLT_MAX, // token 1 + -FLT_MAX, -FLT_MAX, -0.9163, -1.2040, -1.6094, -2.3026, -FLT_MAX, -FLT_MAX, -FLT_MAX, // token 2 + -FLT_MAX, -0.9163, -1.2040, -1.6094, -2.3026, -FLT_MAX, -FLT_MAX, -FLT_MAX, -FLT_MAX, // token 3 + -FLT_MAX, -FLT_MAX, -FLT_MAX, -0.9163, -1.2040, -1.6094, -2.3026, -FLT_MAX, -FLT_MAX, // token 4 + -FLT_MAX, -FLT_MAX, -FLT_MAX, -FLT_MAX, -0.9163, -1.2040, -1.6094, -2.3026, -FLT_MAX, // token 5 + -FLT_MAX, -FLT_MAX, -FLT_MAX, -0.9163, -1.2040, -1.6094, -2.3026, -FLT_MAX, -FLT_MAX, // token 6 + -0.9163, -1.2040, -1.6094, -2.3026, -FLT_MAX, -FLT_MAX, -FLT_MAX, -FLT_MAX, -FLT_MAX, // token 7 + -FLT_MAX, -FLT_MAX, -0.9163, -1.2040, -1.6094, -2.3026, -FLT_MAX, -FLT_MAX, -FLT_MAX, // token 8 + -FLT_MAX, -FLT_MAX, -FLT_MAX, -0.9163, -1.2040, -1.6094, -2.3026, -FLT_MAX, -FLT_MAX, // token 9 + -FLT_MAX, -FLT_MAX, -FLT_MAX, -FLT_MAX, -0.9163, -1.2040, -1.6094, -2.3026, -FLT_MAX, // token 10 + -FLT_MAX, -0.9163, -1.2040, -1.6094, -2.3026, -FLT_MAX, -FLT_MAX, -FLT_MAX, -FLT_MAX // token 11 + }; + // Sampled tokens with K=1 + // [4, 0, 2, 1, 3, 4, 3, 0, 2, 3, 4, 1] + + std::vector medusaLogitsInit = { + -FLT_MAX, -FLT_MAX, -FLT_MAX, -FLT_MAX, -0.9163, -1.2040, -1.6094, -2.3026, -FLT_MAX, // token 0 head=0 ids: [4, 5, 6, 7] + -0.9163, -1.2040, -1.6094, -2.3026, -FLT_MAX, -FLT_MAX, -FLT_MAX, -FLT_MAX, -FLT_MAX, // token 1 head=0 ids: [0, 1, 2, 3] + -FLT_MAX, -FLT_MAX, -0.9163, -1.2040, -1.6094, -2.3026, -FLT_MAX, -FLT_MAX, -FLT_MAX, // token 2 head=0 ids: [2, 3, 4, 5] + -FLT_MAX, -0.9163, -1.2040, -1.6094, -2.3026, -FLT_MAX, -FLT_MAX, -FLT_MAX, -FLT_MAX, // token 3 head=0 ids: [1, 2, 3, 4] + -FLT_MAX, -FLT_MAX, -FLT_MAX, -0.9163, -1.2040, -1.6094, -2.3026, -FLT_MAX, -FLT_MAX, // token 4 head=0 ids: [3, 4, 5, 6] + -FLT_MAX, -FLT_MAX, -FLT_MAX, -FLT_MAX, -0.9163, -1.2040, -1.6094, -2.3026, -FLT_MAX, // token 5 head=0 ids: [4, 5, 6, 7] + -FLT_MAX, -FLT_MAX, -FLT_MAX, -0.9163, -1.2040, -1.6094, -2.3026, -FLT_MAX, -FLT_MAX, // token 6 head=0 ids: [3, 4, 5, 6] + -0.9163, -1.2040, -1.6094, -2.3026, -FLT_MAX, -FLT_MAX, -FLT_MAX, -FLT_MAX, -FLT_MAX, // token 7 head=0 ids: [0, 1, 2, 3] + -FLT_MAX, -FLT_MAX, -0.9163, -1.2040, -1.6094, -2.3026, -FLT_MAX, -FLT_MAX, -FLT_MAX, // token 8 head=0 ids: [2, 3, 4, 5] + -FLT_MAX, -FLT_MAX, -FLT_MAX, -0.9163, -1.2040, -1.6094, -2.3026, -FLT_MAX, -FLT_MAX, // token 9 head=0 ids: [3, 4, 5, 6] + -FLT_MAX, -FLT_MAX, -FLT_MAX, -FLT_MAX, -0.9163, -1.2040, -1.6094, -2.3026, -FLT_MAX, // token 10 head=0 ids: [4, 5, 6, 7] + -FLT_MAX, -0.9163, -1.2040, -1.6094, -2.3026, -FLT_MAX, -FLT_MAX, -FLT_MAX, -FLT_MAX, // token 11 head=0 ids: [1, 2, 3, 4] + + -FLT_MAX, -FLT_MAX, -0.9163, -1.2040, -1.6094, -2.3026, -FLT_MAX, -FLT_MAX, -FLT_MAX, // token 0 head=1 ids: [2, 3, 4, 5] + -0.9163, -1.2040, -1.6094, -2.3026, -FLT_MAX, -FLT_MAX, -FLT_MAX, -FLT_MAX, -FLT_MAX, // token 1 head=1 ids: [0, 1, 2, 3] + -FLT_MAX, -FLT_MAX, -FLT_MAX, -FLT_MAX, -0.9163, -1.2040, -1.6094, -2.3026, -FLT_MAX, // token 2 head=1 ids: [4, 5, 6, 7] + -FLT_MAX, -0.9163, -1.2040, -1.6094, -2.3026, -FLT_MAX, -FLT_MAX, -FLT_MAX, -FLT_MAX, // token 3 head=1 ids: [1, 2, 3, 4] + -FLT_MAX, -FLT_MAX, -FLT_MAX, -FLT_MAX, -0.9163, -1.2040, -1.6094, -2.3026, -FLT_MAX, // token 4 head=1 ids: [4, 5, 6, 7] + -FLT_MAX, -FLT_MAX, -0.9163, -1.2040, -1.6094, -2.3026, -FLT_MAX, -FLT_MAX, -FLT_MAX, // token 5 head=1 ids: [2, 3, 4, 5] + -FLT_MAX, -FLT_MAX, -FLT_MAX, -0.9163, -1.2040, -1.6094, -2.3026, -FLT_MAX, -FLT_MAX, // token 6 head=1 ids: [3, 4, 5, 6] + -FLT_MAX, -FLT_MAX, -FLT_MAX, -0.9163, -1.2040, -1.6094, -2.3026, -FLT_MAX, -FLT_MAX, // token 7 head=1 ids: [3, 4, 5, 6] + -FLT_MAX, -FLT_MAX, -FLT_MAX, -0.9163, -1.2040, -1.6094, -2.3026, -FLT_MAX, -FLT_MAX, // token 8 head=1 ids: [3, 4, 5, 6] + -0.9163, -1.2040, -1.6094, -2.3026, -FLT_MAX, -FLT_MAX, -FLT_MAX, -FLT_MAX, -FLT_MAX, // token 9 head=1 ids: [0, 1, 2, 3] + -FLT_MAX, -0.9163, -1.2040, -1.6094, -2.3026, -FLT_MAX, -FLT_MAX, -FLT_MAX, -FLT_MAX, // token 10 head=1 ids: [1, 2, 3, 4] + -FLT_MAX, -FLT_MAX, -FLT_MAX, -FLT_MAX, -0.9163, -1.2040, -1.6094, -2.3026, -FLT_MAX, // token 11 head=1 ids: [4, 5, 6, 7] + + -0.9163, -1.2040, -1.6094, -2.3026, -FLT_MAX, -FLT_MAX, -FLT_MAX, -FLT_MAX, -FLT_MAX, // token 0 head=2 ids: [0, 1, 2, 3] + -FLT_MAX, -FLT_MAX, -FLT_MAX, -FLT_MAX, -0.9163, -1.2040, -1.6094, -2.3026, -FLT_MAX, // token 1 head=2 ids: [4, 5, 6, 7] + -FLT_MAX, -0.9163, -1.2040, -1.6094, -2.3026, -FLT_MAX, -FLT_MAX, -FLT_MAX, -FLT_MAX, // token 2 head=2 ids: [1, 2, 3, 4] + -FLT_MAX, -FLT_MAX, -0.9163, -1.2040, -1.6094, -2.3026, -FLT_MAX, -FLT_MAX, -FLT_MAX, // token 3 head=2 ids: [2, 3, 4, 5] + -FLT_MAX, -FLT_MAX, -FLT_MAX, -FLT_MAX, -0.9163, -1.2040, -1.6094, -2.3026, -FLT_MAX, // token 4 head=2 ids: [4, 5, 6, 7] + -FLT_MAX, -FLT_MAX, -FLT_MAX, -0.9163, -1.2040, -1.6094, -2.3026, -FLT_MAX, -FLT_MAX, // token 5 head=2 ids: [3, 4, 5, 6] + -0.9163, -1.2040, -1.6094, -2.3026, -FLT_MAX, -FLT_MAX, -FLT_MAX, -FLT_MAX, -FLT_MAX, // token 6 head=2 ids: [0, 1, 2, 3] + -FLT_MAX, -FLT_MAX, -0.9163, -1.2040, -1.6094, -2.3026, -FLT_MAX, -FLT_MAX, -FLT_MAX, // token 7 head=2 ids: [2, 3, 4, 5] + -FLT_MAX, -0.9163, -1.2040, -1.6094, -2.3026, -FLT_MAX, -FLT_MAX, -FLT_MAX, -FLT_MAX, // token 8 head=2 ids: [1, 2, 3, 4] + -FLT_MAX, -FLT_MAX, -FLT_MAX, -0.9163, -1.2040, -1.6094, -2.3026, -FLT_MAX, -FLT_MAX, // token 9 head=2 ids: [3, 4, 5, 6] + -FLT_MAX, -FLT_MAX, -FLT_MAX, -FLT_MAX, -0.9163, -1.2040, -1.6094, -2.3026, -FLT_MAX, // token 10 head=2 ids: [4, 5, 6, 7] + -FLT_MAX, -FLT_MAX, -FLT_MAX, -0.9163, -1.2040, -1.6094, -2.3026, -FLT_MAX, -FLT_MAX, // token 11 head=2 ids: [3, 4, 5, 6] + + -FLT_MAX, -FLT_MAX, -FLT_MAX, -FLT_MAX, -0.9163, -1.2040, -1.6094, -2.3026, -FLT_MAX, // token 0 head=3 ids: [4, 5, 6, 7] + -FLT_MAX, -FLT_MAX, -FLT_MAX, -FLT_MAX, -0.9163, -1.2040, -1.6094, -2.3026, -FLT_MAX, // token 1 head=3 ids: [4, 5, 6, 7] + -FLT_MAX, -0.9163, -1.2040, -1.6094, -2.3026, -FLT_MAX, -FLT_MAX, -FLT_MAX, -FLT_MAX, // token 2 head=3 ids: [1, 2, 3, 4] + -0.9163, -1.2040, -1.6094, -2.3026, -FLT_MAX, -FLT_MAX, -FLT_MAX, -FLT_MAX, -FLT_MAX, // token 3 head=3 ids: [0, 1, 2, 3] + -FLT_MAX, -FLT_MAX, -FLT_MAX, -FLT_MAX, -0.9163, -1.2040, -1.6094, -2.3026, -FLT_MAX, // token 4 head=3 ids: [4, 5, 6, 7] + -FLT_MAX, -FLT_MAX, -0.9163, -1.2040, -1.6094, -2.3026, -FLT_MAX, -FLT_MAX, -FLT_MAX, // token 5 head=3 ids: [2, 3, 4, 5] + -FLT_MAX, -FLT_MAX, -FLT_MAX, -0.9163, -1.2040, -1.6094, -2.3026, -FLT_MAX, -FLT_MAX, // token 6 head=3 ids: [3, 4, 5, 6] + -FLT_MAX, -FLT_MAX, -FLT_MAX, -0.9163, -1.2040, -1.6094, -2.3026, -FLT_MAX, -FLT_MAX, // token 7 head=3 ids: [3, 4, 5, 6] + -FLT_MAX, -FLT_MAX, -FLT_MAX, -0.9163, -1.2040, -1.6094, -2.3026, -FLT_MAX, -FLT_MAX, // token 8 head=3 ids: [3, 4, 5, 6] + -FLT_MAX, -0.9163, -1.2040, -1.6094, -2.3026, -FLT_MAX, -FLT_MAX, -FLT_MAX, -FLT_MAX, // token 9 head=3 ids: [1, 2, 3, 4] + -0.9163, -1.2040, -1.6094, -2.3026, -FLT_MAX, -FLT_MAX, -FLT_MAX, -FLT_MAX, -FLT_MAX, // token 10 head=3 ids: [0, 1, 2, 3] + -FLT_MAX, -FLT_MAX, -0.9163, -1.2040, -1.6094, -2.3026, -FLT_MAX, -FLT_MAX, -FLT_MAX // token 11 head=3 ids: [2, 3, 4, 5] + }; + + // clang-format on + + auto const targetLogitsHost + = ITensor::wrap(targetLogitsInit.data(), dataType, ITensor::makeShape({mMaxTokensPerStep, mVocabSizePadded})); + + TensorPtr medusaLogitsHost = ITensor::wrap( + medusaLogitsInit.data(), dataType, ITensor::makeShape({mMaxNumHeads, mMaxTokensPerStep, mVocabSizePadded})); + + mTargetLogitsDevice + = mBufferManager->gpu(ITensor::makeShape({mBatchSize, mMaxTokensPerStep, mVocabSizePadded}), dataType); + + mFinishedDevice = mBufferManager->gpu( + ITensor::makeShape({mMaxBatchSize}), TRTDataType::value); + + mOutputIdsDevice = mBufferManager->gpu(ITensor::makeShape({mMaxBatchSize, mMaxSeqLen}), nvinfer1::DataType::kINT32); + + mBatchSlots = BufferManager::pinned(ITensor::makeShape({mBatchSize}), nvinfer1::DataType::kINT32); + + mEndIdsDevice = mBufferManager->gpu(ITensor::makeShape({mMaxBatchSize}), nvinfer1::DataType::kINT32); + + mPathsDevice = mBufferManager->gpu( + ITensor::makeShape({mMaxBatchSize, mMaxTokensPerStep, mMaxNumHeads + 1}), nvinfer1::DataType::kINT32); + + mSeqLengthsDevice = mBufferManager->gpu(ITensor::makeShape({mMaxBatchSize}), nvinfer1::DataType::kINT32); + + mAcceptedLengths = mBufferManager->gpu(ITensor::makeShape({mMaxBatchSize}), nvinfer1::DataType::kINT32); + + mMedusaLogitsDevice = mBufferManager->gpu( + ITensor::makeShape({mMaxNumHeads, mMaxBatchSize, mMaxTokensPerStep, mVocabSizePadded}), dataType); + + mNextDraftTokensDevice + = mBufferManager->gpu(ITensor::makeShape({mMaxBatchSize, mMaxTokensPerStep}), nvinfer1::DataType::kINT32); + + for (int32_t bi = 0; bi < mBatchSize; ++bi) + { + auto logitsDeviceView = ITensor::slice(mTargetLogitsDevice, bi, 1); + mBufferManager->copy(*targetLogitsHost, *logitsDeviceView); + } + + for (int32_t hi = 0; hi < mMaxNumHeads; ++hi) + { + TensorPtr logitsHeadDeviceView = ITensor::slice(mMedusaLogitsDevice, hi, 1); + TensorPtr logitsHeadHostView = ITensor::slice(medusaLogitsHost, hi, 1); + logitsHeadDeviceView->squeeze(0); + for (int32_t bi = 0; bi < mBatchSize; ++bi) + { + TensorPtr logitsHeadBatchDeviceView = ITensor::slice(logitsHeadDeviceView, bi, 1); + mBufferManager->copy(*logitsHeadHostView, *logitsHeadBatchDeviceView); + } + } +} + +template +void MedusaDecodingLayerTest::setup(SamplingParams& params) +{ + auto const endId = params.endId.value_or(mEndId); + trk::invokeFill(*mSeqLengthsDevice, SizeType{0}, *mStream); + trk::invokeFill(*mAcceptedLengths, SizeType{0}, *mStream); + trk::invokeFill(*mFinishedDevice, uint8_t{0}, *mStream); + trk::invokeFill(*mOutputIdsDevice, SizeType{0}, *mStream); + trk::invokeFill(*mEndIdsDevice, TokenIdType{endId}, *mStream); + trk::invokeFill(*mNextDraftTokensDevice, TokenIdType{-1}, *mStream); + trk::invokeFill(*mPathsDevice, SizeType{-1}, *mStream); + + auto batchSlotsPtr = bufferCast(*mBatchSlots); + for (SizeType bi = 0; bi < mBatchSize; ++bi) + { + batchSlotsPtr[bi] = 2 * bi; + } + + for (SizeType bi = 0; bi < mBatchSize; ++bi) + { + auto const outputIdsHost = ITensor::wrap(reinterpret_cast(params.draftIds[bi].data()), + nvinfer1::DataType::kINT32, ITensor::makeShape({1, mMaxTokensPerStep})); + auto outputIdsDeviceSlice = ITensor::slice(mOutputIdsDevice, batchSlotsPtr[bi], 1); + mBufferManager->copy(*outputIdsHost, *outputIdsDeviceSlice); + } + + for (SizeType bi = 0; bi < mBatchSize; ++bi) + { + auto& path = params.paths[bi]; + auto const numPaths = static_cast(params.paths[bi].size() / (mMaxNumHeads + 1)); + auto const pathsHost = ITensor::wrap(reinterpret_cast(path.data()), nvinfer1::DataType::kINT32, + ITensor::makeShape({1, numPaths, mMaxNumHeads + 1})); + TensorPtr pathsDeviceSlice = ITensor::slice(mPathsDevice, batchSlotsPtr[bi], 1); + pathsDeviceSlice->squeeze(0); + TensorPtr pathsNumPathsDeviceSlice = ITensor::slice(pathsDeviceSlice, 0, numPaths); + pathsNumPathsDeviceSlice->unsqueeze(0); + mBufferManager->copy(*pathsHost, *pathsNumPathsDeviceSlice); + } + + typename MedusaDecodingLayer::MedusaSetupParams setupParams; + setupParams.runtimeTopK = std::make_optional>(params.runtimeTopK); + setupParams.runtimeHeadsTopK = std::make_optional>>(params.runtimeHeadsTopK); + setupParams.tokensPerStep = std::make_optional>(params.tokensPerStep); + setupParams.randomSeed = {{0}}; + + mMedusaDecodingLayer->setup(mBatchSize, batchSlotsPtr, setupParams); + + mStream->synchronize(); +} + +template +typename MedusaDecodingLayer::MedusaForwardParams MedusaDecodingLayerTest::createInputTensors() +{ + typename MedusaDecodingLayer::MedusaForwardParams forwardParams( + tcc::toTllmTensor(*mTargetLogitsDevice), tcc::toTllmTensor(*mEndIdsDevice)); + + forwardParams.finished = tcc::toTllmTensor(*mFinishedDevice); + + forwardParams.batch_slots = tcc::toTllmTensor(*mBatchSlots); + + forwardParams.paths = tcc::toTllmTensor(*mPathsDevice); + + forwardParams.medusaLogits = tcc::toTllmTensor(*mMedusaLogitsDevice); + + return forwardParams; +} + +template +DecodingOutputParams MedusaDecodingLayerTest::createOutputTensors() +{ + DecodingOutputParams outputParams(tcc::toTllmTensor(*mOutputIdsDevice)); + + outputParams.sequence_length = tcc::toTllmTensor(*mSeqLengthsDevice); + + outputParams.finished = tcc::toTllmTensor(*mFinishedDevice); + + outputParams.nextDraftTokens = tcc::toTllmTensor(*mNextDraftTokensDevice); + + outputParams.acceptedLengths = tcc::toTllmTensor(*mAcceptedLengths); + + return outputParams; +} + +template +void MedusaDecodingLayerTest::checkResult(std::vector>> const& expectedOutTokens, + std::vector> const& expectedDraftTokens, std::vector const& finished) +{ + auto const nextDraftTokensHost = mBufferManager->copyFrom(*mNextDraftTokensDevice, runtime::MemoryType::kCPU); + auto const outputIdsHost = mBufferManager->copyFrom(*mOutputIdsDevice, runtime::MemoryType::kCPU); + auto const seqLenHost = mBufferManager->copyFrom(*mSeqLengthsDevice, runtime::MemoryType::kCPU); + auto const acceptedLengthsHost = mBufferManager->copyFrom(*mAcceptedLengths, runtime::MemoryType::kCPU); + auto const finishedHost = mBufferManager->copyFrom(*mFinishedDevice, runtime::MemoryType::kCPU); + + mStream->synchronize(); + + auto nextDraftTokens = BufferRange(*nextDraftTokensHost); + auto outputIds = BufferRange(*outputIdsHost); + auto seqLen = BufferRange(*seqLenHost); + auto batchSlots = BufferRange(*mBatchSlots); + auto acceptedLengths = BufferRange(*acceptedLengthsHost); + auto finishedPtr + = reinterpret_cast(bufferCast(*finishedHost)); + + for (SizeType bi = 0; bi < mBatchSize; ++bi) + { + auto& expectedOutTokensBatch = expectedOutTokens[bi]; + auto const slot = batchSlots[bi]; + EXPECT_EQ(expectedOutTokensBatch.size(), seqLen[slot]); + EXPECT_EQ(expectedOutTokensBatch.size(), acceptedLengths[slot]); + for (SizeType ti = 0; ti < expectedOutTokensBatch.size(); ++ti) + { + EXPECT_GE(expectedOutTokensBatch[ti].count(outputIds[slot * mMaxSeqLen + ti]), 1); + } + } + + for (SizeType bi = 0; bi < mBatchSize; ++bi) + { + auto& expectedDraftTokensBatch = expectedDraftTokens[bi]; + auto const slot = batchSlots[bi]; + for (SizeType ti = 0; ti < expectedDraftTokensBatch.size(); ++ti) + { + EXPECT_EQ(expectedDraftTokensBatch[ti], nextDraftTokens[slot * mMaxTokensPerStep + ti]); + } + } + for (SizeType bi = 0; bi < mBatchSize; ++bi) + { + auto const slot = batchSlots[bi]; + EXPECT_EQ(finished[bi], finishedPtr[slot].isFinished()); + } +} + +template +void MedusaDecodingLayerTest::runTest(std::vector>> const& expectedOutTokens, + std::vector> const& expectedDraftTokens, std::vector const& finished, + SamplingParams& params) +{ + mBatchSize = params.batchSize; + mMaxBatchSize = 2 * mBatchSize; + + allocateBuffers(); + + setup(params); + + auto inputTensors = createInputTensors(); + auto outputTensors = createOutputTensors(); + + mMedusaDecodingLayer->forward(outputTensors, inputTensors); + + mStream->synchronize(); + + checkResult(expectedOutTokens, expectedDraftTokens, finished); +} + +template class MedusaDecodingLayerTest; +template class MedusaDecodingLayerTest; + +TYPED_TEST_SUITE(MedusaDecodingLayerTest, FloatAndHalfTypes); + +TYPED_TEST(MedusaDecodingLayerTest, SimpleTestBS1) +{ + SamplingParams params; + params.runtimeTopK = {1}; + params.runtimeHeadsTopK = {{2, 3, 2, 1}}; + params.draftIds = {{0, 4, 0, 2, 1, 3, 4, 3, 0, 2, 3, 4}}; + params.paths = {{0, 1, 2, 3, -1}}; + params.tokensPerStep = {12}; + params.batchSize = 1; + + std::vector>> expectedOutTokens = {{{4}, {0}, {2}, {1}}}; + std::vector> expectedDraftTokens = {{1, 2, 1, 2, 3, 2, 3, 0}}; + std::vector finished = {false}; + this->runTest(expectedOutTokens, expectedDraftTokens, finished, params); +} + +TYPED_TEST(MedusaDecodingLayerTest, SimpleTestBS4) +{ + // Target Ids to be sampled + // [4, 0, 2, 1, 3, 4, 3, 0, 2, 3, 4, 1] + SamplingParams params; + params.runtimeTopK = {1, 1, 1, 1}; + params.runtimeHeadsTopK = {{2, 3, 2, 1}, {1, 2, 3, 4}, {3, 1, 1, 1}, {1, 1, 1, 1}}; + // clang-format off + params.draftIds = {{0, 4, 0, 2, 1, 3, 4, 4, 0, 2, 3, 4}, + {0, 4, 0, 2, 1, 4, 4, 4, 0, 2, 2, 4}, + {0, 4, 0, 4, 1, 1, 4, 4, 0, 2, 0, 4}, + {0, 4, 0, 2, 1, 3, 2, 4, 0, 2, 3, 4}}; + params.paths = {{0, 7, 2, 8, -1, + 0, 3, -1, -1, -1}, + {0, 5, 7, 8, 10, + 0, 3, -1, -1, -1}, + {0, 8, 2, 9, -1, + 0, 3, 5, 6, -1, + 0, 3, 5, 7, 10}, + {0, 1, 2, 6, -1, + 0, 3, -1, -1, -1}}; + + // clang-format on + params.tokensPerStep = {12, 11, 11, 7}; + params.batchSize = 4; + + std::vector>> expectedOutTokens + = {{{4}, {0}, {2}}, {{4}, {4}, {0}, {2}, {4}}, {{4}, {1}, {4}, {0}, {4}}, {{4}, {0}, {2}, {3}}}; + std::vector> expectedDraftTokens + = {{2, 3, 3, 4, 5, 1, 2, 3}, {4, 1, 2, 4, 5, 6, 0, 1, 2, 3}, {4, 5, 6, 1, 4, 0}, {3, 3, 0, 3}}; + std::vector finished = {false, false, false, false}; + this->runTest(expectedOutTokens, expectedDraftTokens, finished, params); +} + +TYPED_TEST(MedusaDecodingLayerTest, SimpleTestEndIdNotSelected) +{ + // Target Ids to be sampled + // [4, 0, 2, 1, 3, 4, 3, 0, 2, 3, 4, 1] + SamplingParams params; + params.runtimeTopK = {1}; + params.runtimeHeadsTopK = {{1, 1, 1, 1}}; + params.draftIds = {{0, 4, 0, 4, 1, 3, 2, 3, 0, 2, 3, 4}}; + // clang-format off + params.paths = {{0, 3, 4, 5, -1, + 0, 1, 2, 6, -1}}; + // clang-format on + params.tokensPerStep = {12}; + params.batchSize = 1; + params.endId = 1; + + std::vector>> expectedOutTokens = {{{4}, {0}, {2}, {3}}}; + std::vector> expectedDraftTokens = {{3, 3, 0, 3}}; + std::vector finished = {false}; + this->runTest(expectedOutTokens, expectedDraftTokens, finished, params); +} + +TYPED_TEST(MedusaDecodingLayerTest, SimpleTestEndIdSelected) +{ + // Target Ids to be sampled + // [4, 0, 2, 1, 3, 4, 3, 0, 2, 3, 4, 1] + SamplingParams params; + params.runtimeTopK = {1}; + params.runtimeHeadsTopK = {{1, 1, 1, 1}}; + params.draftIds = {{0, 4, 0, 4, 1, 3, 2, 3, 0, 2, 3, 4}}; + // clang-format off + params.paths = {{0, 3, 4, 5, -1, + 0, 11, 7, 9, -1}}; + // clang-format on + params.tokensPerStep = {12}; + params.batchSize = 1; + params.endId = 1; + + std::vector>> expectedOutTokens = {{{4}}}; + std::vector> expectedDraftTokens = {{3, 4, 4, 4}}; + std::vector finished = {true}; + this->runTest(expectedOutTokens, expectedDraftTokens, finished, params); +} +} // namespace tensorrt_llm::tests::layers diff --git a/cpp/tests/layers/medusaDecodeLayerTest.h b/cpp/tests/layers/medusaDecodeLayerTest.h new file mode 100644 index 000000000..5914b9cb4 --- /dev/null +++ b/cpp/tests/layers/medusaDecodeLayerTest.h @@ -0,0 +1,112 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include + +#include + +#include "tensorrt_llm/layers/medusaDecodingLayer.h" +#include "tensorrt_llm/runtime/bufferManager.h" +#include "tensorrt_llm/runtime/cudaStream.h" + +#include "tensorrt_llm/runtime/bufferManager.h" +#include "tensorrt_llm/runtime/cudaStream.h" +#include "tensorrt_llm/runtime/runtimeKernels.h" +#include "tensorrt_llm/runtime/tllmLogger.h" + +#include "tensorrt_llm/common/cudaAllocator.h" +#include "tensorrt_llm/common/tensorConversion.h" +#include "tensorrt_llm/common/tllmException.h" + +namespace tensorrt_llm::tests::layers +{ + +struct SamplingParams +{ + tensorrt_llm::runtime::SizeType batchSize; + std::vector runtimeTopK; + std::vector> runtimeHeadsTopK; + std::vector> draftIds; + std::vector> paths; + std::vector tokensPerStep; + std::optional endId; +}; + +template +class MedusaDecodingLayerTest : public testing::Test +{ +private: + void SetUp() override; + + using TensorPtr = tensorrt_llm::runtime::ITensor::SharedPtr; + using BufferPtr = tensorrt_llm::runtime::IBuffer::SharedPtr; + using SizeType = tensorrt_llm::runtime::SizeType; + using TokenIdType = tensorrt_llm::runtime::TokenIdType; + + SizeType mBatchSize{6}; + SizeType mMaxBatchSize{2 * mBatchSize}; + SizeType const mVocabSize{9}; + SizeType const mVocabSizePadded{mVocabSize}; + SizeType const mMaxTokensPerStep{12}; + SizeType const mMaxNumHeads{4}; + + SizeType const mMaxSeqLen{mMaxTokensPerStep}; + TokenIdType mEndId{mVocabSize}; + + bool mUseLogitsVec{false}; + + TensorPtr mTargetLogitsDevice; + TensorPtr mMedusaLogitsDevice; + + TensorPtr mFinishedDevice; + TensorPtr mSeqLengthsDevice; + TensorPtr mAcceptedLengths; + TensorPtr mOutputIdsDevice; + TensorPtr mNextDraftTokensDevice; + + TensorPtr mPathsDevice; + TensorPtr mEndIdsDevice; + TensorPtr mBatchSlots; + + std::vector mLogitsVec; + + std::shared_ptr mStream; + std::shared_ptr mBufferManager; + std::shared_ptr mAllocator; + std::shared_ptr> mMedusaDecodingLayer; + +private: + void allocateBuffers(); + + void setup(SamplingParams& params); + + typename tensorrt_llm::layers::MedusaDecodingLayer::MedusaForwardParams createInputTensors(); + + tensorrt_llm::layers::DecodingOutputParams createOutputTensors(); + + void checkResult(std::vector>> const& expectedOutTokens, + std::vector> const& expectedDraftTokens, std::vector const& finished); + +public: + void runTest(std::vector>> const& expectedOutTokens, + std::vector> const& expectedDraftTokens, std::vector const& finished, + SamplingParams& params); +}; + +typedef testing::Types FloatAndHalfTypes; + +} // namespace tensorrt_llm::tests::layers diff --git a/cpp/tests/layers/samplingLayerTest.cpp b/cpp/tests/layers/samplingLayerTest.cpp index e4b5dfaa2..1650e7a03 100644 --- a/cpp/tests/layers/samplingLayerTest.cpp +++ b/cpp/tests/layers/samplingLayerTest.cpp @@ -65,7 +65,7 @@ TYPED_TEST_SUITE(SamplingLayerTest, FloatAndHalfTypes); TYPED_TEST(SamplingLayerTest, TopKToPPSkipDecode) { - uint32_t topK = 2; + SizeType topK = 2; float topP = 0.0f; SamplingParams params; params.topKs = {topK}; @@ -82,7 +82,7 @@ TYPED_TEST(SamplingLayerTest, TopKToPPSkipDecode) TYPED_TEST(SamplingLayerTest, TopKSkipDecodeTopP) { - uint32_t topK = 0; + SizeType topK = 0; float topP = 0.5f; SamplingParams params; params.topKs = {topK}; @@ -99,7 +99,7 @@ TYPED_TEST(SamplingLayerTest, TopKSkipDecodeTopP) TYPED_TEST(SamplingLayerTest, BatchTopKTopP) { - std::vector topKs = {0, 2, 1, 0, 1, 0}; + std::vector topKs = {0, 2, 1, 0, 1, 0}; std::vector topPs = {0.3f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f}; SamplingParams params; params.topKs = topKs; @@ -133,7 +133,7 @@ TYPED_TEST(SamplingLayerTest, TopPDecay) TYPED_TEST(SamplingLayerTest, TopK) { - uint32_t topK = 2; + SizeType topK = 2; float topP = 0.0f; SamplingParams params; params.topKs = {topK}; @@ -150,7 +150,7 @@ TYPED_TEST(SamplingLayerTest, TopK) TYPED_TEST(SamplingLayerTest, TopK1TopP0) { - uint32_t topK = 1; + SizeType topK = 1; float topP = 0.0f; SamplingParams params; params.topKs = {topK}; @@ -167,7 +167,7 @@ TYPED_TEST(SamplingLayerTest, TopK1TopP0) TYPED_TEST(SamplingLayerTest, BatchTopK) { - std::vector topKs = {2, 1, 1, 2, 1, 1}; + std::vector topKs = {2, 1, 1, 2, 1, 1}; SamplingParams params; params.topKs = topKs; params.topPs = {1.0f}; @@ -183,7 +183,7 @@ TYPED_TEST(SamplingLayerTest, BatchTopK) TYPED_TEST(SamplingLayerTest, TopKTopP) { - uint32_t topK = 2; + SizeType topK = 2; float topP = 0.3; SamplingParams params; params.topKs = {topK}; @@ -200,7 +200,7 @@ TYPED_TEST(SamplingLayerTest, TopKTopP) TYPED_TEST(SamplingLayerTest, BatchTopKTopP1) { - std::vector topKs = {2, 2, 1, 2, 2, 1}; + std::vector topKs = {2, 2, 1, 2, 2, 1}; float topP = 0.3; SamplingParams params; params.topKs = topKs; @@ -217,7 +217,7 @@ TYPED_TEST(SamplingLayerTest, BatchTopKTopP1) TYPED_TEST(SamplingLayerTest, BatchTopKBatchTopP) { - std::vector topKs = {2, 2, 0, 2, 2, 1}; + std::vector topKs = {2, 2, 0, 2, 2, 1}; std::vector topPs = {0.0, 0.3, 0.5, 0.0, 0.3, 0.5}; SamplingParams params; params.topKs = topKs; @@ -234,7 +234,7 @@ TYPED_TEST(SamplingLayerTest, BatchTopKBatchTopP) TYPED_TEST(SamplingLayerTest, InvalidArgsZeroTopK) { - uint32_t topK = 0; + SizeType topK = 0; SamplingParams params; params.topKs = {topK}; std::vector> expectedOutputIds{ @@ -250,7 +250,7 @@ TYPED_TEST(SamplingLayerTest, InvalidArgsZeroTopK) TYPED_TEST(SamplingLayerTest, InvalidArgsZeroTopP) { float topP = 0; - uint32_t topK = 0; + SizeType topK = 0; SamplingParams params; params.topPs = {topP}; params.topKs = {topK}; @@ -266,7 +266,7 @@ TYPED_TEST(SamplingLayerTest, InvalidArgsZeroTopP) TYPED_TEST(SamplingLayerTest, InvalidArgsZeroTopKTopP) { - uint32_t topK = 0; + SizeType topK = 0; float topP = 0; SamplingParams params; params.topPs = {topP}; @@ -283,7 +283,7 @@ TYPED_TEST(SamplingLayerTest, InvalidArgsZeroTopKTopP) TYPED_TEST(SamplingLayerTest, InvalidArgsZeroBatchTopKTopP) { - std::vector topKs = {0, 0, 0, 0, 0, 0}; + std::vector topKs = {0, 0, 0, 0, 0, 0}; float topP = 0; SamplingParams params; params.topPs = {topP}; @@ -300,7 +300,7 @@ TYPED_TEST(SamplingLayerTest, InvalidArgsZeroBatchTopKTopP) TYPED_TEST(SamplingLayerTest, InvalidArgsZeroTopKBatchTopP) { - uint32_t topK = 0; + SizeType topK = 0; std::vector topPs = {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}; SamplingParams params; params.topPs = topPs; @@ -317,7 +317,7 @@ TYPED_TEST(SamplingLayerTest, InvalidArgsZeroTopKBatchTopP) TYPED_TEST(SamplingLayerTest, InvalidArgsBatchTopKContainZero) { - std::vector topKs = {2, 1, 0, 0, 2, 1}; + std::vector topKs = {2, 1, 0, 0, 2, 1}; SamplingParams params; params.topKs = topKs; std::vector> expectedOutputIds{ @@ -332,7 +332,7 @@ TYPED_TEST(SamplingLayerTest, InvalidArgsBatchTopKContainZero) TYPED_TEST(SamplingLayerTest, InvalidArgsBatchTopKTopPContainZero) { - std::vector topKs = {2, 2, 1, 0, 2, 0}; + std::vector topKs = {2, 2, 1, 0, 2, 0}; float topP = 0.0; SamplingParams params; params.topPs = {topP}; @@ -349,7 +349,7 @@ TYPED_TEST(SamplingLayerTest, InvalidArgsBatchTopKTopPContainZero) TYPED_TEST(SamplingLayerTest, OnlyTopK) { - std::vector topKs = {2, 2, 1, 0, 2, 0}; + std::vector topKs = {2, 2, 1, 0, 2, 0}; SamplingParams params; params.topKs = topKs; std::vector> expectedOutputIds{ diff --git a/cpp/tests/layers/topKSamplingLayerTest.cpp b/cpp/tests/layers/topKSamplingLayerTest.cpp index eb1438792..6528795bf 100644 --- a/cpp/tests/layers/topKSamplingLayerTest.cpp +++ b/cpp/tests/layers/topKSamplingLayerTest.cpp @@ -20,6 +20,7 @@ namespace { using namespace tensorrt_llm::tests::layers::sampling; +using namespace tensorrt_llm::runtime; template class TopKSamplingLayerTest : public BaseSamplingLayerTest @@ -43,7 +44,7 @@ TYPED_TEST_SUITE(TopKSamplingLayerTest, FloatAndHalfTypes); TYPED_TEST(TopKSamplingLayerTest, TopK) { - uint32_t topK = 2; + SizeType topK = 2; float topP = 0.0f; SamplingParams params; params.topKs = {topK}; @@ -60,7 +61,7 @@ TYPED_TEST(TopKSamplingLayerTest, TopK) TYPED_TEST(TopKSamplingLayerTest, TopK1TopP0) { - uint32_t topK = 1; + SizeType topK = 1; float topP = 0.0f; SamplingParams params; params.topKs = {topK}; @@ -77,7 +78,7 @@ TYPED_TEST(TopKSamplingLayerTest, TopK1TopP0) TYPED_TEST(TopKSamplingLayerTest, BatchTopK) { - std::vector topKs = {2, 1, 1, 2, 1, 1}; + std::vector topKs = {2, 1, 1, 2, 1, 1}; SamplingParams params; params.topKs = topKs; params.topPs = {1.0f}; @@ -109,7 +110,7 @@ TYPED_TEST(TopKSamplingLayerTest, SkipDecode) TYPED_TEST(TopKSamplingLayerTest, TopKTopP) { - uint32_t topK = 2; + SizeType topK = 2; float topP = 0.3; SamplingParams params; params.topKs = {topK}; @@ -126,7 +127,7 @@ TYPED_TEST(TopKSamplingLayerTest, TopKTopP) TYPED_TEST(TopKSamplingLayerTest, BatchTopKTopP) { - std::vector topKs = {2, 2, 1, 2, 2, 1}; + std::vector topKs = {2, 2, 1, 2, 2, 1}; float topP = 0.3; SamplingParams params; params.topKs = topKs; @@ -143,7 +144,7 @@ TYPED_TEST(TopKSamplingLayerTest, BatchTopKTopP) TYPED_TEST(TopKSamplingLayerTest, TopKBatchTopP) { - uint32_t topK = 2; + SizeType topK = 2; std::vector topPs = {0.5, 0.3, 0.5, 0.5, 0.3, 0.5}; SamplingParams params; params.topKs = {topK}; @@ -160,7 +161,7 @@ TYPED_TEST(TopKSamplingLayerTest, TopKBatchTopP) TYPED_TEST(TopKSamplingLayerTest, BatchTopKBatchTopP) { - std::vector topKs = {2, 2, 1, 2, 2, 1}; + std::vector topKs = {2, 2, 1, 2, 2, 1}; std::vector topPs = {0.0, 0.3, 0.5, 0.0, 0.3, 0.5}; SamplingParams params; params.topKs = topKs; @@ -177,7 +178,7 @@ TYPED_TEST(TopKSamplingLayerTest, BatchTopKBatchTopP) TYPED_TEST(TopKSamplingLayerTest, InvalidArgsZeroTopK) { - uint32_t topK = 0; + SizeType topK = 0; SamplingParams params; params.topKs = {topK}; std::vector> expectedOutputIds{ @@ -207,7 +208,7 @@ TYPED_TEST(TopKSamplingLayerTest, InvalidArgsZeroTopP) TYPED_TEST(TopKSamplingLayerTest, InvalidArgsZeroTopKTopP) { - uint32_t topK = 0; + SizeType topK = 0; float topP = 0; SamplingParams params; params.topPs = {topP}; @@ -224,7 +225,7 @@ TYPED_TEST(TopKSamplingLayerTest, InvalidArgsZeroTopKTopP) TYPED_TEST(TopKSamplingLayerTest, InvalidArgsZeroBatchTopKTopP) { - std::vector topKs = {0, 0, 0, 0, 0, 0}; + std::vector topKs = {0, 0, 0, 0, 0, 0}; float topP = 0; SamplingParams params; params.topPs = {topP}; @@ -241,7 +242,7 @@ TYPED_TEST(TopKSamplingLayerTest, InvalidArgsZeroBatchTopKTopP) TYPED_TEST(TopKSamplingLayerTest, InvalidArgsZeroTopKBatchTopP) { - uint32_t topK = 0; + SizeType topK = 0; std::vector topPs = {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}; SamplingParams params; params.topPs = topPs; @@ -258,7 +259,7 @@ TYPED_TEST(TopKSamplingLayerTest, InvalidArgsZeroTopKBatchTopP) TYPED_TEST(TopKSamplingLayerTest, InvalidArgsBatchTopKContainZero) { - std::vector topKs = {2, 1, 0, 0, 2, 1}; + std::vector topKs = {2, 1, 0, 0, 2, 1}; SamplingParams params; params.topKs = topKs; std::vector> expectedOutputIds{ @@ -273,7 +274,7 @@ TYPED_TEST(TopKSamplingLayerTest, InvalidArgsBatchTopKContainZero) TYPED_TEST(TopKSamplingLayerTest, InvalidArgsBatchTopKTopPContainZero) { - std::vector topKs = {2, 2, 1, 0, 2, 0}; + std::vector topKs = {2, 2, 1, 0, 2, 0}; float topP = 0.0; SamplingParams params; params.topPs = {topP}; @@ -290,7 +291,7 @@ TYPED_TEST(TopKSamplingLayerTest, InvalidArgsBatchTopKTopPContainZero) TYPED_TEST(TopKSamplingLayerTest, InvalidArgsBatchTopKBatchTopPContainZero) { - std::vector topKs = {0, 2, 1, 2, 2, 0}; + std::vector topKs = {0, 2, 1, 2, 2, 0}; std::vector topPs = {0.0, 0.3, 0.9, 0.0, 0.3, 0.5}; SamplingParams params; params.topPs = topPs; diff --git a/cpp/tests/layers/topPSamplingLayerTest.cpp b/cpp/tests/layers/topPSamplingLayerTest.cpp index 68fd22a7e..243f80a74 100644 --- a/cpp/tests/layers/topPSamplingLayerTest.cpp +++ b/cpp/tests/layers/topPSamplingLayerTest.cpp @@ -20,6 +20,7 @@ namespace { using namespace tensorrt_llm::tests::layers::sampling; +using namespace tensorrt_llm::runtime; template class TopPSamplingLayerTest : public BaseSamplingLayerTest @@ -51,7 +52,7 @@ TYPED_TEST_SUITE(TopPSamplingLayerTest, FloatAndHalfTypes); TYPED_TEST(TopPSamplingLayerTest, TopKSkipDecode) { - uint32_t topK = 2; + SizeType topK = 2; float topP = 0.0f; SamplingParams params; params.topKs = {topK}; @@ -68,7 +69,7 @@ TYPED_TEST(TopPSamplingLayerTest, TopKSkipDecode) TYPED_TEST(TopPSamplingLayerTest, TopKTopPSkipDecode) { - uint32_t topK = 2; + SizeType topK = 2; float topP = 1.0f; SamplingParams params; params.topKs = {topK}; @@ -85,7 +86,7 @@ TYPED_TEST(TopPSamplingLayerTest, TopKTopPSkipDecode) TYPED_TEST(TopPSamplingLayerTest, BatchTopKTopP) { - std::vector topKs = {0, 1, 1, 0, 1, 0}; + std::vector topKs = {0, 1, 1, 0, 1, 0}; std::vector topPs = {0.3f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f}; SamplingParams params; params.topKs = topKs; @@ -102,7 +103,7 @@ TYPED_TEST(TopPSamplingLayerTest, BatchTopKTopP) TYPED_TEST(TopPSamplingLayerTest, TopP) { - uint32_t topK = 0; + SizeType topK = 0; float topP = 0.3f; SamplingParams params; params.topKs = {topK}; diff --git a/cpp/tests/resources/scripts/build_chatglm_engines.py b/cpp/tests/resources/scripts/build_chatglm_engines.py index 266edd3b1..4a802bb88 100755 --- a/cpp/tests/resources/scripts/build_chatglm_engines.py +++ b/cpp/tests/resources/scripts/build_chatglm_engines.py @@ -37,7 +37,6 @@ def convert_ckpt(model_dir: str, output_dir: str, world_size: int): f"--model_dir={model_dir}", f"--output_dir={output_dir}", f"--tp_size={world_size}" ] - print("Running: " + " ".join(convert_cmd)) run_command(convert_cmd) @@ -50,7 +49,6 @@ def build_engine(ckpt_dir: str, engine_dir: str): "--builder_opt=0", "--remove_input_padding=disable", "--paged_kv_cache=disable" ] - print("Running: " + " ".join(build_cmd)) run_command(build_cmd) diff --git a/cpp/tests/resources/scripts/build_gpt_engines.py b/cpp/tests/resources/scripts/build_gpt_engines.py index da4bb112c..3cb22186a 100755 --- a/cpp/tests/resources/scripts/build_gpt_engines.py +++ b/cpp/tests/resources/scripts/build_gpt_engines.py @@ -14,50 +14,60 @@ # See the License for the specific language governing permissions and # limitations under the License. -import argparse as _arg -import os as _os -import pathlib as _pl -import platform as _pf -import typing as _tp - -import hf_gpt_convert as _egc -import torch.multiprocessing as _mp +import argparse +import os +import platform +import sys +from pathlib import Path +from typing import Optional + from build_engines_utils import run_command, wincopy -import build as _egb # isort:skip + +def convert_ckpt(model_dir: str, + output_dir: str, + *args, + world_size: int = 1, + dtype: str = 'float16'): + convert_cmd = [ + sys.executable, "examples/gpt/convert_checkpoint.py", + f"--model_dir={model_dir}", f"--output_dir={output_dir}", + f"--dtype={dtype}", f"--tp_size={world_size}" + ] + list(args) + run_command(convert_cmd) def build_engine( - weight_dir: _pl.Path, - engine_dir: _pl.Path, - world_size, + checkpoint_dir: str, + engine_dir: str, *args, - max_input_len=256, - max_output_len=128, + max_input_len: int = 256, + max_output_len: int = 128, ): - args = [ - '--log_level=error', - '--model_dir', - str(weight_dir), - '--output_dir', - str(engine_dir), - '--max_batch_size=64', - f'--max_input_len={max_input_len}', - f'--max_output_len={max_output_len}', - '--max_beam_width=2', - '--builder_opt=0', - f'--world_size={world_size}', - ] + list(args) - print("Running: build " + " ".join(args)) - _egb.run_build(args) + build_cmd = [ + "trtllm-build", '--log_level=error', + f'--checkpoint_dir={checkpoint_dir}', f'--output_dir={engine_dir}', + '--max_batch_size=64', f'--max_input_len={max_input_len}', + f'--max_output_len={max_output_len}', '--max_beam_width=2', + '--builder_opt=0' + ] + legacy_args = [ + "--gpt_attention_plugin=disable", + "--context_fmha=disable", + "--paged_kv_cache=disable", + "--remove_input_padding=disable", + "--enable_xqa=disable", + ] + build_cmd = build_cmd + legacy_args + list(args) + run_command(build_cmd) -def build_engines(model_cache: _tp.Optional[str] = None, world_size: int = 1): +def build_engines(model_cache: Optional[str] = None, world_size: int = 1): # TODO add support of Pipeline parallelism to GPT tp_size = world_size pp_size = 1 - resources_dir = _pl.Path(__file__).parent.resolve().parent + resources_dir = Path(__file__).parent.resolve().parent models_dir = resources_dir / 'models' model_name = 'gpt2' @@ -67,13 +77,13 @@ def build_engines(model_cache: _tp.Optional[str] = None, world_size: int = 1): assert hf_dir.is_dir() run_command(["git", "pull"], cwd=hf_dir) else: - if _pf.system() == "Windows": + if platform.system() == "Windows": url_prefix = "" else: url_prefix = "file://" model_url = url_prefix + str( - _pl.Path(model_cache) / + Path(model_cache) / model_name) if model_cache else "https://huggingface.co/gpt2" run_command([ "git", "clone", model_url, "--single-branch", "--no-local", @@ -81,7 +91,7 @@ def build_engines(model_cache: _tp.Optional[str] = None, world_size: int = 1): ], cwd=hf_dir.parent, env={ - **_os.environ, "GIT_LFS_SKIP_SMUDGE": "1" + **os.environ, "GIT_LFS_SKIP_SMUDGE": "1" }) assert hf_dir.is_dir() @@ -89,16 +99,16 @@ def build_engines(model_cache: _tp.Optional[str] = None, world_size: int = 1): # Download the model file model_file_name = "pytorch_model.bin" if model_cache: - if _pf.system() == "Windows": + if platform.system() == "Windows": wincopy(source=str( - _pl.Path(model_cache) / model_name / model_file_name), + Path(model_cache) / model_name / model_file_name), dest=model_file_name, isdir=False, cwd=hf_dir) else: run_command([ "rsync", "-av", - str(_pl.Path(model_cache) / model_name / model_file_name), "." + str(Path(model_cache) / model_name / model_file_name), "." ], cwd=hf_dir) else: @@ -112,108 +122,105 @@ def build_engines(model_cache: _tp.Optional[str] = None, world_size: int = 1): assert (hf_dir / model_file_name).is_file() - weight_dir = models_dir / 'c-model' / model_name + ckpt_dir = models_dir / 'c-model' / model_name engine_dir = models_dir / 'rt_engine' / model_name tp_pp_dir = f"tp{tp_size}-pp{pp_size}-gpu" tp_dir = f"{world_size}-gpu" print("\nConverting to fp32") - fp32_weight_dir = weight_dir / 'fp32' - _egc.run_conversion( - _egc.ProgArgs(in_file=str(hf_dir), - out_dir=str(fp32_weight_dir), - storage_type='float32', - tensor_parallelism=tp_size)) + fp32_ckpt_dir = ckpt_dir / 'fp32' / tp_dir + convert_ckpt(str(hf_dir), + str(fp32_ckpt_dir), + world_size=tp_size, + dtype='float32') print("\nBuilding fp32 engines") - fp32_weight_dir_x_gpu = fp32_weight_dir / tp_dir - build_engine(fp32_weight_dir_x_gpu, engine_dir / 'fp32-default' / tp_pp_dir, - tp_size, '--dtype=float32') - build_engine(fp32_weight_dir_x_gpu, engine_dir / 'fp32-plugin' / tp_pp_dir, - tp_size, '--dtype=float32', - '--use_gpt_attention_plugin=float32') + build_engine(str(fp32_ckpt_dir), + str(engine_dir / 'fp32-default' / tp_pp_dir)) + build_engine(str(fp32_ckpt_dir), + str(engine_dir / 'fp32-plugin' / tp_pp_dir), + '--gpt_attention_plugin=float32', '--context_fmha=enable', + '--context_fmha_fp32_acc=enable') print("\nConverting to fp16") - fp16_weight_dir = weight_dir / 'fp16' - _egc.run_conversion( - _egc.ProgArgs(in_file=str(hf_dir), - out_dir=str(fp16_weight_dir), - storage_type='float16', - tensor_parallelism=tp_size)) + fp16_ckpt_dir = ckpt_dir / 'fp16' / tp_dir + convert_ckpt(str(hf_dir), + str(fp16_ckpt_dir), + world_size=tp_size, + dtype='float16') print("\nBuilding fp16 engines") - fp16_weight_dir_x_gpu = fp16_weight_dir / tp_dir - build_engine(fp16_weight_dir_x_gpu, engine_dir / 'fp16-default' / tp_pp_dir, - tp_size, '--dtype=float16') - build_engine(fp16_weight_dir_x_gpu, engine_dir / 'fp16-plugin' / tp_pp_dir, - tp_size, '--dtype=float16', - '--use_gpt_attention_plugin=float16') - build_engine(fp16_weight_dir_x_gpu, - engine_dir / 'fp16-plugin-packed' / tp_pp_dir, tp_size, - '--dtype=float16', '--use_gpt_attention_plugin=float16', - '--remove_input_padding') + build_engine(str(fp16_ckpt_dir), + str(engine_dir / 'fp16-default' / tp_pp_dir)) + build_engine(str(fp16_ckpt_dir), + str(engine_dir / 'fp16-plugin' / tp_pp_dir), + '--gpt_attention_plugin=float16') + build_engine(str(fp16_ckpt_dir), + str(engine_dir / 'fp16-plugin-packed' / tp_pp_dir), + '--gpt_attention_plugin=float16', + '--remove_input_padding=enable') + # this engine can be use for in-flight batching ifb_args = [ - '--dtype=float16', - '--use_gpt_attention_plugin=float16', - '--remove_input_padding', - '--paged_kv_cache', - '--enable_context_fmha_fp32_acc', + '--gpt_attention_plugin=float16', + '--remove_input_padding=enable', + '--paged_kv_cache=enable', + '--context_fmha=enable', + '--context_fmha_fp32_acc=enable', '--max_num_tokens=10000', - '--use_paged_context_fmha', + '--use_paged_context_fmha=enable', ] - build_engine(fp16_weight_dir_x_gpu, - engine_dir / 'fp16-plugin-packed-paged' / tp_pp_dir, tp_size, + build_engine(str(fp16_ckpt_dir), + str(engine_dir / 'fp16-plugin-packed-paged' / tp_pp_dir), '--max_draft_len=5', *ifb_args) - build_engine(fp16_weight_dir_x_gpu, - engine_dir / 'fp16-plugin-packed-paged-in128' / tp_pp_dir, - tp_size, - max_input_len=128, - *ifb_args) + build_engine(str(fp16_ckpt_dir), + str(engine_dir / 'fp16-plugin-packed-paged-in128' / tp_pp_dir), + *ifb_args, + max_input_len=128) # We build almost the same engine twice. But this engine has gather_all_token_logits # to extract logits from python runtime and uses context FMHA for generation to match draft model executions, # which uses context FMHA for draft tokens prediction. # Currently the gather_all_token_logits is not supported with target model of speculative decoding - build_engine(fp16_weight_dir_x_gpu, - engine_dir / 'fp16-plugin-packed-paged-gather' / tp_pp_dir, - tp_size, '--gather_all_token_logits', *ifb_args) + build_engine( + str(fp16_ckpt_dir), + str(engine_dir / 'fp16-plugin-packed-paged-gather' / tp_pp_dir), + '--gather_all_token_logits', *ifb_args) # '--use_context_fmha_for_generation', *ifb_args) # Commented out because of `--use_context_fmha_for_generation` has bugs now: https://nvbugspro.nvidia.com/bug/4476681 build_engine( - fp16_weight_dir_x_gpu, engine_dir / - 'fp16-plugin-packed-paged-context-fmha-for-gen' / tp_pp_dir, tp_size, - '--use_context_fmha_for_generation', '--max_draft_len=5', *ifb_args) + str(fp16_ckpt_dir), + str(engine_dir / 'fp16-plugin-packed-paged-context-fmha-for-gen' / + tp_pp_dir), '--use_context_fmha_for_generation=enable', + '--max_draft_len=5', *ifb_args) # build engine with lora enabled - build_engine(fp16_weight_dir_x_gpu, - engine_dir / "fp16-plugin-packed-paged-lora" / tp_pp_dir, - tp_size, '--use_lora_plugin=float16', - '--lora_target_modules=attn_qkv', *ifb_args) + build_engine(str(fp16_ckpt_dir), + str(engine_dir / "fp16-plugin-packed-paged-lora" / tp_pp_dir), + "--lora_target_modules=attn_qkv", '--lora_plugin=float16', + *ifb_args) print("\nConverting to fp16 SQ") - fp16_weight_dir = weight_dir / 'fp16-sq' - fp16_weight_dir_x_gpu = fp16_weight_dir / tp_dir - _egc.run_conversion( - _egc.ProgArgs(in_file=str(hf_dir), - out_dir=str(fp16_weight_dir), - storage_type='float16', - tensor_parallelism=tp_size, - smoothquant=0.5)) + fp16_sq_ckpt_dir = ckpt_dir / 'fp16-sq' / tp_dir + convert_ckpt(str(hf_dir), + str(fp16_sq_ckpt_dir), + "--smoothquant=0.5", + world_size=tp_size, + dtype='float16') print("\nBuilding fp16 SQ engines") - build_engine(fp16_weight_dir_x_gpu, - engine_dir / 'fp16-plugin-packed-paged-sq' / tp_pp_dir, - tp_size, *ifb_args) + build_engine(str(fp16_sq_ckpt_dir), + str(engine_dir / 'fp16-plugin-packed-paged-sq' / tp_pp_dir), + *ifb_args) if has_safetensor: - _pl.Path(str(safetensor_file) + ".bak").rename(safetensor_file) + Path(str(safetensor_file) + ".bak").rename(safetensor_file) print("Done.") if __name__ == "__main__": - parser = _arg.ArgumentParser() + parser = argparse.ArgumentParser() parser.add_argument("--model_cache", type=str, help="Directory where models are stored") @@ -223,6 +230,4 @@ def build_engines(model_cache: _tp.Optional[str] = None, world_size: int = 1): default=1, help='world size, only support tensor parallelism now') - _mp.set_start_method("spawn") - build_engines(**vars(parser.parse_args())) diff --git a/cpp/tests/resources/scripts/generate_batch_manager_data.py b/cpp/tests/resources/scripts/generate_batch_manager_data.py deleted file mode 100644 index 3f40cc297..000000000 --- a/cpp/tests/resources/scripts/generate_batch_manager_data.py +++ /dev/null @@ -1,74 +0,0 @@ -#!/usr/bin/env python3 -# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import json -from pathlib import Path - - -def generate_dataset(num_samples, output_filename, long_prompt=False): - resources_dir = Path(__file__).parent.resolve().parent - data_dir = resources_dir / 'data' - dummy_cnn_dataset = data_dir / output_filename - - input = ' '.join(['test' for _ in range(10)]) - output = ' '.join(['test' for _ in range(10)]) - - instruction = "Summarize the following news article:" - if long_prompt: - instruction = ( - "TensorRT-LLM provides users with an easy-to-use Python " - "API to define Large Language Models (LLMs) and build TensorRT engines " - "that contain state-of-the-art optimizations to perform inference " - "efficiently on NVIDIA GPUs. TensorRT-LLM also contains components to " - "create Python and C++ runtimes that execute those TensorRT engines. " - "It also includes a backend for integration with the NVIDIA Triton Inference " - "Server; a production-quality system to serve LLMs. Models built with " - "TensorRT-LLM can be executed on a wide range of configurations going from " - "a single GPU to multiple nodes with multiple GPUs (using Tensor Parallelism " - "and/or Pipeline Parallelism). The Python API of TensorRT-LLM is architectured " - "to look similar to the PyTorch API. It provides users with a functional module " - "containing functions like einsum, softmax, matmul or view. The layers module " - "bundles useful building blocks to assemble LLMs; like an Attention block, a MLP " - "or the entire Transformer layer. Model-specific components, like GPTAttention " - "or BertAttention, can be found in the models module.") - - samples = [] - for _ in range(num_samples): - samples.append({ - "input": input, - "instruction": instruction, - "output": output - }) - - with open(dummy_cnn_dataset, 'w') as f: - json.dump(samples, f) - - -if __name__ == '__main__': - parser = argparse.ArgumentParser() - parser.add_argument('--long_prompt', - default=False, - action='store_true', - help='Using long prompts to generate tokens.') - parser.add_argument('--output_filename', - type=str, - default='dummy_cnn.json', - help=('The name of the json output file.')) - FLAGS = parser.parse_args() - generate_dataset(num_samples=50, - output_filename=FLAGS.output_filename, - long_prompt=FLAGS.long_prompt) diff --git a/cpp/tests/resources/scripts/generate_test_lora_weights.py b/cpp/tests/resources/scripts/generate_test_lora_weights.py index d409e422b..9a5fe779b 100644 --- a/cpp/tests/resources/scripts/generate_test_lora_weights.py +++ b/cpp/tests/resources/scripts/generate_test_lora_weights.py @@ -58,11 +58,68 @@ def pad_tensors(weights_list): return torch.concatenate(padded_weights, dim=0) +def copy_to_cache_pages(weights, + lora_config, + page_blocks, + configs, + tp_rank=0, + tp_size=1): + page_slots = page_blocks.shape[1] + page_width = page_blocks.shape[2] + + curr_page = 0 + curr_slot = 0 + for i in range(lora_config.shape[0]): + module = configs[lora_config[i, 0]] + adapter_size = module[2] + in_dim = module[3] + out_dim = module[4] + mod_id = module[0] + split_in = mod_id in (4, 6, 12) + + local_in_dim = in_dim // tp_size + local_out_dim = out_dim // tp_size + + local_size = 0 + if split_in: + local_size = adapter_size * (local_in_dim + out_dim) + else: + local_size = adapter_size * (in_dim + local_out_dim) + + num_slots = (local_size + page_width - 1) // page_width + if num_slots + curr_slot > page_slots: + curr_slot = 0 + curr_page += 1 + + flattend_size = adapter_size * (in_dim + out_dim) + + if split_in: + in_weights = weights[i, :adapter_size * in_dim].reshape( + (adapter_size, tp_size, + local_in_dim))[:, tp_rank, :].contiguous().flatten() + out_weights = weights[i, adapter_size * + in_dim:flattend_size].contiguous().flatten() + else: + in_weights = weights[i, :adapter_size * + in_dim].contiguous().flatten() + out_weights = weights[i, + adapter_size * in_dim:flattend_size].reshape( + (tp_size, local_out_dim, adapter_size + ))[tp_rank, :, :].contiguous().flatten() + + page_blocks[curr_page, curr_slot:curr_slot + num_slots, :].view( + -1)[0:in_weights.shape[0] + + out_weights.shape[0]] = torch.concatenate( + (in_weights, out_weights)).contiguous().flatten() + curr_slot += num_slots + + def main(): torch.manual_seed(12345) parser = argparse.ArgumentParser() parser.add_argument('--tp-size', type=int, default=1) parser.add_argument('--out-dir', type=Path, required=True) + parser.add_argument('--num-loras', type=int, default=1) args = parser.parse_args() @@ -93,35 +150,55 @@ def main(): hidden_size), # cross_attn_dense ] - all_source = [] - all_config = [] - - all_target = [] - for c in configs: - source_weights, config = generate_source_weights(*c) - all_source.append(source_weights) - all_config.append(config) - - mod_id, _, adapter_size, in_dim, out_dim = c - split_in = mod_id in (4, 6, 12) - - target_weights = format_tensors(source_weights, adapter_size, in_dim, - out_dim, args.tp_size, split_in) - all_target.append(target_weights) - - all_source = pad_tensors(all_source) - all_config = pad_tensors(all_config) - all_target = pad_tensors(all_target) - - source_out_path = args.out_dir / 'source.npy' - config_out_path = args.out_dir / 'config.npy' - target_out_path = args.out_dir / 'target.npy' - - os.makedirs(args.out_dir, exist_ok=True) - - np.save(source_out_path, all_source) - np.save(config_out_path, all_config) - np.save(target_out_path, all_target) + for lora_idx in range(args.num_loras): + all_source = [] + all_config = [] + + all_target = [] + for c in configs: + source_weights, config = generate_source_weights(*c) + all_source.append(source_weights) + all_config.append(config) + + mod_id, _, adapter_size, in_dim, out_dim = c + split_in = mod_id in (4, 6, 12) + + target_weights = format_tensors(source_weights, adapter_size, + in_dim, out_dim, args.tp_size, + split_in) + all_target.append(target_weights) + + all_source = pad_tensors(all_source) + all_config = pad_tensors(all_config) + all_target = pad_tensors(all_target) + + output_dir = Path(args.out_dir) + if args.num_loras > 1: + output_dir = output_dir / str(lora_idx) + + os.makedirs(output_dir, exist_ok=True) + # copy weights into cache pages + for rank in range(args.tp_size): + page_block = torch.zeros((8, 18, 128), + dtype=torch.float32, + device='cpu') + copy_to_cache_pages(all_source, + all_config, + page_block, + configs, + tp_rank=rank, + tp_size=args.tp_size) + + out_path = output_dir / f'cache_pages_rank{rank}.npy' + np.save(out_path, page_block) + + source_out_path = output_dir / 'source.npy' + config_out_path = output_dir / 'config.npy' + target_out_path = output_dir / 'target.npy' + + np.save(source_out_path, all_source) + np.save(config_out_path, all_config) + np.save(target_out_path, all_target) if __name__ == "__main__": diff --git a/cpp/tests/resources/scripts/test_cpp.py b/cpp/tests/resources/scripts/test_cpp.py index b4530e19f..88076a6ae 100755 --- a/cpp/tests/resources/scripts/test_cpp.py +++ b/cpp/tests/resources/scripts/test_cpp.py @@ -135,8 +135,17 @@ def run_tests(cuda_architectures: _tp.Optional[str] = None, "--tp-size=2" ] + generate_multi_lora_tp2_args = [ + python_exe, + str(resources_dir / "scripts" / "generate_test_lora_weights.py"), + "--out-dir=cpp/tests/resources/data/multi_lora", + "--tp-size=2", + "--num-loras=128", + ] + run_command(generate_lora_data_args_tp1, cwd=root_dir, timeout=100) run_command(generate_lora_data_args_tp2, cwd=root_dir, timeout=100) + run_command(generate_multi_lora_tp2_args, cwd=root_dir, timeout=100) if not skip_unit_tests: run_unit_tests(build_dir=build_dir) @@ -367,18 +376,19 @@ def run_multi_gpu_tests(build_dir: _pl.Path): "mpirun", "-n", "4", "--allow-run-as-root", "gptSessionTest", "--gtest_filter=*TP4*:*PP4*" ] - run_command(session_test, cwd=tests_dir, env=cpp_env, timeout=300) + run_command(session_test, cwd=tests_dir, env=cpp_env, + timeout=300) # expecting ~250s trt_model_test = [ "mpirun", "-n", "4", "--allow-run-as-root", "batch_manager/trtGptModelRealDecoderTest", "--gtest_filter=*TP*:*PP*" ] - run_command(trt_model_test, cwd=tests_dir, env=cpp_env, timeout=300) + run_command(trt_model_test, cwd=tests_dir, env=cpp_env, + timeout=1500) # expecting ~ 1200s def run_benchmarks(python_exe: str, root_dir: _pl.Path, build_dir: _pl.Path, resources_dir: _pl.Path): - scripts_dir = resources_dir / "scripts" make_benchmarks = [ "cmake", "--build", ".", "--config", "Release", "-j", "--target", @@ -395,21 +405,30 @@ def run_benchmarks(python_exe: str, root_dir: _pl.Path, build_dir: _pl.Path, ] run_command(benchmark, cwd=root_dir, timeout=600) - prompt_flags = [None, "--long_prompt"] - prompt_files = ["dummy_cnn.json", "dummy_long_cnn.json"] - token_files = ["prepared_" + s for s in prompt_files] - max_input_lens = ["20", "512"] + prompt_datasets_args = [{ + '--dataset-name': "cnn_dailymail", + '--dataset-config-name': "3.0.0", + '--dataset-split': "validation", + '--dataset-input-key': "article", + '--dataset-prompt': "Summarize the following article:", + '--dataset-output-key': "highlights" + }, { + '--dataset-name': "Open-Orca/1million-gpt-4", + '--dataset-split': "train", + '--dataset-input-key': "question", + '--dataset-prompt-key': "system_prompt", + '--dataset-output-key': "response" + }] + token_files = [ + "prepared_" + s['--dataset-name'].replace('/', '_') + for s in prompt_datasets_args + ] + max_input_lens = ["512", "20"] + num_reqs = ["50", "10"] - for flag, prompt_f, tokens_f, len in zip(prompt_flags, prompt_files, - token_files, max_input_lens): - generate_batch_manager_data = [ - python_exe, - str(scripts_dir / "generate_batch_manager_data.py"), - "--output_filename", prompt_f - ] - if flag is not None: - generate_batch_manager_data.append(flag) - run_command(generate_batch_manager_data, cwd=root_dir, timeout=300) + for prompt_ds_args, tokens_f, len, num_req in zip(prompt_datasets_args, + token_files, + max_input_lens, num_reqs): benchmark_src_dir = _pl.Path("benchmarks") / "cpp" data_dir = resources_dir / "data" @@ -417,9 +436,11 @@ def run_benchmarks(python_exe: str, root_dir: _pl.Path, build_dir: _pl.Path, python_exe, str(benchmark_src_dir / "prepare_dataset.py"), "--tokenizer", str(resources_dir / "models" / "gpt2"), "--output", - str(data_dir / tokens_f), "dataset", "--dataset", - str(data_dir / prompt_f), "--max-input-len", len + str(data_dir / tokens_f), "dataset", "--max-input-len", len, + "--num-requests", num_req ] + for k, v in prompt_ds_args.items(): + prepare_dataset += [k, v] run_command(prepare_dataset, cwd=root_dir, timeout=300) benchmark = [ @@ -438,14 +459,6 @@ def run_benchmarks(python_exe: str, root_dir: _pl.Path, build_dir: _pl.Path, ] run_command(benchmark, cwd=root_dir, timeout=600) - benchmark = [ - str(benchmark_exe_dir / "gptManagerBenchmark"), "--engine_dir", - str(gpt_engine_dir / "fp16-plugin-packed-paged" / "tp1-pp1-gpu"), - "--type", "IFB", "--static_emulated_batch_size", "50", "--dataset", - str(data_dir / "prepared_dummy_cnn.json") - ] - run_command(benchmark, cwd=root_dir) - if __name__ == "__main__": _log.basicConfig(level=_log.INFO) diff --git a/cpp/tests/runtime/bufferManagerTest.cpp b/cpp/tests/runtime/bufferManagerTest.cpp index f94e802f9..0d0dde084 100644 --- a/cpp/tests/runtime/bufferManagerTest.cpp +++ b/cpp/tests/runtime/bufferManagerTest.cpp @@ -43,6 +43,16 @@ class BufferManagerTest : public ::testing::Test // NOLINT(cppcoreguidelines-pro void TearDown() override {} + std::size_t memoryPoolReserved() + { + return BufferManager::memoryPoolReserved(mStream->getDevice()); + } + + std::size_t memoryPoolFree() + { + return BufferManager::memoryPoolFree(mStream->getDevice()); + } + int mDeviceCount; BufferManager::CudaStreamPtr mStream; }; @@ -170,3 +180,20 @@ TEST_F(BufferManagerTest, MemPoolAttributes) EXPECT_LE(manager.memoryPoolReserved(), reserved); EXPECT_LE(manager.memoryPoolFree(), free); } + +TEST_F(BufferManagerTest, TrimPoolOnDestruction) +{ + auto manager = std::make_unique(mStream, true); // trim the pool on destruction + + manager->memoryPoolTrimTo(0); + auto const reserved = manager->memoryPoolReserved(); + auto const free = manager->memoryPoolFree(); + auto constexpr kBytesToReserve = 1 << 20; + { + auto const mem = manager->allocate(MemoryType::kGPU, kBytesToReserve); + } + EXPECT_GE(manager->memoryPoolFree(), free + kBytesToReserve); + manager.reset(); + EXPECT_LE(memoryPoolReserved(), reserved); + EXPECT_LE(memoryPoolFree(), free); +} diff --git a/cpp/tests/runtime/gptSessionTest.cpp b/cpp/tests/runtime/gptSessionTest.cpp index 98f5d5dbb..65f1d35d5 100644 --- a/cpp/tests/runtime/gptSessionTest.cpp +++ b/cpp/tests/runtime/gptSessionTest.cpp @@ -20,6 +20,7 @@ #include #include "tensorrt_llm/common/memoryUtils.h" +#include "tensorrt_llm/common/mpiUtils.h" #include "tensorrt_llm/common/stlUtils.h" #include "tensorrt_llm/plugins/api/tllmPlugin.h" #include "tensorrt_llm/runtime/gptJsonConfig.h" @@ -29,7 +30,6 @@ #include #include -#include using namespace tensorrt_llm::runtime; @@ -583,7 +583,7 @@ TEST_P(ParamTest, Test) // Warning: This should be the last check before running the test. // It will initialize MPI which can take significant time. - if (!WorldConfig::validConfig(modelSpec.mTPSize, modelSpec.mPPSize)) + if (modelSpec.mTPSize * modelSpec.mPPSize != COMM_SESSION.getSize()) { GTEST_SKIP() << "Model's world size " << modelSpec.mPPSize * modelSpec.mTPSize << " is not equal to the system world size"; diff --git a/cpp/tests/runtime/loraCacheTest.cpp b/cpp/tests/runtime/loraCacheTest.cpp new file mode 100644 index 000000000..0e3a60341 --- /dev/null +++ b/cpp/tests/runtime/loraCacheTest.cpp @@ -0,0 +1,590 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/runtime/loraCache.h" +#include "tensorrt_llm/runtime/bufferManager.h" +#include "tensorrt_llm/runtime/cudaStream.h" +#include "tensorrt_llm/runtime/iBuffer.h" +#include "tensorrt_llm/runtime/loraCachePageManagerConfig.h" +#include "tensorrt_llm/runtime/loraModule.h" +#include "tensorrt_llm/runtime/loraUtils.h" +#include "tensorrt_llm/runtime/utils/numpyUtils.h" +#include "tensorrt_llm/runtime/worldConfig.h" +#include +#include +#include +#include +#include +#include + +namespace fs = std::filesystem; + +namespace +{ + +auto const TEST_RESOURCE_PATH = fs::path{TOP_LEVEL_DIR} / "cpp/tests/resources/data"; +auto const TEST_SOURCE_LORA_TP1 = TEST_RESOURCE_PATH / "lora-test-weights-tp1/source.npy"; +auto const TEST_DEST_LORA_TP1 = TEST_RESOURCE_PATH / "lora-test-weights-tp1/target.npy"; +auto const TEST_KEYS_LORA_TP1 = TEST_RESOURCE_PATH / "lora-test-weights-tp1/config.npy"; +auto const TEST_KEYS_LORA_TP1_PAGES_RANK0 = TEST_RESOURCE_PATH / "lora-test-weights-tp1/cache_pages_rank0.npy"; +auto const TEST_SOURCE_LORA_TP2 = TEST_RESOURCE_PATH / "lora-test-weights-tp2/source.npy"; +auto const TEST_DEST_LORA_TP2 = TEST_RESOURCE_PATH / "lora-test-weights-tp2/target.npy"; +auto const TEST_KEYS_LORA_TP2 = TEST_RESOURCE_PATH / "lora-test-weights-tp2/config.npy"; +auto const TEST_KEYS_LORA_TP2_PAGES_RANK0 = TEST_RESOURCE_PATH / "lora-test-weights-tp2/cache_pages_rank0.npy"; +auto const TEST_KEYS_LORA_TP2_PAGES_RANK1 = TEST_RESOURCE_PATH / "lora-test-weights-tp2/cache_pages_rank1.npy"; +} // namespace + +namespace tensorrt_llm::runtime +{ + +using TensorPtr = ITensor::SharedPtr; + +class LoraCacheTest : public ::testing::Test // NOLINT(cppcoreguidelines-pro-type-member-init) +{ +protected: + LoraCacheTest() {} + + void SetUp() override + { + mModelConfig = std::make_unique(0, 2, 1, 16, nvinfer1::DataType::kFLOAT); + mModelConfig->setMlpHiddenSize(32); + mWorldConfig = std::make_unique(2, 1, 0); + std::vector modules{ + LoraModule(LoraModule::ModuleType::kATTN_QKV, 16, 3 * 16, false, true, -1, 0), + LoraModule(LoraModule::ModuleType::kATTN_Q, 16, 16, false, true, -1, 0), + LoraModule(LoraModule::ModuleType::kATTN_K, 16, 16, false, true, -1, 0), + LoraModule(LoraModule::ModuleType::kATTN_V, 16, 16, false, true, -1, 0), + LoraModule(LoraModule::ModuleType::kATTN_DENSE, 16, 16, false, true, 1, -1), + LoraModule(LoraModule::ModuleType::kMLP_H_TO_4H, 16, 32, false, true, -1, 0), + LoraModule(LoraModule::ModuleType::kMLP_4H_TO_H, 32, 16, false, true, 1, -1), + LoraModule(LoraModule::ModuleType::kMLP_GATE, 16, 32, false, true, -1, 0), + LoraModule(LoraModule::ModuleType::kCROSS_ATTN_QKV, 16, 3 * 16, false, true, -1, 0), + LoraModule(LoraModule::ModuleType::kCROSS_ATTN_Q, 16, 16, false, true, -1, 0), + LoraModule(LoraModule::ModuleType::kCROSS_ATTN_K, 16, 16, false, true, -1, 0), + LoraModule(LoraModule::ModuleType::kCROSS_ATTN_V, 16, 16, false, true, -1, 0), + LoraModule(LoraModule::ModuleType::kCROSS_ATTN_DENSE, 16, 16, false, true, 1, -1), + }; + mModelConfig->setLoraModules(modules); + mStream = std::make_shared(); + mManager = std::make_unique(mStream); + + auto pageConfig = LoraCachePageManagerConfig( + runtime::MemoryType::kCPU, nvinfer1::DataType::kFLOAT, 2 * 8, 6, 64, 4 * 16, 1); + pageConfig.setInitToZero(true); + auto pageConfig2 = pageConfig; + pageConfig2.setInitToZero(true); + pageConfig2.setMemoryType(runtime::MemoryType::kGPU); + mLoraCache = std::make_unique(pageConfig, *mModelConfig, *mWorldConfig, *mManager); + mLoraCache2 = std::make_unique(pageConfig2, *mModelConfig, *mWorldConfig, *mManager); + } + + std::shared_ptr mManager; + BufferManager::CudaStreamPtr mStream; + std::unique_ptr mModelConfig; + std::unique_ptr mWorldConfig; + std::unique_ptr mLoraCache; + std::unique_ptr mLoraCache2; +}; + +TEST_F(LoraCacheTest, LoraCachePageManagerTest) +{ + SizeType constexpr maxAdapterSize = 4; + SizeType constexpr maxAdapterWeights = 8; + auto pageShape = ITensor::makeShape({maxAdapterSize, maxAdapterWeights}); + + LoraCachePageManagerConfig config( + runtime::MemoryType::kCPU, nvinfer1::DataType::kFLOAT, 8, 6, maxAdapterSize, maxAdapterWeights, 1); + LoraCachePageManager manager(config, *mManager); + + auto block0 = manager.blockPtr(0); + auto block1 = manager.blockPtr(1); + + auto expectedBlockShape0 = ITensor::makeShape({6, maxAdapterSize, maxAdapterWeights}); + auto expectedBlockShape1 = ITensor::makeShape({2, maxAdapterSize, maxAdapterWeights}); + EXPECT_TRUE(ITensor::shapeEquals(block0->getShape(), expectedBlockShape0)); + EXPECT_TRUE(ITensor::shapeEquals(block1->getShape(), expectedBlockShape1)); + + std::vector expectedPages; + for (SizeType blockIdx = 0; blockIdx < 2; ++blockIdx) + { + auto block = blockIdx == 0 ? block0 : block1; + for (SizeType i = 0; i < (blockIdx == 0 ? 6 : 2); ++i) + { + TensorPtr page = ITensor::slice( + ITensor::wrap(const_cast(block->data()), block->getDataType(), block->getShape()), i, 1); + page->squeeze(0); + expectedPages.push_back(page); + } + } + + // auto [claimed, singlePageId] = manager.claimPages(1); + auto singlePageId = manager.claimPages(1); + ASSERT_TRUE(singlePageId.has_value()); + ASSERT_EQ(singlePageId.value().size(), 1); + auto singlePage = manager.pagePtr(singlePageId.value().at(0)); + EXPECT_EQ(singlePage->data(), expectedPages.at(0)->data()); + + // auto [claimed2, pages] = manager.claimPages(7); + auto pages = manager.claimPages(7); + ASSERT_TRUE(pages.has_value()); + EXPECT_EQ(pages.value().size(), 7); + for (std::size_t i = 1; i < 8; ++i) + { + EXPECT_EQ(manager.pagePtr(pages.value().at(i - 1))->data(), expectedPages.at(i)->data()); + } + + // auto [claimed3, empty1] = manager.claimPages(1); + auto empty1 = manager.claimPages(1); + ASSERT_FALSE(empty1.has_value()); + // EXPECT_EQ(empty1.size(), 0); + + manager.releasePages(std::move(singlePageId.value())); + // auto [claimed4, singlePageId2] = manager.claimPages(1); + auto singlePageId2 = manager.claimPages(1); + + // check that page slots are freed when it's released + ASSERT_TRUE(singlePageId2.has_value()); + EXPECT_EQ(singlePageId2.value().size(), 1); + EXPECT_EQ(manager.pagePtr(singlePageId2.value().at(0))->data(), expectedPages.at(0)->data()); +} + +TEST_F(LoraCacheTest, determineNumPages) +{ + GptModelConfig modelConfig(0, 2, 1, 4, nvinfer1::DataType::kFLOAT); + modelConfig.setLoraModules(LoraModule::createLoraModules({"attn_dense", "attn_qkv"}, 4, 4, 1, 1, 2, 2)); + WorldConfig worldConfig(1, 1, 0); + + LoraCachePageManagerConfig pageConfig(MemoryType::kCPU, nvinfer1::DataType::kFLOAT, 12393, 40, 80, 16, 1); + + LoraCache cache(pageConfig, modelConfig, worldConfig, *mManager); + + std::vector loraConfigVector{ + 0, + 0, + 64, + 0, + 1, + 64, + }; + TensorPtr loraConfig + = ITensor::wrap(loraConfigVector, ITensor::makeShape({static_cast(loraConfigVector.size()) / 3, 3})); + auto numPages = cache.determineNumPages(loraConfig); + EXPECT_EQ(numPages, 2); + + loraConfigVector = std::vector{ + 0, + 0, + 32, + 0, + 0, + 32, + 0, + 0, + 64, + }; + + loraConfig + = ITensor::wrap(loraConfigVector, ITensor::makeShape({static_cast(loraConfigVector.size()) / 3, 3})); + numPages = cache.determineNumPages(loraConfig); + EXPECT_EQ(numPages, 2); + + loraConfigVector = std::vector{ + 0, + 0, + 32, + 0, + 0, + 32, + 0, + 0, + 64, + 0, + 0, + 24, + }; + + loraConfig + = ITensor::wrap(loraConfigVector, ITensor::makeShape({static_cast(loraConfigVector.size()) / 3, 3})); + numPages = cache.determineNumPages(loraConfig); + EXPECT_EQ(numPages, 3); + + loraConfigVector = std::vector{ + 0, + 0, + 60, + 0, + 0, + 1, + 0, + 0, + 1, + 0, + 0, + 1, + 0, + 0, + 1, + }; + + loraConfig + = ITensor::wrap(loraConfigVector, ITensor::makeShape({static_cast(loraConfigVector.size()) / 3, 3})); + numPages = cache.determineNumPages(loraConfig); + EXPECT_EQ(numPages, 2); + + loraConfigVector = std::vector{ + 0, + 0, + 60, + 0, + 0, + 1, + 0, + 0, + 1, + 0, + 0, + 1, + 0, + 0, + 64, + 0, + 0, + 1, + }; + + loraConfig + = ITensor::wrap(loraConfigVector, ITensor::makeShape({static_cast(loraConfigVector.size()) / 3, 3})); + numPages = cache.determineNumPages(loraConfig); + EXPECT_EQ(numPages, 4); +} + +TEST_F(LoraCacheTest, basicPutGet) +{ + TensorPtr loraReqWeights = utils::loadNpy(*mManager, TEST_SOURCE_LORA_TP2, MemoryType::kCPU); + TensorPtr loraReqKeys = utils::loadNpy(*mManager, TEST_KEYS_LORA_TP2, MemoryType::kCPU); + TensorPtr loraDestWeights = utils::loadNpy(*mManager, TEST_DEST_LORA_TP2, MemoryType::kCPU); + + EXPECT_FALSE(mLoraCache->has(1234)); + mLoraCache->put(1234, loraReqWeights, loraReqKeys); + EXPECT_TRUE(mLoraCache->has(1234)); + EXPECT_TRUE(mLoraCache->isLoaded(1234)); + auto const& values = *mLoraCache->get(1234); + + std::vector expectedValues{{0, 0, 128, 192, 0, 0, 8, 5}, + {0, 5, 128, 192, 0, 1, 8, 5}, {0, 10, 64, 32, 1, 0, 4, 2}, {0, 12, 64, 32, 1, 1, 4, 2}, + {0, 14, 64, 32, 2, 0, 4, 2}, {0, 16, 64, 32, 2, 1, 4, 2}, {0, 18, 64, 32, 3, 0, 4, 2}, + {0, 20, 64, 32, 3, 1, 4, 2}, {0, 22, 64, 128, 4, 0, 8, 3}, {0, 25, 64, 128, 4, 1, 8, 3}, + {0, 28, 128, 128, 5, 0, 8, 4}, {0, 32, 128, 128, 5, 1, 8, 4}, {0, 36, 128, 128, 6, 0, 8, 4}, + {0, 40, 128, 128, 6, 1, 8, 4}, {0, 44, 128, 128, 7, 0, 8, 4}, {0, 48, 128, 128, 7, 1, 8, 4}, + {0, 52, 128, 192, 8, 0, 8, 5}, {0, 57, 128, 192, 8, 1, 8, 5}, {0, 62, 64, 32, 9, 0, 4, 2}, + {1, 0, 64, 32, 9, 1, 4, 2}, {1, 2, 64, 32, 10, 0, 4, 2}, {1, 4, 64, 32, 10, 1, 4, 2}, + {1, 6, 64, 32, 11, 0, 4, 2}, {1, 8, 64, 32, 11, 1, 4, 2}, {1, 10, 64, 128, 12, 0, 8, 3}, + {1, 13, 64, 128, 12, 1, 8, 3}}; + + ASSERT_EQ(expectedValues.size(), values.size()); + for (std::size_t i = 0; i < expectedValues.size(); ++i) + { + EXPECT_EQ(expectedValues.at(i), values.at(i)); + } + + ASSERT_EQ(values.size(), loraDestWeights->getShape().d[0]); + auto const tpSize = mWorldConfig->getTensorParallelism(); + for (size_t i = 0; i < values.size(); ++i) + { + auto const configRowPtr = bufferCast(*ITensor::slice(loraReqKeys, i, 1)); + auto const modId = configRowPtr[lora::kLORA_CONFIG_MODULE_OFF]; + auto const adapterSize = configRowPtr[lora::kLORA_CONFIG_ADAPTER_SIZE_OFF]; + auto modIt = std::find_if(mModelConfig->getLoraModules().begin(), mModelConfig->getLoraModules().end(), + [modId = modId](LoraModule const& m) { return m.value() == modId; }); + auto const inSize = modIt->inSize(adapterSize); + auto const outSize = modIt->outSize(adapterSize); + + float const* weightsInPtr = reinterpret_cast(values[i].weightsInPointer); + float const* weightsOutPtr = reinterpret_cast(values[i].weightsOutPointer); + + TensorPtr row = ITensor::slice(loraDestWeights, i, 1); + auto const rowSize = static_cast(ITensor::volume(row->getShape())); + TensorPtr rowFlatView = ITensor::view(row, ITensor::makeShape({rowSize})); + TensorPtr expectedIn = ITensor::slice(rowFlatView, 0, inSize); + TensorPtr expectedOut = ITensor::slice(rowFlatView, inSize, outSize); + auto const expectedInPtr = bufferCast(*expectedIn); + auto const expectedOutPtr = bufferCast(*expectedOut); + for (size_t j = 0; j < values.at(i).inSize; ++j) + { + EXPECT_FLOAT_EQ(weightsInPtr[j], expectedInPtr[j]); + } + for (size_t j = 0; j < values.at(i).outSize; ++j) + { + EXPECT_FLOAT_EQ(weightsOutPtr[j], expectedOutPtr[j]); + } + } + + mLoraCache->copyTask(1234, *mLoraCache2); + + auto const& values2 = *mLoraCache2->get(1234); + ASSERT_EQ(values.size(), values2.size()); + for (size_t i = 0; i < values.size(); ++i) + { + EXPECT_EQ(values.at(i), values2.at(i)); + auto page1 = mLoraCache->getPagePtr(values.at(i).pageId); + auto page2 = mLoraCache2->getPagePtr(values2.at(i).pageId); + auto hostPage2 = mManager->copyFrom(*page2, runtime::MemoryType::kCPU); + ASSERT_TRUE(ITensor::shapeEquals(page1->getShape(), page2->getShape())); + auto const pageSize = page1->getSize(); + auto const p1 = bufferCast(*page1); + auto const p2 = bufferCast(*hostPage2); + for (size_t i = 0; i < static_cast(pageSize); ++i) + { + ASSERT_FLOAT_EQ(p1[i], p2[i]); + } + } +} + +TEST_F(LoraCacheTest, splitTransposeCpu) +{ + auto modelConfig = GptModelConfig(0, 2, 1, 16, nvinfer1::DataType::kFLOAT); + auto worldConfig = WorldConfig(2, 1, 0); + + SizeType const split{2}; + std::vector const input{28524, 287, 5093, 12, 23316, 4881, 11, 30022, 263, 8776, 355, 257}; + std::vector const outputRank0{28524, 5093, 23316, 11, 263, 355}; + std::vector const outputRank1{287, 12, 4881, 30022, 8776, 257}; + std::vector const output2Rank0{28524, 287, 23316, 4881, 263, 8776}; + std::vector const output2Rank1{5093, 12, 11, 30022, 355, 257}; + + { + SizeType const batchSize{6}; + auto const inputLength = static_cast(input.size() / batchSize); + auto const inputShape = ITensor::makeShape({batchSize, inputLength}); + auto const outputShape = ITensor::makeShape({batchSize, inputLength / split}); + + auto inputTensor = mManager->copyFrom(input, inputShape, MemoryType::kCPU); + auto outputTensorRank0 = mManager->cpu(outputShape, nvinfer1::DataType::kINT32); + auto outputTensorRank1 = mManager->cpu(outputShape, nvinfer1::DataType::kINT32); + mManager->setZero(*outputTensorRank0); + mManager->setZero(*outputTensorRank1); + + auto outputPtrRank0 = bufferCast(*outputTensorRank0); + auto outputPtrRank1 = bufferCast(*outputTensorRank1); + + LoraCache::splitTransposeCpu(*outputTensorRank0, *inputTensor, split, 0); + LoraCache::splitTransposeCpu(*outputTensorRank1, *inputTensor, split, 1); + + for (SizeType i = 0; i < static_cast(outputRank0.size()); ++i) + { + EXPECT_EQ(outputPtrRank0[i], outputRank0[i]); + EXPECT_EQ(outputPtrRank1[i], outputRank1[i]); + } + } + + { + SizeType const batchSize{3}; + auto const inputLength = static_cast(input.size() / batchSize); + auto const inputShape = ITensor::makeShape({batchSize, inputLength}); + auto const outputShape = ITensor::makeShape({batchSize, inputLength / split}); + + auto inputTensor = mManager->copyFrom(input, inputShape, MemoryType::kCPU); + auto outputTensorRank0 = mManager->cpu(outputShape, nvinfer1::DataType::kINT32); + auto outputTensorRank1 = mManager->cpu(outputShape, nvinfer1::DataType::kINT32); + mManager->setZero(*outputTensorRank0); + mManager->setZero(*outputTensorRank1); + + LoraCache::splitTransposeCpu(*outputTensorRank0, *inputTensor, split, 0); + LoraCache::splitTransposeCpu(*outputTensorRank1, *inputTensor, split, 1); + + auto outputPtrRank0 = bufferCast(*outputTensorRank0); + auto outputPtrRank1 = bufferCast(*outputTensorRank1); + + for (SizeType i = 0; i < static_cast(outputRank0.size()); ++i) + { + EXPECT_EQ(outputPtrRank0[i], output2Rank0[i]); + EXPECT_EQ(outputPtrRank1[i], output2Rank1[i]); + } + } +} + +TEST_F(LoraCacheTest, copyToPages_tp1) +{ + auto modelConfig = GptModelConfig(0, 2, 1, 16, nvinfer1::DataType::kFLOAT); + modelConfig.setMlpHiddenSize(32); + auto worldConfig = WorldConfig(1, 1, 0); + std::vector modules{ + LoraModule(LoraModule::ModuleType::kATTN_QKV, 16, 3 * 16, false, true, -1, 0), + LoraModule(LoraModule::ModuleType::kATTN_Q, 16, 16, false, true, -1, 0), + LoraModule(LoraModule::ModuleType::kATTN_K, 16, 16, false, true, -1, 0), + LoraModule(LoraModule::ModuleType::kATTN_V, 16, 16, false, true, -1, 0), + LoraModule(LoraModule::ModuleType::kATTN_DENSE, 16, 16, false, true, 1, -1), + LoraModule(LoraModule::ModuleType::kMLP_H_TO_4H, 16, 32, false, true, -1, 0), + LoraModule(LoraModule::ModuleType::kMLP_4H_TO_H, 32, 16, false, true, 1, -1), + LoraModule(LoraModule::ModuleType::kMLP_GATE, 16, 32, false, true, -1, 0), + LoraModule(LoraModule::ModuleType::kCROSS_ATTN_QKV, 16, 3 * 16, false, true, -1, 0), + LoraModule(LoraModule::ModuleType::kCROSS_ATTN_Q, 16, 16, false, true, -1, 0), + LoraModule(LoraModule::ModuleType::kCROSS_ATTN_K, 16, 16, false, true, -1, 0), + LoraModule(LoraModule::ModuleType::kCROSS_ATTN_V, 16, 16, false, true, -1, 0), + LoraModule(LoraModule::ModuleType::kCROSS_ATTN_DENSE, 16, 16, false, true, 1, -1), + }; + modelConfig.setLoraModules(modules); + std::unordered_map moduleIdToModule; + for (auto const& m : modelConfig.getLoraModules()) + { + moduleIdToModule[m.value()] = m; + } + + TensorPtr loraReqWeights = utils::loadNpy(*mManager, TEST_SOURCE_LORA_TP1, MemoryType::kCPU); + loraReqWeights->unsqueeze(0); + TensorPtr loraReqKeys = utils::loadNpy(*mManager, TEST_KEYS_LORA_TP1, MemoryType::kCPU); + loraReqKeys->unsqueeze(0); + TensorPtr loraTargetTensors = utils::loadNpy(*mManager, TEST_DEST_LORA_TP1, MemoryType::kCPU); + + TensorPtr targetPageBlock = utils::loadNpy(*mManager, TEST_KEYS_LORA_TP1_PAGES_RANK0, MemoryType::kCPU); + TensorPtr pageBlock = mManager->cpu(targetPageBlock->getShape(), targetPageBlock->getDataType()); + mManager->setZero(*pageBlock); + std::vector pages; + for (SizeType p = 0; p < pageBlock->getShape().d[0]; ++p) + { + pages.push_back(ITensor::view(ITensor::slice(pageBlock, p, 1), + ITensor::makeShape({pageBlock->getShape().d[1], pageBlock->getShape().d[2]}))); + } + std::vector pageIds{}; + pageIds.resize(pages.size()); + std::iota(pageIds.begin(), pageIds.end(), 0); + + auto locations = LoraCache::copyToPages( + loraReqWeights, loraReqKeys, modelConfig, worldConfig, moduleIdToModule, *mManager, pages, pageIds); + + auto pagePtr = bufferCast(*pageBlock); + auto targetPtr = bufferCast(*targetPageBlock); + + for (SizeType i = 0; i < pageBlock->getSize(); ++i) + { + EXPECT_FLOAT_EQ(pagePtr[i], targetPtr[i]); + } +} + +TEST_F(LoraCacheTest, copyToPages_tp2_rank0) +{ + auto modelConfig = GptModelConfig(0, 2, 1, 16, nvinfer1::DataType::kFLOAT); + modelConfig.setMlpHiddenSize(32); + auto worldConfig = WorldConfig(2, 1, 0); + std::vector modules{ + LoraModule(LoraModule::ModuleType::kATTN_QKV, 16, 3 * 16, false, true, -1, 0), + LoraModule(LoraModule::ModuleType::kATTN_Q, 16, 16, false, true, -1, 0), + LoraModule(LoraModule::ModuleType::kATTN_K, 16, 16, false, true, -1, 0), + LoraModule(LoraModule::ModuleType::kATTN_V, 16, 16, false, true, -1, 0), + LoraModule(LoraModule::ModuleType::kATTN_DENSE, 16, 16, false, true, 1, -1), + LoraModule(LoraModule::ModuleType::kMLP_H_TO_4H, 16, 32, false, true, -1, 0), + LoraModule(LoraModule::ModuleType::kMLP_4H_TO_H, 32, 16, false, true, 1, -1), + LoraModule(LoraModule::ModuleType::kMLP_GATE, 16, 32, false, true, -1, 0), + LoraModule(LoraModule::ModuleType::kCROSS_ATTN_QKV, 16, 3 * 16, false, true, -1, 0), + LoraModule(LoraModule::ModuleType::kCROSS_ATTN_Q, 16, 16, false, true, -1, 0), + LoraModule(LoraModule::ModuleType::kCROSS_ATTN_K, 16, 16, false, true, -1, 0), + LoraModule(LoraModule::ModuleType::kCROSS_ATTN_V, 16, 16, false, true, -1, 0), + LoraModule(LoraModule::ModuleType::kCROSS_ATTN_DENSE, 16, 16, false, true, 1, -1), + }; + modelConfig.setLoraModules(modules); + std::unordered_map moduleIdToModule; + for (auto const& m : modelConfig.getLoraModules()) + { + moduleIdToModule[m.value()] = m; + } + + TensorPtr loraReqWeights = utils::loadNpy(*mManager, TEST_SOURCE_LORA_TP2, MemoryType::kCPU); + loraReqWeights->unsqueeze(0); + TensorPtr loraReqKeys = utils::loadNpy(*mManager, TEST_KEYS_LORA_TP2, MemoryType::kCPU); + loraReqKeys->unsqueeze(0); + + TensorPtr targetPageBlock = utils::loadNpy(*mManager, TEST_KEYS_LORA_TP2_PAGES_RANK0, MemoryType::kCPU); + TensorPtr pageBlock = mManager->cpu(targetPageBlock->getShape(), targetPageBlock->getDataType()); + mManager->setZero(*pageBlock); + std::vector pages; + for (SizeType p = 0; p < pageBlock->getShape().d[0]; ++p) + { + pages.push_back(ITensor::view(ITensor::slice(pageBlock, p, 1), + ITensor::makeShape({pageBlock->getShape().d[1], pageBlock->getShape().d[2]}))); + } + std::vector pageIds{}; + pageIds.resize(pages.size()); + std::iota(pageIds.begin(), pageIds.end(), 0); + + auto locations = LoraCache::copyToPages( + loraReqWeights, loraReqKeys, modelConfig, worldConfig, moduleIdToModule, *mManager, pages, pageIds); + + auto pagePtr = bufferCast(*pageBlock); + auto targetPtr = bufferCast(*targetPageBlock); + + for (SizeType i = 0; i < pageBlock->getSize(); ++i) + { + EXPECT_FLOAT_EQ(pagePtr[i], targetPtr[i]); + } +} + +TEST_F(LoraCacheTest, copyToPages_tp2_rank1) +{ + auto modelConfig = GptModelConfig(0, 2, 1, 16, nvinfer1::DataType::kFLOAT); + modelConfig.setMlpHiddenSize(32); + auto worldConfig = WorldConfig(2, 1, 1); + std::vector modules{ + LoraModule(LoraModule::ModuleType::kATTN_QKV, 16, 3 * 16, false, true, -1, 0), + LoraModule(LoraModule::ModuleType::kATTN_Q, 16, 16, false, true, -1, 0), + LoraModule(LoraModule::ModuleType::kATTN_K, 16, 16, false, true, -1, 0), + LoraModule(LoraModule::ModuleType::kATTN_V, 16, 16, false, true, -1, 0), + LoraModule(LoraModule::ModuleType::kATTN_DENSE, 16, 16, false, true, 1, -1), + LoraModule(LoraModule::ModuleType::kMLP_H_TO_4H, 16, 32, false, true, -1, 0), + LoraModule(LoraModule::ModuleType::kMLP_4H_TO_H, 32, 16, false, true, 1, -1), + LoraModule(LoraModule::ModuleType::kMLP_GATE, 16, 32, false, true, -1, 0), + LoraModule(LoraModule::ModuleType::kCROSS_ATTN_QKV, 16, 3 * 16, false, true, -1, 0), + LoraModule(LoraModule::ModuleType::kCROSS_ATTN_Q, 16, 16, false, true, -1, 0), + LoraModule(LoraModule::ModuleType::kCROSS_ATTN_K, 16, 16, false, true, -1, 0), + LoraModule(LoraModule::ModuleType::kCROSS_ATTN_V, 16, 16, false, true, -1, 0), + LoraModule(LoraModule::ModuleType::kCROSS_ATTN_DENSE, 16, 16, false, true, 1, -1), + }; + modelConfig.setLoraModules(modules); + std::unordered_map moduleIdToModule; + for (auto const& m : modelConfig.getLoraModules()) + { + moduleIdToModule[m.value()] = m; + } + + TensorPtr loraReqWeights = utils::loadNpy(*mManager, TEST_SOURCE_LORA_TP2, MemoryType::kCPU); + loraReqWeights->unsqueeze(0); + TensorPtr loraReqKeys = utils::loadNpy(*mManager, TEST_KEYS_LORA_TP2, MemoryType::kCPU); + loraReqKeys->unsqueeze(0); + + TensorPtr targetPageBlock = utils::loadNpy(*mManager, TEST_KEYS_LORA_TP2_PAGES_RANK1, MemoryType::kCPU); + TensorPtr pageBlock = mManager->cpu(targetPageBlock->getShape(), targetPageBlock->getDataType()); + mManager->setZero(*pageBlock); + std::vector pages; + for (SizeType p = 0; p < pageBlock->getShape().d[0]; ++p) + { + pages.push_back(ITensor::view(ITensor::slice(pageBlock, p, 1), + ITensor::makeShape({pageBlock->getShape().d[1], pageBlock->getShape().d[2]}))); + } + std::vector pageIds{}; + pageIds.resize(pages.size()); + std::iota(pageIds.begin(), pageIds.end(), 0); + + auto locations = LoraCache::copyToPages( + loraReqWeights, loraReqKeys, modelConfig, worldConfig, moduleIdToModule, *mManager, pages, pageIds); + + auto pagePtr = bufferCast(*pageBlock); + auto targetPtr = bufferCast(*targetPageBlock); + + for (SizeType i = 0; i < pageBlock->getSize(); ++i) + { + EXPECT_FLOAT_EQ(pagePtr[i], targetPtr[i]); + } +} +} // namespace tensorrt_llm::runtime diff --git a/cpp/tests/runtime/loraManagerTest.cpp b/cpp/tests/runtime/loraManagerTest.cpp index 1176f3a31..a57be7b53 100644 --- a/cpp/tests/runtime/loraManagerTest.cpp +++ b/cpp/tests/runtime/loraManagerTest.cpp @@ -14,7 +14,6 @@ * limitations under the License. */ -#include #include #include "tensorrt_llm/common/memoryUtils.h" @@ -25,8 +24,10 @@ #include "tensorrt_llm/runtime/gptModelConfig.h" #include "tensorrt_llm/runtime/iBuffer.h" #include "tensorrt_llm/runtime/iTensor.h" +#include "tensorrt_llm/runtime/loraCache.h" #include "tensorrt_llm/runtime/loraManager.h" #include "tensorrt_llm/runtime/loraModule.h" +#include "tensorrt_llm/runtime/loraUtils.h" #include "tensorrt_llm/runtime/worldConfig.h" #include "tensorrt_llm/runtime/utils/numpyUtils.h" @@ -52,6 +53,7 @@ auto const TEST_MODEL_CONFIG = TEST_RESOURCE_PATH / "test_model_lora_config.json namespace tensorrt_llm::runtime { using TensorPtr = ITensor::SharedPtr; +using PeftTable = LoraManager::PeftTable; class LoraManagerTest : public ::testing::Test // NOLINT(cppcoreguidelines-pro-type-member-init) { @@ -75,6 +77,36 @@ class LoraManagerTest : public ::testing::Test // NOLINT(cppcoreguidelines-pro-t BufferManager::CudaStreamPtr mStream; GptModelConfig mModelConfig; WorldConfig mWorldConfig; + + PeftTable getPeftTable(SizeType tpRank = 0) + { + auto modelConfig = GptModelConfig(0, 2, 1, 16, nvinfer1::DataType::kFLOAT); + modelConfig.setMlpHiddenSize(32); + auto worldConfig = WorldConfig(2, 2, 3); + std::vector modules{ + LoraModule(LoraModule::ModuleType::kATTN_QKV, 16, 3 * 16, false, true, -1, 0), + LoraModule(LoraModule::ModuleType::kATTN_Q, 16, 16, false, true, -1, 0), + LoraModule(LoraModule::ModuleType::kATTN_K, 16, 16, false, true, -1, 0), + LoraModule(LoraModule::ModuleType::kATTN_V, 16, 16, false, true, -1, 0), + LoraModule(LoraModule::ModuleType::kATTN_DENSE, 16, 16, false, true, 1, -1), + LoraModule(LoraModule::ModuleType::kMLP_H_TO_4H, 16, 32, false, true, -1, 0), + LoraModule(LoraModule::ModuleType::kMLP_4H_TO_H, 32, 16, false, true, 1, -1), + LoraModule(LoraModule::ModuleType::kMLP_GATE, 16, 32, false, true, -1, 0), + }; + modelConfig.setLoraModules(modules); + auto pageConfig = LoraCachePageManagerConfig( + runtime::MemoryType::kCPU, nvinfer1::DataType::kFLOAT, 2 * 8, 6, 64, 4 * 16, 1); + pageConfig.setInitToZero(true); + LoraCache loraCache(pageConfig, modelConfig, worldConfig, *mManager); + + TensorPtr loraReqWeights = utils::loadNpy(*mManager, TEST_SOURCE_LORA_TP2, MemoryType::kCPU); + TensorPtr loraReqKeys = utils::loadNpy(*mManager, TEST_KEYS_LORA_TP2, MemoryType::kCPU); + + loraCache.put(1234, loraReqWeights, loraReqKeys); + PeftTable peftTable{}; + peftTable.try_emplace(1234, loraCache.get(1234)); + return peftTable; + } }; TEST_F(LoraManagerTest, moduleParsing) @@ -109,170 +141,10 @@ TEST_F(LoraManagerTest, moduleParsing) } } -TEST_F(LoraManagerTest, formatTensors_tp1) -{ - LoraManager loraManager; - auto modelConfig = GptModelConfig(0, 2, 1, 16, nvinfer1::DataType::kFLOAT); - modelConfig.setMlpHiddenSize(32); - auto worldConfig = WorldConfig(1, 1, 0); - std::vector modules{ - LoraModule(LoraModule::ModuleType::kATTN_QKV, 16, 3 * 16, false, true, -1, 0), - LoraModule(LoraModule::ModuleType::kATTN_Q, 16, 16, false, true, -1, 0), - LoraModule(LoraModule::ModuleType::kATTN_K, 16, 16, false, true, -1, 0), - LoraModule(LoraModule::ModuleType::kATTN_V, 16, 16, false, true, -1, 0), - LoraModule(LoraModule::ModuleType::kATTN_DENSE, 16, 16, false, true, 1, -1), - LoraModule(LoraModule::ModuleType::kMLP_H_TO_4H, 16, 32, false, true, -1, 0), - LoraModule(LoraModule::ModuleType::kMLP_4H_TO_H, 32, 16, false, true, 1, -1), - LoraModule(LoraModule::ModuleType::kMLP_GATE, 16, 32, false, true, -1, 0), - LoraModule(LoraModule::ModuleType::kCROSS_ATTN_QKV, 16, 3 * 16, false, true, -1, 0), - LoraModule(LoraModule::ModuleType::kCROSS_ATTN_Q, 16, 16, false, true, -1, 0), - LoraModule(LoraModule::ModuleType::kCROSS_ATTN_K, 16, 16, false, true, -1, 0), - LoraModule(LoraModule::ModuleType::kCROSS_ATTN_V, 16, 16, false, true, -1, 0), - LoraModule(LoraModule::ModuleType::kCROSS_ATTN_DENSE, 16, 16, false, true, 1, -1), - }; - modelConfig.setLoraModules(modules); - loraManager.create(modelConfig, worldConfig, *mManager); - - TensorPtr loraReqWeights = utils::loadNpy(*mManager, TEST_SOURCE_LORA_TP1.string(), MemoryType::kGPU); - loraReqWeights->unsqueeze(0); - TensorPtr loraReqKeys = utils::loadNpy(*mManager, TEST_KEYS_LORA_TP1.string(), MemoryType::kCPU); - loraReqKeys->unsqueeze(0); - TensorPtr loraTargetTensors = utils::loadNpy(*mManager, TEST_DEST_LORA_TP1.string(), MemoryType::kCPU); - - loraManager.formatTaskTensors(loraReqWeights, loraReqKeys, modelConfig, worldConfig, *mManager); - TensorPtr hostWeights = mManager->copyFrom(*loraReqWeights, MemoryType::kCPU); - mManager->getStream().synchronize(); - - auto srcPtr = bufferCast(*hostWeights); - auto destPtr = bufferCast(*loraTargetTensors); - - for (SizeType i = 0; i < loraReqWeights->getSize(); ++i) - { - EXPECT_FLOAT_EQ(srcPtr[i], destPtr[i]); - } -} - -TEST_F(LoraManagerTest, formatTensors_tp2) -{ - LoraManager loraManager; - auto modelConfig = GptModelConfig(0, 2, 1, 16, nvinfer1::DataType::kFLOAT); - modelConfig.setMlpHiddenSize(32); - auto worldConfig = WorldConfig(2, 1, 0); - std::vector modules{ - LoraModule(LoraModule::ModuleType::kATTN_QKV, 16, 3 * 16, false, true, -1, 0), - LoraModule(LoraModule::ModuleType::kATTN_Q, 16, 16, false, true, -1, 0), - LoraModule(LoraModule::ModuleType::kATTN_K, 16, 16, false, true, -1, 0), - LoraModule(LoraModule::ModuleType::kATTN_V, 16, 16, false, true, -1, 0), - LoraModule(LoraModule::ModuleType::kATTN_DENSE, 16, 16, false, true, 1, -1), - LoraModule(LoraModule::ModuleType::kMLP_H_TO_4H, 16, 32, false, true, -1, 0), - LoraModule(LoraModule::ModuleType::kMLP_4H_TO_H, 32, 16, false, true, 1, -1), - LoraModule(LoraModule::ModuleType::kMLP_GATE, 16, 32, false, true, -1, 0), - LoraModule(LoraModule::ModuleType::kCROSS_ATTN_QKV, 16, 3 * 16, false, true, -1, 0), - LoraModule(LoraModule::ModuleType::kCROSS_ATTN_Q, 16, 16, false, true, -1, 0), - LoraModule(LoraModule::ModuleType::kCROSS_ATTN_K, 16, 16, false, true, -1, 0), - LoraModule(LoraModule::ModuleType::kCROSS_ATTN_V, 16, 16, false, true, -1, 0), - LoraModule(LoraModule::ModuleType::kCROSS_ATTN_DENSE, 16, 16, false, true, 1, -1), - }; - modelConfig.setLoraModules(modules); - loraManager.create(modelConfig, worldConfig, *mManager); - - TensorPtr loraReqWeights = utils::loadNpy(*mManager, TEST_SOURCE_LORA_TP2.string(), MemoryType::kGPU); - loraReqWeights->unsqueeze(0); - TensorPtr loraReqKeys = utils::loadNpy(*mManager, TEST_KEYS_LORA_TP2.string(), MemoryType::kCPU); - loraReqKeys->unsqueeze(0); - TensorPtr loraTargetTensors = utils::loadNpy(*mManager, TEST_DEST_LORA_TP2.string(), MemoryType::kCPU); - - loraManager.formatTaskTensors(loraReqWeights, loraReqKeys, modelConfig, worldConfig, *mManager); - TensorPtr hostWeights = mManager->copyFrom(*loraReqWeights, MemoryType::kCPU); - mManager->getStream().synchronize(); - - auto srcPtr = bufferCast(*hostWeights); - auto destPtr = bufferCast(*loraTargetTensors); - - for (SizeType i = 0; i < loraReqWeights->getSize(); ++i) - { - EXPECT_FLOAT_EQ(srcPtr[i], destPtr[i]); - } -} - -TEST_F(LoraManagerTest, LoraManager_addTask) -{ - LoraManager manager; - manager.create(mModelConfig, mWorldConfig, *mManager); - - std::vector taskNLayers{4, 6}; - std::vector taskMod{0, 1}; - std::vector taskSizes{16, 8}; - - for (SizeType taskNum = 0; taskNum < static_cast(taskSizes.size()); ++taskNum) - { - - auto mod = taskMod[taskNum]; - auto nLayers = taskNLayers[taskNum]; - auto taskSize = taskSizes[taskNum]; - auto taskName = taskNum; - // bs=1 - // nbModules=1 - // nbLayers=4 - // adapterSize=16 - // Hi=128 - // Ho=3*128 - auto weightsShape = ITensor::makeShape({1, 1 * nLayers, taskSize * 128 + taskSize * 3 * 128}); - auto weights = mManager->cpu(weightsShape, nvinfer1::DataType::kFLOAT); - auto weightsPtr = bufferCast(*weights); - std::fill_n(weightsPtr, weights->getSize(), 1.f * taskNum); - - auto keysShape = ITensor::makeShape({1, 1 * nLayers, 3}); - - auto keys = mManager->cpu(keysShape, nvinfer1::DataType::kINT32); - auto keysPtr = bufferCast(*keys); - SizeType off = 0; - for (SizeType i = 0; i < nLayers; ++i) - { - keysPtr[off++] = mod; - keysPtr[off++] = i; - keysPtr[off++] = taskSize; - } - - weights->squeeze(0); - keys->squeeze(0); - - manager.addTask(taskName, std::move(weights), std::move(keys)); - } - - for (SizeType taskNum = 0; taskNum < static_cast(taskSizes.size()); ++taskNum) - { - auto mod = taskMod[taskNum]; - auto nLayers = taskNLayers[taskNum]; - auto taskSize = taskSizes[taskNum]; - auto taskName = taskNum; - auto modName = taskNum == 0 ? "attn_qkv" : "attn_q"; - - auto [taskWeights, taskKeys] = manager.getTask(taskName); - auto taskKeysPtr = bufferCast(*taskKeys); - - auto numWeights = static_cast(taskWeights->getSize()); - auto hostWeightsPtr = bufferCast(*taskWeights); - - for (SizeType i = 0; i < numWeights; ++i) - { - EXPECT_FLOAT_EQ(1.f * taskNum, hostWeightsPtr[i]); - } - - SizeType off = 0; - for (SizeType i = 0; i < taskNLayers[taskNum]; ++i) - { - EXPECT_EQ(taskKeysPtr[off++], taskMod[taskNum]); - EXPECT_EQ(taskKeysPtr[off++], i); - EXPECT_EQ(taskKeysPtr[off++], taskSizes[taskNum]); - } - } -} - static void checkLoraTensors(LoraManager const& loraManager, std::vector const& targetPtrs, TensorPtr weightsPtrs, std::vector const& targetAdapterSizes, TensorPtr adapterSizes, GptModelConfig const& modelConfig, WorldConfig const& worldConfig, std::vector const& modules, - SizeType numModules, SizeType numLayers, SizeType numSeqs) + SizeType numModules, SizeType numLayers, SizeType numSeqs, bool checkPointers = true) { auto adapterSizesPtr = bufferCast(*adapterSizes); auto weightsPtrsPtr = bufferCast(*weightsPtrs); @@ -306,8 +178,17 @@ static void checkLoraTensors(LoraManager const& loraManager, std::vector, std::vector, PeftTable> createFillInputTensorsTestsData( + std::vector const& configs, std::vector const& reqIds, + std::vector const& reqBeamWidth, std::vector const& modules, SizeType numLayers, + SizeType numSeq, std::vector& valuesWorkspace) { - LoraManager loraManager; - auto modelConfig = GptModelConfig(0, 2, 1, 16, nvinfer1::DataType::kFLOAT); - modelConfig.setMlpHiddenSize(32); - auto worldConfig = WorldConfig(1, 1, 0); - std::vector modules{ - LoraModule(LoraModule::ModuleType::kATTN_QKV, 16, 3 * 16, false, true, -1, 0), - LoraModule(LoraModule::ModuleType::kATTN_Q, 16, 16, false, true, -1, 0), - LoraModule(LoraModule::ModuleType::kATTN_K, 16, 16, false, true, -1, 0), - LoraModule(LoraModule::ModuleType::kATTN_V, 16, 16, false, true, -1, 0), - LoraModule(LoraModule::ModuleType::kATTN_DENSE, 16, 16, false, true, 1, -1), - LoraModule(LoraModule::ModuleType::kMLP_H_TO_4H, 16, 32, false, true, -1, 0), - LoraModule(LoraModule::ModuleType::kMLP_GATE, 16, 32, false, true, -1, 0), - LoraModule(LoraModule::ModuleType::kMLP_4H_TO_H, 32, 16, false, true, 1, -1), - LoraModule(LoraModule::ModuleType::kCROSS_ATTN_QKV, 16, 3 * 16, false, true, -1, 0), - LoraModule(LoraModule::ModuleType::kCROSS_ATTN_Q, 16, 16, false, true, -1, 0), - LoraModule(LoraModule::ModuleType::kCROSS_ATTN_K, 16, 16, false, true, -1, 0), - LoraModule(LoraModule::ModuleType::kCROSS_ATTN_V, 16, 16, false, true, -1, 0), - LoraModule(LoraModule::ModuleType::kCROSS_ATTN_DENSE, 16, 16, false, true, 1, -1), - }; - modelConfig.setLoraModules(modules); - loraManager.create(modelConfig, worldConfig, *mManager); - auto numModules = static_cast(modelConfig.getLoraModules().size()); - auto numLayers = static_cast(modelConfig.getNbLayers()); - SizeType numSeqs = 4; - TensorPtr weightsPtrs - = mManager->cpu(ITensor::makeShape({numModules, numLayers, numSeqs, 2}), nvinfer1::DataType::kINT64); - TensorPtr adapterSizes - = mManager->cpu(ITensor::makeShape({numModules, numLayers, numSeqs}), nvinfer1::DataType::kINT32); - - mManager->setZero(*weightsPtrs); - mManager->setZero(*adapterSizes); - - SizeType numContextRequests = 1; - std::vector reqIds{1, 2, 3}; - std::vector reqBeamWidth{1, 2, 1}; - std::vector loraEnabled{true, true, false}; - - TensorPtr loraReqKeys = utils::loadNpy(*mManager, TEST_KEYS_LORA_TP1.string(), MemoryType::kCPU); - TensorPtr loraWeights = utils::loadNpy(*mManager, TEST_DEST_LORA_TP1.string(), MemoryType::kGPU); + std::map moduleOffset; + SizeType modOff = 0; + for (auto const& m : modules) + { + moduleOffset[m.value()] = modOff++; + } - loraManager.addTask(1, loraWeights, loraReqKeys); - loraManager.addTask(2, loraWeights, loraReqKeys); + SizeType batchSize = configs.size(); + SizeType numModules = modules.size(); - loraManager.fillInputTensors( - weightsPtrs, adapterSizes, reqIds, reqBeamWidth, loraEnabled, numContextRequests, modelConfig, worldConfig); + std::vector targetAdapterSizes(numModules * numLayers * numSeq, 0); + std::vector targetPointers(numModules * numLayers * numSeq * 2, 0); - // set in order litest in modelConfig - SizeType attnQkvOff = 1; - SizeType attnDense = 0; + PeftTable peftTable{}; - auto inputWeightsPtrs = bufferCast(*loraWeights); + int64_t pointerAddr = 777001; - auto adapterSizesPtr = bufferCast(*adapterSizes); - auto weightsPtrsPtr = bufferCast(*weightsPtrs); - - auto weightsRowSize = loraWeights->getShape().d[1]; - - std::vector targetAdapterSizes{ - 8, 8, 8, 0, // attn_qkv layer 0 - 8, 8, 8, 0, // attn_qkv layer 1 - 4, 4, 4, 0, // attn_q layer 0 - 4, 4, 4, 0, // attn_q layer 1 - 4, 4, 4, 0, // attn_k layer 0 - 4, 4, 4, 0, // attn_k layer 1 - 4, 4, 4, 0, // attn_v layer 0 - 4, 4, 4, 0, // attn_v layer 1 - 8, 8, 8, 0, // attn_dense layer 0 - 8, 8, 8, 0, // attn_dense layer 1 - 8, 8, 8, 0, // mlp_h_to_4h layer 0 - 8, 8, 8, 0, // mlp_h_to_4h layer 1 - 8, 8, 8, 0, // mlp_gate layer 0 - 8, 8, 8, 0, // mlp_gate layer 1 - 8, 8, 8, 0, // mlp_4h_to_h layer 0 - 8, 8, 8, 0, // mlp_4h_to_h layer 1 - 8, 8, 8, 0, // cross_attn_qkv layer 0 - 8, 8, 8, 0, // cross_attn_qkv layer 1 - 4, 4, 4, 0, // cross_attn_q layer 0 - 4, 4, 4, 0, // cross_attn_q layer 1 - 4, 4, 4, 0, // cross_attn_k layer 0 - 4, 4, 4, 0, // cross_attn_k layer 1 - 4, 4, 4, 0, // cross_attn_v layer 0 - 4, 4, 4, 0, // cross_attn_v layer 1 - 8, 8, 8, 0, // cross_attn_dense layer 0 - 8, 8, 8, 0, // cross_attn_dense layer 1 - }; - - std::vector targetPtrs{ - // attn_qkv layer 0 - reinterpret_cast(inputWeightsPtrs + common::flat_index2(0, 0, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(0, 8 * 16, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(0, 0, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(0, 8 * 16, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(0, 0, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(0, 8 * 16, weightsRowSize)), - 0, - 0, - - // attn_qkv layer 1 - reinterpret_cast(inputWeightsPtrs + common::flat_index2(1, 0, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(1, 8 * 16, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(1, 0, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(1, 8 * 16, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(1, 0, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(1, 8 * 16, weightsRowSize)), - 0, - 0, - - // attn_q layer 0 - reinterpret_cast(inputWeightsPtrs + common::flat_index2(2, 0, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(2, 4 * 16, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(2, 0, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(2, 4 * 16, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(2, 0, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(2, 4 * 16, weightsRowSize)), - 0, - 0, - - // attn_q layer 1 - reinterpret_cast(inputWeightsPtrs + common::flat_index2(3, 0, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(3, 4 * 16, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(3, 0, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(3, 4 * 16, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(3, 0, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(3, 4 * 16, weightsRowSize)), - 0, - 0, - - // attn_k layer 0 - reinterpret_cast(inputWeightsPtrs + common::flat_index2(4, 0, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(4, 4 * 16, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(4, 0, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(4, 4 * 16, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(4, 0, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(4, 4 * 16, weightsRowSize)), - 0, - 0, - - // attn_k layer 1 - reinterpret_cast(inputWeightsPtrs + common::flat_index2(5, 0, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(5, 4 * 16, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(5, 0, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(5, 4 * 16, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(5, 0, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(5, 4 * 16, weightsRowSize)), - 0, - 0, - - // attn_v layer 0 - reinterpret_cast(inputWeightsPtrs + common::flat_index2(6, 0, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(6, 4 * 16, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(6, 0, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(6, 4 * 16, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(6, 0, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(6, 4 * 16, weightsRowSize)), - 0, - 0, - - // attn_v layer 1 - reinterpret_cast(inputWeightsPtrs + common::flat_index2(7, 0, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(7, 4 * 16, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(7, 0, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(7, 4 * 16, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(7, 0, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(7, 4 * 16, weightsRowSize)), - 0, - 0, - - // attn_dense layer 0 - reinterpret_cast(inputWeightsPtrs + common::flat_index2(8, 0, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(8, 8 * 16, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(8, 0, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(8, 8 * 16, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(8, 0, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(8, 8 * 16, weightsRowSize)), - 0, - 0, - - // attn_dense layer 1 - reinterpret_cast(inputWeightsPtrs + common::flat_index2(9, 0, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(9, 8 * 16, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(9, 0, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(9, 8 * 16, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(9, 0, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(9, 8 * 16, weightsRowSize)), - 0, - 0, - - // mlp_h_to_4h layer 0 - reinterpret_cast(inputWeightsPtrs + common::flat_index2(10, 0, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(10, 8 * 16, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(10, 0, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(10, 8 * 16, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(10, 0, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(10, 8 * 16, weightsRowSize)), - 0, - 0, - - // mlp_h_to_4h layer 1 - reinterpret_cast(inputWeightsPtrs + common::flat_index2(11, 0, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(11, 8 * 16, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(11, 0, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(11, 8 * 16, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(11, 0, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(11, 8 * 16, weightsRowSize)), - 0, - 0, - - // mlp_gate layer 0 - reinterpret_cast(inputWeightsPtrs + common::flat_index2(14, 0, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(14, 8 * 16, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(14, 0, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(14, 8 * 16, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(14, 0, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(14, 8 * 16, weightsRowSize)), - 0, - 0, - - // mlp_gate layer 1 - reinterpret_cast(inputWeightsPtrs + common::flat_index2(15, 0, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(15, 8 * 16, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(15, 0, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(15, 8 * 16, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(15, 0, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(15, 8 * 16, weightsRowSize)), - 0, - 0, - - // mlp_4h_to_h layer 0 - reinterpret_cast(inputWeightsPtrs + common::flat_index2(12, 0, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(12, 8 * 32, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(12, 0, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(12, 8 * 32, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(12, 0, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(12, 8 * 32, weightsRowSize)), - 0, - 0, - - // mlp_4h_to_h layer 1 - reinterpret_cast(inputWeightsPtrs + common::flat_index2(13, 0, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(13, 8 * 32, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(13, 0, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(13, 8 * 32, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(13, 0, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(13, 8 * 32, weightsRowSize)), - 0, - 0, - - // cross_attn_qkv layer 0 - reinterpret_cast(inputWeightsPtrs + common::flat_index2(16, 0, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(16, 8 * 16, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(16, 0, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(16, 8 * 16, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(16, 0, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(16, 8 * 16, weightsRowSize)), - 0, - 0, - - // cross_attn_qkv layer 1 - reinterpret_cast(inputWeightsPtrs + common::flat_index2(17, 0, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(17, 8 * 16, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(17, 0, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(17, 8 * 16, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(17, 0, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(17, 8 * 16, weightsRowSize)), - 0, - 0, - - // cross_attn_q layer 0 - reinterpret_cast(inputWeightsPtrs + common::flat_index2(18, 0, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(18, 4 * 16, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(18, 0, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(18, 4 * 16, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(18, 0, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(18, 4 * 16, weightsRowSize)), - 0, - 0, - - // cross_attn_q layer 1 - reinterpret_cast(inputWeightsPtrs + common::flat_index2(19, 0, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(19, 4 * 16, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(19, 0, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(19, 4 * 16, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(19, 0, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(19, 4 * 16, weightsRowSize)), - 0, - 0, - - // cross_attn_k layer 0 - reinterpret_cast(inputWeightsPtrs + common::flat_index2(20, 0, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(20, 4 * 16, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(20, 0, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(20, 4 * 16, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(20, 0, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(20, 4 * 16, weightsRowSize)), - 0, - 0, - - // cross_attn_k layer 1 - reinterpret_cast(inputWeightsPtrs + common::flat_index2(21, 0, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(21, 4 * 16, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(21, 0, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(21, 4 * 16, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(21, 0, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(21, 4 * 16, weightsRowSize)), - 0, - 0, - - // cross_attn_v layer 0 - reinterpret_cast(inputWeightsPtrs + common::flat_index2(22, 0, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(22, 4 * 16, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(22, 0, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(22, 4 * 16, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(22, 0, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(22, 4 * 16, weightsRowSize)), - 0, - 0, - - // cross_attn_v layer 1 - reinterpret_cast(inputWeightsPtrs + common::flat_index2(23, 0, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(23, 4 * 16, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(23, 0, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(23, 4 * 16, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(23, 0, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(23, 4 * 16, weightsRowSize)), - 0, - 0, - - // cross_attn_dense layer 0 - reinterpret_cast(inputWeightsPtrs + common::flat_index2(24, 0, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(24, 8 * 16, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(24, 0, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(24, 8 * 16, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(24, 0, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(24, 8 * 16, weightsRowSize)), - 0, - 0, - - // cross_attn_dense layer 1 - reinterpret_cast(inputWeightsPtrs + common::flat_index2(25, 0, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(25, 8 * 16, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(25, 0, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(25, 8 * 16, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(25, 0, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(25, 8 * 16, weightsRowSize)), - 0, - 0, - }; + for (size_t bid = 0; bid < configs.size(); ++bid) + { + valuesWorkspace.push_back(std::make_shared>()); + auto beamWidth = reqBeamWidth[bid]; + auto config = configs[bid]; + if (config == nullptr) + { + continue; + } + peftTable.try_emplace(reqIds[bid], valuesWorkspace[bid]); + if (config->getShape().nbDims == 3) + { + config->squeeze(0); + } + SizeType numRows = config->getShape().d[0]; + for (SizeType r = 0; r < numRows; ++r) + { + auto const* row = bufferCast(*ITensor::slice(config, r, 1)); + auto moduleId = row[lora::kLORA_CONFIG_MODULE_OFF]; + auto layerId = row[lora::kLORA_CONFIG_LAYER_OFF]; + auto adapterSize = row[lora::kLORA_CONFIG_ADAPTER_SIZE_OFF]; + auto modOff = moduleOffset.at(moduleId); + + auto inPointer = pointerAddr++; + auto outPointer = pointerAddr++; + valuesWorkspace[bid]->push_back( + LoraCache::TaskLayerModuleConfig{0, 0, 0, 0, moduleId, layerId, adapterSize, 0, inPointer, outPointer}); + + for (SizeType beamIdx = 0; beamIdx < beamWidth; ++beamIdx) + { + targetAdapterSizes[common::flat_index3(modOff, layerId, bid + beamIdx, numLayers, numSeq)] + = adapterSize; + targetPointers[common::flat_index4(modOff, layerId, bid + beamIdx, 0, numLayers, numSeq, 2)] + = inPointer; + targetPointers[common::flat_index4(modOff, layerId, bid + beamIdx, 1, numLayers, numSeq, 2)] + = outPointer; + } + } + } - checkLoraTensors(loraManager, targetPtrs, weightsPtrs, targetAdapterSizes, adapterSizes, modelConfig, worldConfig, - modules, numModules, numLayers, numSeqs); + return std::make_tuple(targetAdapterSizes, targetPointers, peftTable); } -TEST_F(LoraManagerTest, fillInputTensors_tp2_pp2) +TEST_F(LoraManagerTest, fillInputTensors) { LoraManager loraManager; auto modelConfig = GptModelConfig(0, 2, 1, 16, nvinfer1::DataType::kFLOAT); modelConfig.setMlpHiddenSize(32); - auto worldConfig = WorldConfig(2, 2, 3); // tpRank = 1, ppRank = 1 + auto worldConfig = WorldConfig(1, 1, 0); std::vector modules{ LoraModule(LoraModule::ModuleType::kATTN_QKV, 16, 3 * 16, false, true, -1, 0), LoraModule(LoraModule::ModuleType::kATTN_Q, 16, 16, false, true, -1, 0), @@ -721,7 +313,7 @@ TEST_F(LoraManagerTest, fillInputTensors_tp2_pp2) modelConfig.setLoraModules(modules); loraManager.create(modelConfig, worldConfig, *mManager); auto numModules = static_cast(modelConfig.getLoraModules().size()); - auto numLayers = static_cast(modelConfig.getNbLayers(2)); + auto numLayers = static_cast(modelConfig.getNbLayers()); SizeType numSeqs = 4; TensorPtr weightsPtrs = mManager->cpu(ITensor::makeShape({numModules, numLayers, numSeqs, 2}), nvinfer1::DataType::kINT64); @@ -734,189 +326,20 @@ TEST_F(LoraManagerTest, fillInputTensors_tp2_pp2) SizeType numContextRequests = 1; std::vector reqIds{1, 2, 3}; std::vector reqBeamWidth{1, 2, 1}; - std::vector loraEnabled{true, true, false}; - - TensorPtr loraReqKeys = utils::loadNpy(*mManager, TEST_KEYS_LORA_TP2.string(), MemoryType::kCPU); - TensorPtr loraWeights = utils::loadNpy(*mManager, TEST_DEST_LORA_TP2.string(), MemoryType::kGPU); - - loraManager.addTask(1, loraWeights, loraReqKeys); - loraManager.addTask(2, loraWeights, loraReqKeys); - loraManager.fillInputTensors( - weightsPtrs, adapterSizes, reqIds, reqBeamWidth, loraEnabled, numContextRequests, modelConfig, worldConfig); + TensorPtr loraReqKeys = utils::loadNpy(*mManager, TEST_KEYS_LORA_TP1, MemoryType::kCPU); + std::vector loraConfigs{loraReqKeys, loraReqKeys, nullptr}; - // set in order litest in modelConfig - SizeType attnQkvOff = 1; - SizeType attnDense = 0; + std::vector valuesWorkspace; + auto [targetadapterSizes, targetPointers, peftTable] = createFillInputTensorsTestsData( + loraConfigs, reqIds, reqBeamWidth, modules, numLayers, numSeqs, valuesWorkspace); - auto inputWeightsPtrs = bufferCast(*loraWeights); + loraManager.fillInputTensors(weightsPtrs, adapterSizes, peftTable, reqIds, reqBeamWidth, modelConfig, worldConfig); auto adapterSizesPtr = bufferCast(*adapterSizes); auto weightsPtrsPtr = bufferCast(*weightsPtrs); - auto weightsRowSize = loraWeights->getShape().d[1]; - - std::vector targetAdapterSizes{ - 8, 8, 8, 0, // attn_qkv layer 1 - 4, 4, 4, 0, // attn_q layer 1 - 4, 4, 4, 0, // attn_k layer 1 - 4, 4, 4, 0, // attn_v layer 1 - 8, 8, 8, 0, // attn_dense layer 1 - 8, 8, 8, 0, // mlp_h_to_4h layer 1 - 8, 8, 8, 0, // mlp_gate layer 1 - 8, 8, 8, 0, // mlp_4h_to_h layer 1 - 8, 8, 8, 0, // cross_attn_qkv layer 1 - 4, 4, 4, 0, // cross_attn_q layer 1 - 4, 4, 4, 0, // cross_attn_k layer 1 - 4, 4, 4, 0, // cross_attn_v layer 1 - 8, 8, 8, 0, // cross_attn_dense layer 1 - }; - - SizeType attnQkvInRank1Off = 0; - SizeType attnQkvOutRank1Off = (8 * 16) + (4 * (3 * 16)); - - SizeType attnQInRank1Off = 0; - SizeType attnQOutRank1Off = (4 * 16) + (2 * 16); - - SizeType mlphto4hInRank1Off = 0; - SizeType mlphto4hOutRank1Off = (8 * 16) + (4 * 32); - - SizeType mlp4htohInRank1Off = (4 * 32); - SizeType mlp4htohOutRank1Off = (8 * 32); - - std::vector targetPtrs{ - // attn_qkv layer 1 - reinterpret_cast(inputWeightsPtrs + common::flat_index2(1, attnQkvInRank1Off, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(1, attnQkvOutRank1Off, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(1, attnQkvInRank1Off, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(1, attnQkvOutRank1Off, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(1, attnQkvInRank1Off, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(1, attnQkvOutRank1Off, weightsRowSize)), - 0, - 0, - - // attn_q layer 1 - reinterpret_cast(inputWeightsPtrs + common::flat_index2(3, attnQInRank1Off, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(3, attnQOutRank1Off, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(3, attnQInRank1Off, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(3, attnQOutRank1Off, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(3, attnQInRank1Off, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(3, attnQOutRank1Off, weightsRowSize)), - 0, - 0, - - // attn_k layer 1 - reinterpret_cast(inputWeightsPtrs + common::flat_index2(5, attnQInRank1Off, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(5, attnQOutRank1Off, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(5, attnQInRank1Off, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(5, attnQOutRank1Off, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(5, attnQInRank1Off, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(5, attnQOutRank1Off, weightsRowSize)), - 0, - 0, - - // attn_v layer 1 - reinterpret_cast(inputWeightsPtrs + common::flat_index2(7, attnQInRank1Off, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(7, attnQOutRank1Off, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(7, attnQInRank1Off, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(7, attnQOutRank1Off, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(7, attnQInRank1Off, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(7, attnQOutRank1Off, weightsRowSize)), - 0, - 0, - - // attn_dense layer 1 - reinterpret_cast(inputWeightsPtrs + common::flat_index2(9, 4 * 16, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(9, 8 * 16, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(9, 4 * 16, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(9, 8 * 16, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(9, 4 * 16, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(9, 8 * 16, weightsRowSize)), - 0, - 0, - - // mlp_h_to_4h layer 1 - reinterpret_cast(inputWeightsPtrs + common::flat_index2(11, mlphto4hInRank1Off, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(11, mlphto4hOutRank1Off, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(11, mlphto4hInRank1Off, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(11, mlphto4hOutRank1Off, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(11, mlphto4hInRank1Off, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(11, mlphto4hOutRank1Off, weightsRowSize)), - 0, - 0, - - // mlp_gate layer 1 - reinterpret_cast(inputWeightsPtrs + common::flat_index2(15, mlphto4hInRank1Off, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(15, mlphto4hOutRank1Off, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(15, mlphto4hInRank1Off, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(15, mlphto4hOutRank1Off, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(15, mlphto4hInRank1Off, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(15, mlphto4hOutRank1Off, weightsRowSize)), - 0, - 0, - - // mlp_4h_to_h layer 1 - reinterpret_cast(inputWeightsPtrs + common::flat_index2(13, mlp4htohInRank1Off, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(13, mlp4htohOutRank1Off, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(13, mlp4htohInRank1Off, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(13, mlp4htohOutRank1Off, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(13, mlp4htohInRank1Off, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(13, mlp4htohOutRank1Off, weightsRowSize)), - 0, - 0, - - // cross_attn_qkv layer 1 - reinterpret_cast(inputWeightsPtrs + common::flat_index2(17, attnQkvInRank1Off, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(17, attnQkvOutRank1Off, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(17, attnQkvInRank1Off, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(17, attnQkvOutRank1Off, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(17, attnQkvInRank1Off, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(17, attnQkvOutRank1Off, weightsRowSize)), - 0, - 0, - - // cross_attn_q layer 1 - reinterpret_cast(inputWeightsPtrs + common::flat_index2(19, attnQInRank1Off, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(19, attnQOutRank1Off, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(19, attnQInRank1Off, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(19, attnQOutRank1Off, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(19, attnQInRank1Off, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(19, attnQOutRank1Off, weightsRowSize)), - 0, - 0, - - // cross_attn_k layer 1 - reinterpret_cast(inputWeightsPtrs + common::flat_index2(21, attnQInRank1Off, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(21, attnQOutRank1Off, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(21, attnQInRank1Off, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(21, attnQOutRank1Off, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(21, attnQInRank1Off, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(21, attnQOutRank1Off, weightsRowSize)), - 0, - 0, - - // cross_attn_v layer 1 - reinterpret_cast(inputWeightsPtrs + common::flat_index2(23, attnQInRank1Off, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(23, attnQOutRank1Off, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(23, attnQInRank1Off, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(23, attnQOutRank1Off, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(23, attnQInRank1Off, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(23, attnQOutRank1Off, weightsRowSize)), - 0, - 0, - - // cross_attn_dense layer 1 - reinterpret_cast(inputWeightsPtrs + common::flat_index2(25, 4 * 16, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(25, 8 * 16, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(25, 4 * 16, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(25, 8 * 16, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(25, 4 * 16, weightsRowSize)), - reinterpret_cast(inputWeightsPtrs + common::flat_index2(25, 8 * 16, weightsRowSize)), - 0, - 0, - }; - - checkLoraTensors(loraManager, targetPtrs, weightsPtrs, targetAdapterSizes, adapterSizes, modelConfig, worldConfig, - modules, numModules, numLayers, numSeqs); + checkLoraTensors(loraManager, targetPointers, weightsPtrs, targetadapterSizes, adapterSizes, modelConfig, + worldConfig, modules, numModules, numLayers, numSeqs); } } // namespace tensorrt_llm::runtime diff --git a/cpp/tests/runtime/loraUtilsTest.cpp b/cpp/tests/runtime/loraUtilsTest.cpp index 39d192127..d64fb9e0f 100644 --- a/cpp/tests/runtime/loraUtilsTest.cpp +++ b/cpp/tests/runtime/loraUtilsTest.cpp @@ -105,7 +105,8 @@ TEST_F(LoraUtilsTest, loraValidateRequestTensors) auto configPtr = bufferCast(*optReqLoraConfig.value()); std::copy_n(config.data(), config.size(), configPtr); - EXPECT_THAT([&]() { loraValidateRequestTensors(optReqLoraWeights, optReqLoraConfig, modelConfig, worldConfig); }, + EXPECT_THAT([&]() + { loraValidateRequestTensors(12345, optReqLoraWeights, optReqLoraConfig, modelConfig, worldConfig); }, testing::Throws()); std::vector modules{ @@ -113,7 +114,11 @@ TEST_F(LoraUtilsTest, loraValidateRequestTensors) }; modelConfig.setLoraModules(modules); - loraValidateRequestTensors(optReqLoraWeights, optReqLoraConfig, modelConfig, worldConfig); + loraValidateRequestTensors(12345, optReqLoraWeights, optReqLoraConfig, modelConfig, worldConfig); + + EXPECT_THAT([&]() + { loraValidateRequestTensors(std::nullopt, optReqLoraWeights, optReqLoraConfig, modelConfig, worldConfig); }, + testing::Throws()); } } // namespace tensorrt_llm::runtime::lora diff --git a/cpp/tests/runtime/workerPoolTest.cpp b/cpp/tests/runtime/workerPoolTest.cpp new file mode 100644 index 000000000..967c69361 --- /dev/null +++ b/cpp/tests/runtime/workerPoolTest.cpp @@ -0,0 +1,71 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/runtime/workerPool.h" + +#include + +namespace tensorrt_llm::runtime +{ + +TEST(WorkerPool, basic) +{ + WorkerPool pool(2); + + auto fn = []() { return 12345; }; + auto resultFuture = pool.enqueue, int>(std::move(fn)); + + auto fn2 = []() { return 12.345f; }; + auto f2 = pool.enqueue, float>(std::move(fn2)); + + auto fn3 = []() { return 40.78f; }; + auto f3 = pool.enqueue, float>(std::move(fn3)); + + auto r1 = resultFuture.get(); + auto r2 = f2.get(); + auto r3 = f3.get(); + + EXPECT_EQ(12345, r1); + EXPECT_FLOAT_EQ(12.345f, r2); + EXPECT_FLOAT_EQ(40.78f, r3); +} + +TEST(WorkerPool, voidReturn) +{ + WorkerPool pool(2); + + int32_t returnVal1 = 0; + int32_t returnVal2 = 0; + int32_t returnVal3 = 0; + + auto fn1 = [&returnVal1]() { returnVal1 = 10001; }; + auto f1 = pool.enqueue(fn1); + + auto fn2 = [&returnVal2]() { returnVal2 = 10002; }; + auto f2 = pool.enqueue(fn2); + + auto fn3 = [&returnVal3]() { returnVal3 = 10003; }; + auto f3 = pool.enqueue(fn3); + + f1.get(); + f2.get(); + f3.get(); + + EXPECT_EQ(returnVal1, 10001); + EXPECT_EQ(returnVal2, 10002); + EXPECT_EQ(returnVal3, 10003); +} +} // namespace tensorrt_llm::runtime diff --git a/docker/Dockerfile.multi b/docker/Dockerfile.multi index ba1ed4328..7f256b79b 100644 --- a/docker/Dockerfile.multi +++ b/docker/Dockerfile.multi @@ -1,6 +1,6 @@ # Multi-stage Dockerfile ARG BASE_IMAGE=nvcr.io/nvidia/pytorch -ARG BASE_TAG=24.01-py3 +ARG BASE_TAG=24.02-py3 ARG DEVEL_IMAGE=devel FROM ${BASE_IMAGE}:${BASE_TAG} as base diff --git a/docker/Makefile b/docker/Makefile index 237e16b31..e219a9df2 100644 --- a/docker/Makefile +++ b/docker/Makefile @@ -145,12 +145,12 @@ jenkins-aarch64_%: STAGE = devel centos7_%: IMAGE_WITH_TAG = $(shell grep 'LLM_CENTOS7_DOCKER_IMAGE = ' ../jenkins/L0_MergeRequest.groovy | grep -o '".*"' | tr -d '"') centos7_%: STAGE = devel centos7_%: BASE_IMAGE = nvidia/cuda -centos7_%: BASE_TAG = 12.3.1-devel-centos7 +centos7_%: BASE_TAG = 12.3.2-devel-centos7 # For x86_64 and aarch64 ubuntu22_%: STAGE = devel ubuntu22_%: BASE_IMAGE = nvidia/cuda -ubuntu22_%: BASE_TAG = 12.3.1-devel-ubuntu22.04 +ubuntu22_%: BASE_TAG = 12.3.2-devel-ubuntu22.04 # For x86_64 old-cuda_%: IMAGE_WITH_TAG = $(shell grep 'LLM_OLD_CUDA_DOCKER_IMAGE = ' ../jenkins/L0_MergeRequest.groovy | grep -o '".*"' | tr -d '"') diff --git a/docker/common/install_polygraphy.sh b/docker/common/install_polygraphy.sh index b8b404c8c..f3ae75a18 100644 --- a/docker/common/install_polygraphy.sh +++ b/docker/common/install_polygraphy.sh @@ -2,6 +2,4 @@ set -ex -RELEASE_URL_PG=https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/secure/9.0.1/tars/polygraphy-0.48.1-py2.py3-none-any.whl -pip3 uninstall -y polygraphy -pip3 install ${RELEASE_URL_PG} +pip3 install polygraphy==0.49.0 diff --git a/docker/common/install_pytorch.sh b/docker/common/install_pytorch.sh index 93fc8d3ea..1536bd514 100644 --- a/docker/common/install_pytorch.sh +++ b/docker/common/install_pytorch.sh @@ -4,8 +4,8 @@ set -ex # Use latest stable version from https://pypi.org/project/torch/#history # and closest to the version specified in -# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-01.html#rel-24-01 -TORCH_VERSION="2.1.2" +# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-02.html#rel-24-02 +TORCH_VERSION="2.2.1" SYSTEM_ID=$(grep -oP '(?<=^ID=).+' /etc/os-release | tr -d '"') prepare_environment() { diff --git a/docker/common/install_tensorrt.sh b/docker/common/install_tensorrt.sh index 487106b08..20ea5c62c 100644 --- a/docker/common/install_tensorrt.sh +++ b/docker/common/install_tensorrt.sh @@ -2,13 +2,15 @@ set -ex -# Use https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-01.html#rel-24-01 +# Use https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-02.html#rel-24-02 TRT_VER="9.3.0.1" CUDA_VER="12.3" CUDNN_VER="8.9.7.29-1+cuda12.2" # v2.19.4 doesn't exist in https://developer.download.nvidia.cn/compute/cuda/repos/ NCCL_VER="2.19.3-1+cuda12.3" CUBLAS_VER="12.3.4.1-1" +# Align with the pre-installed CUDA / NVCC version. +# https://docs.nvidia.com/cuda/archive/12.3.2/cuda-toolkit-release-notes/index.html NVRTC_VER="12.3.107-1" for i in "$@"; do diff --git a/docs/Doxygen b/docs/Doxygen index 617416c16..0e4e19f68 100644 --- a/docs/Doxygen +++ b/docs/Doxygen @@ -864,7 +864,7 @@ WARN_LOGFILE = # spaces. See also FILE_PATTERNS and EXTENSION_MAPPING # Note: If this tag is empty the current directory is searched. -INPUT = ../cpp/include/tensorrt_llm/runtime +INPUT = ../cpp/include/tensorrt_llm/runtime ../cpp/include/tensorrt_llm/executor # This tag can be used to specify the character encoding of the source files # that doxygen parses. Internally doxygen uses the UTF-8 encoding. Doxygen uses diff --git a/docs/source/2023-05-19-how-to-debug.md b/docs/source/2023-05-19-how-to-debug.md index 858d23994..07f51d1eb 100644 --- a/docs/source/2023-05-19-how-to-debug.md +++ b/docs/source/2023-05-19-how-to-debug.md @@ -83,119 +83,106 @@ Here is an example to print the values of the MLP output tensor in the GPT model hidden_states = residual + hidden_states ``` -2. In `examples/gpt/build.py`, we mark it as a TensorRT network output: - -```python - with net_guard(network): - network.set_named_parameters(tensorrt_llm_gpt.named_parameters()) - - inputs = tensorrt_llm_gpt.prepare_inputs(args.max_batch_size, - args.max_input_len, - args.max_output_len, True, - args.max_beam_width) - tensorrt_llm_gpt(*inputs) - - # mark as TRT network output - # ---------------------------------------------------------------- - for k, v in tensorrt_llm_gpt.named_network_outputs(): - network._mark_output(v, k, - tensorrt_llm.str_dtype_to_trt(args.dtype)) - # ---------------------------------------------------------------- -``` +2. Build the TensorRT engine of the model: - -3. Build the TensorRT engine of the model: +When building engines with `trtllm-build`, enable the `--enable_debug_output` option. ```bash +cd examples/gpt + +# Download hf gpt2 model rm -rf gpt2 && git clone https://huggingface.co/gpt2-medium gpt2 pushd gpt2 && rm pytorch_model.bin model.safetensors && wget -q https://huggingface.co/gpt2-medium/resolve/main/pytorch_model.bin && popd -python3 hf_gpt_convert.py -i gpt2 -o ./c-model/gpt2 --tensor-parallelism 1 --storage-type float16 -python3 build.py --model_dir=./c-model/gpt2/1-gpu --use_gpt_attention_plugin +# Convert to TensorRT-LLM checkpoint +python3 convert_checkpoint.py --model_dir gpt2 \ + --dtype float16 \ + --output_dir gpt2/trt_ckpt/fp16/1-gpu + +# Build TensorRT-LLM engines with --enable_debug_output +trtllm-build --checkpoint_dir gpt2/trt_ckpt/fp16/1-gpu \ + --gpt_attention_plugin float16 \ + --remove_input_padding enable \ + --enable_debug_output \ + --output_dir gpt2/trt_engines/fp16/1-gpu ``` -4. Print the intermediate output tensors: - - -In `examples/gpt/run.py`, we open the debug mode: - -```python - decoder = tensorrt_llm.runtime.GenerationSession(model_config, - engine_buffer, - runtime_mapping, - debug_mode=True) -``` +3. Print the intermediate output tensors: In `tensorrt_llm/runtime/generation.py`, we print the debug info: ```python - if step == 0: - ... - ctx_shape, ctx_buffer = self._get_context_shape_buffer( - input_ids, max_input_length, step, - input_lengths, position_ids, last_token_ids, attention_mask, - this_src_cache_indirection) - self.runtime._set_shape(context, ctx_shape) - self.runtime._set_buffer(context, ctx_buffer) - # ------------------------------------------- - debug_buffer = ctx_buffer - # ------------------------------------------- - stream = torch.cuda.current_stream().cuda_stream - ok = self.runtime._run(context, stream) + instance_idx = step % 2 + if self.cuda_graph_mode and self.runtime.cuda_graph_instances[ + instance_idx] is not None: + # launch cuda graph + CUASSERT( + cudart.cudaGraphLaunch( + self.runtime.cuda_graph_instances[instance_idx], stream)) + ok = True + else: + ok = self.runtime._run(context, stream) + if not ok: - raise RuntimeError('Executing TRT engine failed!') + raise RuntimeError(f"Executing TRT engine failed step={step}!") if self.debug_mode: torch.cuda.synchronize() # ------------------------------------------- if step == 0: - print(debug_buffer.keys()) - print(step, debug_buffer['layers.6.mlp_output']) + print(self.debug_buffer.keys()) + print(f"Step: {step}") + print(self.debug_buffer['transformer.layers.6.mlp_output']) # ------------------------------------------- +``` - if not step == self.max_new_tokens - 1: - ... - next_step_shape, next_step_buffer = self._get_next_step_shape_buffer( - batch_size, scfg.num_beams, max_input_length, step, - input_lengths, position_ids, last_token_ids, - attention_mask, next_src_cache_indirection) - self.runtime._set_shape(next_context, next_step_shape) - self.runtime._set_buffer(next_context, next_step_buffer) - # ------------------------------------------- - debug_buffer = next_step_buffer - # ------------------------------------------- +Then, run `../run.py` with `--debug_mode` and `--use_py_session`: +```bash +python3 ../run.py --engine_dir gpt2/trt_engines/fp16/1-gpu \ + --tokenizer_dir gpt2 \ + --max_output_len 8 \ + --debug_mode \ + --use_py_session ``` -Then, we will see the tensor values: +We will see the tensor values: -```bash -python run.py --max_output_len=8 -dict_keys(['input_ids', 'logits', 'input_lengths', 'position_ids', 'last_token_ids', 'max_input_length', 'cache_indirection', 'past_key_0', 'past_value_0', 'present_key_0', 'present_value_0', 'past_key_1', 'past_value_1', 'present_key_1', 'present_value_1', 'past_key_2', 'past_value_2', 'present_key_2', 'present_value_2', 'past_key_3', 'past_value_3', 'present_key_3', 'present_value_3', 'past_key_4', 'past_value_4', 'present_key_4', 'present_value_4', 'past_key_5', 'past_value_5', 'present_key_5', 'present_value_5', 'past_key_6', 'past_value_6', 'present_key_6', 'present_value_6', 'past_key_7', 'past_value_7', 'present_key_7', 'present_value_7', 'past_key_8', 'past_value_8', 'present_key_8', 'present_value_8', 'past_key_9', 'past_value_9', 'present_key_9', 'present_value_9', 'past_key_10', 'past_value_10', 'present_key_10', 'present_value_10', 'past_key_11', 'past_value_11', 'present_key_11', 'present_value_11', 'past_key_12', 'past_value_12', 'present_key_12', 'present_value_12', 'past_key_13', 'past_value_13', 'present_key_13', 'present_value_13', 'past_key_14', 'past_value_14', 'present_key_14', 'present_value_14', 'past_key_15', 'past_value_15', 'present_key_15', 'present_value_15', 'past_key_16', 'past_value_16', 'present_key_16', 'present_value_16', 'past_key_17', 'past_value_17', 'present_key_17', 'present_value_17', 'past_key_18', 'past_value_18', 'present_key_18', 'present_value_18', 'past_key_19', 'past_value_19', 'present_key_19', 'present_value_19', 'past_key_20', 'past_value_20', 'present_key_20', 'present_value_20', 'past_key_21', 'past_value_21', 'present_key_21', 'present_value_21', 'past_key_22', 'past_value_22', 'present_key_22', 'present_value_22', 'past_key_23', 'past_value_23', 'present_key_23', 'present_value_23', 'sequence_length', 'past_key_value_length', 'layers.0.mlp_output', 'layers.1.mlp_output', 'layers.2.mlp_output', 'layers.3.mlp_output', 'layers.4.mlp_output', 'layers.5.mlp_output', 'layers.6.mlp_output', 'layers.7.mlp_output', 'layers.8.mlp_output', 'layers.9.mlp_output', 'layers.10.mlp_output', 'layers.11.mlp_output', 'layers.12.mlp_output', 'layers.13.mlp_output', 'layers.14.mlp_output', 'layers.15.mlp_output', 'layers.16.mlp_output', 'layers.17.mlp_output', 'layers.18.mlp_output', 'layers.19.mlp_output', 'layers.20.mlp_output', 'layers.21.mlp_output', 'layers.22.mlp_output', 'layers.23.mlp_output']) -0 tensor([[[ 0.0295, -0.0256, -0.0780, ..., -0.0562, -0.0241, 0.0273], - [-0.0089, 0.5882, 0.1989, ..., -1.0464, -0.6305, 0.5967], - [-0.8793, 0.1056, 0.7083, ..., 0.0889, 1.0714, -0.2931], - ..., - [ 0.1209, -0.0886, -0.5927, ..., -0.1048, -0.3437, 0.1085], - [-1.0752, -0.0739, 0.6156, ..., 0.3454, 0.3014, 0.2653], - [-0.7126, 0.9685, -0.1145, ..., -0.0084, 0.9521, 0.1425]]], - device='cuda:0') -1 tensor([[[-0.2129, 0.5879, 0.8172, ..., 0.7892, -0.6887, 0.6063]]], - device='cuda:0') -2 tensor([[[ 0.4184, -0.0066, 1.3895, ..., -0.9023, -0.0686, -0.2831]]], - device='cuda:0') -3 tensor([[[-0.7935, -0.5085, -0.1696, ..., -0.5839, -0.1375, -0.0078]]], - device='cuda:0') -4 tensor([[[-0.0810, 0.1262, -0.6260, ..., -0.1065, -0.0529, 0.7143]]], - device='cuda:0') -5 tensor([[[-0.3322, -0.8835, 0.3427, ..., 0.8159, -0.0622, 1.2327]]], - device='cuda:0') -6 tensor([[[-0.2217, -0.2057, -0.1475, ..., -0.3545, -0.1673, 0.1131]]], - device='cuda:0') -7 tensor([[[ 0.1268, -0.1570, 0.3972, ..., -0.8213, -0.3282, -0.8672]]], - device='cuda:0') -Input: Born in north-east France, Soyer trained as a -Output: chef before moving to London in the early +``` +...... +dict_keys(['context_lengths', 'cache_indirection', 'position_ids', 'logits', 'last_token_ids', 'input_ids', 'kv_cache_block_pointers', 'host_kv_cache_block_pointers', 'sequence_length', 'host_past_key_value_lengths', 'host_sink_token_length', 'host_request_types', 'host_max_attention_window_sizes', 'host_context_lengths', 'transformer.layers.0.mlp_output', 'transformer.layers.1.mlp_output', 'transformer.layers.2.mlp_output', 'transformer.layers.3.mlp_output', 'transformer.layers.4.mlp_output', 'transformer.layers.5.mlp_output', 'transformer.layers.6.mlp_output', 'transformer.layers.7.mlp_output', 'transformer.layers.8.mlp_output', 'transformer.layers.9.mlp_output', 'transformer.layers.10.mlp_output', 'transformer.layers.11.mlp_output', 'transformer.layers.12.mlp_output', 'transformer.layers.13.mlp_output', 'transformer.layers.14.mlp_output', 'transformer.layers.15.mlp_output', 'transformer.layers.16.mlp_output', 'transformer.layers.17.mlp_output', 'transformer.layers.18.mlp_output', 'transformer.layers.19.mlp_output', 'transformer.layers.20.mlp_output', 'transformer.layers.21.mlp_output', 'transformer.layers.22.mlp_output', 'transformer.layers.23.mlp_output']) +Step: 0 +tensor([[ 0.0294, -0.0260, -0.0776, ..., -0.0560, -0.0235, 0.0273], + [-0.0071, 0.5879, 0.1993, ..., -1.0449, -0.6299, 0.5957], + [-0.8779, 0.1050, 0.7090, ..., 0.0910, 1.0713, -0.2939], + ..., + [ 0.1212, -0.0903, -0.5918, ..., -0.1045, -0.3445, 0.1082], + [-1.0723, -0.0732, 0.6157, ..., 0.3452, 0.2998, 0.2649], + [-0.7134, 0.9692, -0.1141, ..., -0.0096, 0.9521, 0.1437]], + device='cuda:0', dtype=torch.float16) +Step: 1 +tensor([[-0.2107, 0.5874, 0.8179, ..., 0.7900, -0.6890, 0.6064]], + device='cuda:0', dtype=torch.float16) +Step: 2 +tensor([[ 0.4192, -0.0047, 1.3887, ..., -0.9028, -0.0682, -0.2820]], + device='cuda:0', dtype=torch.float16) +Step: 3 +tensor([[-0.7949, -0.5073, -0.1721, ..., -0.5830, -0.1378, -0.0070]], + device='cuda:0', dtype=torch.float16) +Step: 4 +tensor([[-0.0804, 0.1272, -0.6255, ..., -0.1072, -0.0523, 0.7144]], + device='cuda:0', dtype=torch.float16) +Step: 5 +tensor([[-0.3328, -0.8828, 0.3442, ..., 0.8149, -0.0630, 1.2305]], + device='cuda:0', dtype=torch.float16) +Step: 6 +tensor([[-0.2225, -0.2079, -0.1459, ..., -0.3555, -0.1672, 0.1135]], + device='cuda:0', dtype=torch.float16) +Step: 7 +tensor([[ 0.1290, -0.1556, 0.3977, ..., -0.8218, -0.3291, -0.8672]], + device='cuda:0', dtype=torch.float16) +Input [Text 0]: "Born in north-east France, Soyer trained as a" +Output [Text 0 Beam 0]: " chef before moving to London in the early" ``` ## Debug execution errors diff --git a/docs/source/conf.py b/docs/source/conf.py index 1eb81c7be..511d76299 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -98,3 +98,16 @@ def gen_cpp_doc(ofile_name: str, header_dir: str, summary: str): subprocess.run(['mkdir', '-p', CPP_GEN_DIR]) gen_cpp_doc(CPP_GEN_DIR + '/runtime.rst', CPP_INCLUDE_DIR + '/runtime', runtime_summary) + +executor_summary = f""" +Executor +========== + +.. Here are files in the cpp/include/executor +.. We manually add subsection to enable detailed description in the future +.. It is also doable to automatically generate this file and list all the modules in the conf.py + """.strip() + +subprocess.run(['mkdir', '-p', CPP_GEN_DIR]) +gen_cpp_doc(CPP_GEN_DIR + '/executor.rst', CPP_INCLUDE_DIR + '/executor', + executor_summary) diff --git a/docs/source/executor.md b/docs/source/executor.md new file mode 100644 index 000000000..5c08c0df2 --- /dev/null +++ b/docs/source/executor.md @@ -0,0 +1,32 @@ +(executor)= + +# Executor API + +TensorRT-LLM includes a high-level API called the Executor API which allows you to execute requests +asynchronously, with in-flight batching, and without the need to define callbacks. + +A software component (referred to as "the client" in the text that follows) can interact +with the executor using the API defined in the [`executor.h`](source:cpp/include/tensorrt_llm/executor/executor.h) file. +For details about the API, refer to the {ref}`_cpp_gen/executor.rst`. The following sections provide an overview of the main classes defined in the Executor API. + +### The Executor Class + +The `Executor` class is responsible for receiving requests from the client, and providing responses for those requests. The executor is constructed by providing a path to a directory containing the TensorRT-LLM engine or buffers containing the engine and the model JSON configuration. The client can create requests and enqueue those requests for execution using the `enqueueRequest` or `enqueueRequests` methods of the `Executor` class. Enqueued requests will be scheduled for execution by the executor, and multiple independent requests can be batched together at every iteration of the main execution loop (a process often referred to as continuous batching or iteration-level batching). Responses for a particular request can be awaited for by calling the `awaitResponses` method, and by providing the request id. Alternatively, responses for any requests can be awaited for by omitting to provide the request id when calling `awaitResponses`. The `Executor` class also allows to cancel requests using the `cancelRequest` method and to obtain per-iteration and per-request statistics using the `getLatestIterationStats`. + +### The Request Class + +The `Request` class is used to define properties of the request, such as the input token ids and the maximum number of tokens to generate. The `streaming` parameter can be used to indicate if the request should generate a response for each new generated tokens (`streaming = true`) or only after all tokens have been generated (`streaming = false`). Other mandatory parameters of the request include the sampling configuration (defined by the `SamplingConfig` class) which contains parameters controlling the decoding process and the output configuration (defined by the `OutputConfig` class) which controls what information should be included in the `Result` for a particular response. + +Optional parameters can also be provided when constructing a request such as a list of bad words, a list of stop words, or configurations objects for prompt tuning, LoRA, or speculative decoding for example. + +### The Response Class + +The `awaitResponses` method of the `Executor` class returns a vector of responses. Each response contains the request id associated with this response, and also contains either an error or a `Result`. Check if the response has an error by using the `hasError` method before trying to obtain the `Result` associated with this response using the `getResult` method. + +### The Result Class + +The `Result` class holds the result for a given request. It contains a Boolean parameter called `isFinal` that indicates if this is the last `Result` that will be returned for the given request id. It also contains the generated tokens. If the request is configured with `streaming = false`, the `isFinal` Boolean will be set to `true` and all generated tokens will be included in the `outputTokenIds`. If `streaming = false` is used, a `Result` will only include 1 token and the `isFinal` flag will be set to `true` for the last result associated with this request. + +## Executor API Example + +An example is provided that shows how to use the Executor API to generate tokens for a single request in non-streaming mode. The example can be found in [`main.cpp`](source:examples/cpp_executor/main.cpp). diff --git a/docs/source/gpt_runtime.md b/docs/source/gpt_runtime.md index adaffab73..eecbbd25e 100644 --- a/docs/source/gpt_runtime.md +++ b/docs/source/gpt_runtime.md @@ -345,17 +345,17 @@ value for a given parameter, the vector can be limited to a single element ***General*** * `temperature`, a vector of floating-point numbers to control the - modulation of logits when sampling new tokens. The default value is `1.0f`, + modulation of logits when sampling new tokens. It can have any value `> 0.0f`. The default value is `1.0f`(no modulation). * `minLength`, a vector of integers to set a lower-bound on the number of tokens - generated. The default value is 0, + generated. It can have any value. Values `< 1` have no effect, the first generated token can be EOS. The default value is `1` (at least one non-EOS token is generated). * `repetitionPenalty`, a vector of float-point numbers to penalize tokens - based on how often they appear in the sequence. The default value is `0.f`, + based on how often they appear in the sequence. It can have any value `> 0.0f`. Repetition penalty `< 1.0f` encourages repetition, `> 1.0f` discourages it. The default value is `1.0f` (no effect). * `presencePenalty`, a vector of float-point numbers to penalize tokens already present in the sequence (irrespective of the number of appearances). - The default value is `0.f`, + It can have any value, values `< 0.0f` encourage repetition, `> 0.f` discourage it. The default value is `0.0f` (no effect). * `frequencyPenalty`, a vector of float-point numbers to penalize tokens - already present in the sequence (dependent on the number of appearances). - The default value is `0.f`, + already present in the sequence (dependent on the number of appearances). It can have any value, values `< 0.0f` encourage repetition, `> 0.0f` discourage it. + The default value is `0.0f`(no effect). The parameters `repetitionPenalty`, `presencePenalty`, and `frequencyPenalty` are not mutually exclusive. diff --git a/docs/source/inference_request.md b/docs/source/inference_request.md index 00f07314c..287a4b46e 100644 --- a/docs/source/inference_request.md +++ b/docs/source/inference_request.md @@ -29,6 +29,7 @@ Optional tensors that can be supplied to `InferenceRequest` are shown below. Def | `stop_words_list` | [2, num_stop_words] | `int32_t` | Stop words list | | `prompt_embedding_table` | [1] | `float16` | P-tuning prompt embedding table | | `prompt_vocab_size` | [1] | `int32_t` | P-tuning prompt vocab size | +| `lora_task_id` | [1] | `uint64_t` | Task ID for the given lora_weights. This ID is expected to be globally unique. To perform inference with a specific LoRA for the first time `lora_task_id` `lora_weights` and `lora_config` must all be given. The LoRA will be cached, so that subsequent requests for the same task only require `lora_task_id`. If the cache is full the oldest LoRA will be evicted to make space for new ones. An error is returned if `lora_task_id` is not cached | | `lora_weights` | [ num_lora_modules_layers, D x Hi + Ho x D ] | `float` (model data type) | weights for a lora adapter. see [lora docs](lora.md#lora-tensor-format-details) for more details. | | `lora_config` | [3] | `int32_t` | lora configuration tensor. `[ module_id, layer_idx, adapter_size (D aka R value) ]` see [lora docs](lora.md#lora-tensor-format-details) for more details. | | `return_log_probs` | [1] | `bool` | When `true`, include log probs in the output | diff --git a/docs/source/lora.md b/docs/source/lora.md index bf79a23c0..3565b2d0c 100644 --- a/docs/source/lora.md +++ b/docs/source/lora.md @@ -8,13 +8,10 @@ git-lfs clone https://huggingface.co/kunishou/Japanese-Alpaca-LoRA-7b-v0 BASE_MODEL=llama-7b-hf python examples/llama/convert_checkpoint.py --model_dir ${BASE_MODEL} \ - --output_dir /tmp/llama_7b_with_lora_qkv/trt_ckpt/fp16/1-gpu/ \ - --dtype float16 \ - --hf_lora_dir Japanese-Alpaca-LoRA-7b-v0 \ - --max_lora_rank 8 \ - --lora_target_modules "attn_q" "attn_k" "attn_v" + --output_dir /tmp/llama_7b/trt_ckpt/fp16/1-gpu/ \ + --dtype float16 -trtllm-build --checkpoint_dir /tmp/llama_7b_with_lora_qkv/trt_ckpt/fp16/1-gpu/ \ +trtllm-build --checkpoint_dir /tmp/llama_7b/trt_ckpt/fp16/1-gpu/ \ --output_dir /tmp/llama_7b_with_lora_qkv/trt_engines/fp16/1-gpu/ \ --remove_input_padding enable \ --gpt_attention_plugin float16 \ @@ -25,6 +22,9 @@ trtllm-build --checkpoint_dir /tmp/llama_7b_with_lora_qkv/trt_ckpt/fp16/1-gpu/ \ --max_batch_size 128 \ --max_input_len 512 \ --max_output_len 50 \ + --lora_dir Japanese-Alpaca-LoRA-7b-v0 \ + --max_lora_rank 8 \ + --lora_target_modules "attn_q" "attn_k" "attn_v" ``` To pass LoRAs into the cpp runtime they must be converted to the format below. @@ -41,6 +41,12 @@ See tensorrtllm_backend [docs](https://github.com/triton-inference-server/tensor To run inference with LoRA weights using GptManager, InferenceRequests must have LoraWeights (lora_weights) and LoraConfig (lora_config) parameters. +'LoraTaskId` the unique task ID for the given LoRA. + +To perform inference with a specific LoRA for the first time `lora_task_id` `lora_weights` and `lora_config` must all be given. +The LoRA will be cached, so that subsequent requests for the same task only require `lora_task_id`. +If the cache is full the oldest LoRA will be evicted to make space for new ones. An error is returned if `lora_task_id` is not cached. + `LoraWeights` contains the weights for all the LoRAs. Currently this should include weight for all tp and pp ranks. The weights tensor has the shape `[ num_lora_modules_layers, D x Hi + Ho x D ]`. the last dimension holds the in / out adapter weights for the associated module (e.g. attn_qkv) and model layer. Each of the in / out tensors are first flattened and then concatenated together in the format above. @@ -98,3 +104,9 @@ See LoraModule::ModuleType for model id mapping | mlp_h_to_4h | 5 | for llama2 adapter for gated mlp layer after attention / RMSNorm: up projection | | mlp_4h_to_h | 6 | for llama2 adapter for gated mlp layer after attention / RMSNorm: down projection | | mlp_gate | 7 | for llama2 adapter for gated mlp later after attention / RMSNorm: gate | + +#### LoraCache configuration + +The core idea is that we will have a fixed size, 2-level LoRA cache in TRT-LLM. The higher level cache resides on the host and the lower level is on GPU (distinct from the existing KV cache). Sizes of both are user configurable. +The CPU cache is configured to be a max size. The GPU cache is configured to a percentage of free GPU memory after engine load. As requests come in LoRAs are stored in the host cache. +As requests are scheduled for execution LoRAs are loaded into the GPU cache. diff --git a/docs/source/memory.md b/docs/source/memory.md index ceb079f32..0d87aaa98 100644 --- a/docs/source/memory.md +++ b/docs/source/memory.md @@ -31,7 +31,7 @@ Here some explanations on how these values affect the memory: 1. Reduce build time max input tokens - Most of the tensors inside a transformer network have a linear relationship with number of input tokens, so activation size will be close to `max number of input tokens * some constant factor`, the constant factor depends on the network structure and TRT internal optimization. The max number of input tokens is derived from build time arguments, one can change the parameters provided to the `prepare_inputs` function, like `GPTLMHeadModel.prepare_inputs` to affect the memory usage, or one can change the command line options of the `trtllm-build` command used in the examples. + Most of the tensors inside a transformer network have a linear relationship with number of input tokens, so activation size will be close to `max number of input tokens * some constant factor`, the constant factor depends on the network structure and TRT internal optimization. The max number of input tokens is derived from build time arguments, one can change the parameters provided to the `prepare_inputs` function, like `PretrainedModel.prepare_inputs` to affect the memory usage, or one can change the command line options of the `trtllm-build` command used in the examples. When using the [padded tensors](./gpt_attention.md#padded-and-packed-tensors) format, the max number of input tokens equals to `max_batch_size*max_input_len`, so reducing `max_batch_size` and `max_input_len` can almost linearly reduce the activation memory size. When using the [packed tensors](./gpt_attention.md#padded-and-packed-tensors) format and `max_num_tokens` is specified, reducing its value will also reduce activation memory size. If `max_num_tokens` is not specified, the max number of input tokens will be derived as `max_batch_size*max_input_len`. diff --git a/docs/source/new_workflow.md b/docs/source/new_workflow.md index e1587cb12..a78903e44 100644 --- a/docs/source/new_workflow.md +++ b/docs/source/new_workflow.md @@ -56,7 +56,6 @@ The different files will be loaded by different ranks in a multi-GPU (multi-proc | intermediate_size | int | null | | norm_epsilon | float | 1e-5 | | position_embedding_type | string | 'learned_absolute' | -| use_prompt_tuning | bool | false | | mapping.world_size | int | 1 | | mapping.tp_size | int | 1 | | mapping.pp_size | int | 1 | @@ -219,7 +218,6 @@ Here is the `config.json`: "embedding_sharding_dim": 0, "share_embedding_table": false, "do_layer_norm_before": true, - "use_prompt_tuning": false } ``` diff --git a/docs/source/perf_best_practices.md b/docs/source/perf_best_practices.md index 22d21425f..8c6b12d58 100644 --- a/docs/source/perf_best_practices.md +++ b/docs/source/perf_best_practices.md @@ -157,10 +157,20 @@ for details. ### Horizontal Fusion in Gated-MLP Horizontal fusion in Gated-MLP combines two Matmul operations into a single one -followed by a separate SwiGLU kernel. If both model and batch sizes are large, -it is recommended to enable the feature by using the `--use_fused_mlp` argument -with `trtllm-build`. When the workload is very small, it is not recommended to -enable that feature. +followed by a separate SwiGLU kernel. However, for FP8 PTQ, the +downside is slight reduction of accuracy because one of the quantization scaling +factors are discarded. + +If both model and batch sizes are large, it is recommended to enable the feature +by using the `--use_fused_mlp` argument with `trtllm-build`. When the workload +is very small, or if you're using FP8 PTQ and the accuracy after enabling it +does not satisfy your requirement, it is not recommended to enable that feature. + +### GEMM Plugin + +The GEMM plugin utilizes NVIDIA cuBLASLt to perform GEMM operations. On FP16 and +BF16, it's recommended to be enabled for better performance and smaller GPU +memory usage. On FP8, it's recommended to be disabled. ### BERT Attention Plugin and Context Fused Multi-Head Attention diff --git a/docs/source/precision.md b/docs/source/precision.md index c11c10f00..f9998e7de 100644 --- a/docs/source/precision.md +++ b/docs/source/precision.md @@ -137,10 +137,10 @@ This release of TensorRT-LLM contains the following examples: | OPT | Y | Y | Y | . | . | . | . | . | . | | Phi | Y | Y | Y | . | . | . | . | . | . | | Replit Code| Y | Y | Y | . | . | . | . | . | . | -| SantaCoder | Y | Y | Y | . | . | . | . | . | . | +| SantaCoder | Y | Y | Y | . | . | Y | Y | . | . | | Skywork | Y | Y | Y | . | . | . | . | . | . | -| StarCoder1 | Y | Y | Y | . | . | Y | . | . | . | -| StarCoder2 | Y | Y | Y | . | . | Y | . | . | . | +| StarCoder1 | Y | Y | Y | . | . | Y | Y | . | . | +| StarCoder2 | Y | Y | Y | . | . | Y | Y | . | . | | T5 | Y | Y | Y | . | . | . | . | . | . | | Whisper | Y | Y | Y | . | . | Y | Y | . | . | diff --git a/examples/apps/README.md b/examples/apps/README.md new file mode 100644 index 000000000..4fc6884d8 --- /dev/null +++ b/examples/apps/README.md @@ -0,0 +1,33 @@ +# Apps examples with GenerationExecutor / High-level API + +## Python chat + +[chat.py](./chat.py) provides a small examples to play around with your model. You can run it with + +`python3 examples/apps/chat.py ` +or +`mpirun -n python3 examples/apps/chat.py ` + +You can modify prompt setting by entering options starting with '!!'. Type '!!help' to see available commands. + +## FastAPI server + +### Install the additional requirements + +` pip install -r examples/apps/requirements.txt` + +### Start the server + +Suppose you have build an engine with `trtllm-build`, you can now serve it with: + +`python3 -m examples.apps.fastapi_server &` +or +`mpirun -n python3 -m examples.server.server &` + +### Send requests + +You can pass request arguments like "max_new_tokens", "top_p", "top_k" in your JSON dict: +`curl http://localhost:8000/generate -d '{"prompt": "In this example,", "max_new_tokens": 8}'` + +You can also use the streaming interface with: +`curl http://localhost:8000/generate -d '{"prompt": "In this example,", "max_new_tokens": 8, "streaming": true}' --output -` diff --git a/examples/apps/chat.py b/examples/apps/chat.py new file mode 100644 index 000000000..7ccb6648e --- /dev/null +++ b/examples/apps/chat.py @@ -0,0 +1,60 @@ +#! /usr/bin/env python3 +import argparse +import code +import readline # NOQA +from argparse import ArgumentParser +from pathlib import Path + +from tensorrt_llm.executor import GenerationExecutorWorker + + +class LLMChat(code.InteractiveConsole): + + def __init__(self, executor): + super().__init__() + self.executor = executor + self.generation_kwargs = { + "max_new_tokens": 100, + "repetition_penalty": 1.05, + } + self.parser = ArgumentParser(prefix_chars="!") + self.parser.add_argument("!!max_new_tokens", type=int) + self.parser.add_argument("!!repetition_penalty", type=float) + + def runsource(self, + source: str, + filename: str = "", + symbol: str = "single") -> bool: + del filename, symbol + + if source.startswith("!"): + args = self.parser.parse_args(source.split(" ")) + for k, v in vars(args).items(): + if v is not None: + self.generation_kwargs[k] = v + return False + + future = self.executor.generate_async(source, + streaming=True, + **self.generation_kwargs) + for partial_result in future: + print(partial_result.text_diff, end="") + print("") + + return False + + +def main(model_dir: Path, tokenizer: Path | str): + + with GenerationExecutorWorker(model_dir, tokenizer, 1) as executor: + executor.block_subordinates() + repl = LLMChat(executor) + repl.interact(banner="", exitmsg="") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("model_dir", type=Path) + parser.add_argument("tokenizer", type=Path) + args = parser.parse_args() + main(args.model_dir, args.tokenizer) diff --git a/examples/server/server.py b/examples/apps/fastapi_server.py similarity index 65% rename from examples/server/server.py rename to examples/apps/fastapi_server.py index 661fa1ee2..45b123c92 100644 --- a/examples/server/server.py +++ b/examples/apps/fastapi_server.py @@ -1,17 +1,18 @@ import argparse import asyncio import json -from typing import AsyncGenerator +from typing import AsyncGenerator, Optional import uvicorn -from executor import GenerationExecutor from fastapi import FastAPI, Request from fastapi.responses import JSONResponse, Response, StreamingResponse +from tensorrt_llm.executor import GenerationExecutorWorker, GenerationResult + TIMEOUT_KEEP_ALIVE = 5 # seconds. TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds. app = FastAPI() -executor: GenerationExecutor | None = None +executor: Optional[GenerationExecutorWorker] = None @app.get("/stats") @@ -38,34 +39,35 @@ async def generate(request: Request) -> Response: """ request_dict = await request.json() + prompt = request_dict.pop("prompt", "") streaming = request_dict.pop("streaming", False) - promise = executor.generate_async(request_dict.pop("prompt"), - request_dict.pop("max_num_tokens", 8), - streaming) + promise = executor.generate_async(prompt, streaming, **request_dict) + assert isinstance(promise, GenerationResult) async def stream_results() -> AsyncGenerator[bytes, None]: async for output in promise: - yield (json.dumps(output.text) + "\0").encode("utf-8") + yield (json.dumps(output.text_diff) + "\0").encode("utf-8") if streaming: return StreamingResponse(stream_results()) # Non-streaming case - await promise.await_completion() + await promise.aresult() return JSONResponse({"text": promise.text}) async def main(args): global executor - executor = GenerationExecutor(args.model_dir, args.tokenizer_type, - args.max_beam_width) - config = uvicorn.Config(app, - host=args.host, - port=args.port, - log_level="info", - timeout_keep_alive=TIMEOUT_KEEP_ALIVE) - await uvicorn.Server(config).serve() + with GenerationExecutorWorker(args.model_dir, args.tokenizer_type, + args.max_beam_width) as executor: + executor.block_subordinates() + config = uvicorn.Config(app, + host=args.host, + port=args.port, + log_level="info", + timeout_keep_alive=TIMEOUT_KEEP_ALIVE) + await uvicorn.Server(config).serve() if __name__ == "__main__": diff --git a/examples/apps/requirements.txt b/examples/apps/requirements.txt new file mode 100644 index 000000000..97dc7cd8c --- /dev/null +++ b/examples/apps/requirements.txt @@ -0,0 +1,2 @@ +fastapi +uvicorn diff --git a/examples/baichuan/convert_checkpoint.py b/examples/baichuan/convert_checkpoint.py index 68193c017..3c388d472 100644 --- a/examples/baichuan/convert_checkpoint.py +++ b/examples/baichuan/convert_checkpoint.py @@ -62,11 +62,6 @@ def parse_arguments(): type=int, default=1, help='The number of workers for converting checkpoint in parallel') - parser.add_argument( - '--max_prompt_embedding_table_size', - type=int, - default=0, - help='Setting to a value > 0 enables support for prompt tuning.') parser.add_argument( '--per_channel', default=False, @@ -1223,7 +1218,6 @@ def process_and_assign_weight(prefix, v, tp_dim=-1): 'tp_size': args.tp_size, 'pp_size': args.pp_size, }, - 'use_prompt_tuning': args.max_prompt_embedding_table_size > 0, } if args.use_weight_only and args.weight_only_precision == 'int4_gptq': config['quantization'].update({ diff --git a/examples/baichuan/requirements.txt b/examples/baichuan/requirements.txt index d849018f9..5c6fb8e1e 100644 --- a/examples/baichuan/requirements.txt +++ b/examples/baichuan/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.9.0.dev2024031200 +tensorrt_llm==0.9.0.dev2024031900 datasets~=2.15.0 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/bindings/README.md b/examples/bindings/README.md new file mode 100644 index 000000000..799963472 --- /dev/null +++ b/examples/bindings/README.md @@ -0,0 +1,16 @@ +# Python Bindings Example + +This example shows how to use the python bindings interface to generate tokens using a TensorRT engine. + +## Setup + +Build a TensorRT engine for one of the supported TensorRT-LLM model following instructions in the corresponding `examples` folder. + +## Usage + +Run `example.py`, passing in the directory where the TensorRT engine was generated. For example: + +``` +cd examples/bindings +python3 example.py --model_path=../llama/tmp/7B/trt_engines/fp16/1-gpu/ +``` diff --git a/examples/bindings/example.py b/examples/bindings/example.py new file mode 100644 index 000000000..454c69142 --- /dev/null +++ b/examples/bindings/example.py @@ -0,0 +1,33 @@ +import argparse + +import tensorrt_llm.bindings.executor as trtllm + +# This example hows to use the python bindings to create an executor, enqueue a +# request, and get the generated tokens. + +# First, follow the steps in README.md to generate the engines. + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Executor Bindings Example") + parser.add_argument("--model_path", + type=str, + required=True, + help="Directory containing model engine") + args = parser.parse_args() + + # Create the executor. + executor = trtllm.Executor(args.model_path, trtllm.ModelType.DECODER_ONLY, + trtllm.ExecutorConfig(1)) + + # Create the request. + request = trtllm.Request(input_token_ids=[1, 2, 3, 4], max_new_tokens=10) + + # Enqueue the request. + request_id = executor.enqueue_request(request) + + # Wait for the new tokens. + responses = executor.await_responses(request_id) + output_tokens = responses[0].result.output_token_ids + + # Print tokens. + print(output_tokens) diff --git a/examples/bloom/requirements.txt b/examples/bloom/requirements.txt index 756b9635c..9e749148b 100644 --- a/examples/bloom/requirements.txt +++ b/examples/bloom/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.9.0.dev2024031200 +tensorrt_llm==0.9.0.dev2024031900 datasets~=2.14.5 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/chatglm/.gitignore b/examples/chatglm/.gitignore index ecc3168f6..1dca99a83 100644 --- a/examples/chatglm/.gitignore +++ b/examples/chatglm/.gitignore @@ -1,14 +1,11 @@ __pycache__/ .vscode/ -*.pt -awq/ chatglm*/ dataset/ -engine_outputs/ glm*/ -model.cache -output_*/ -sq/ +trt_ckpt/ +trt_engine/ -*.pt model.cache +*.safetensors +*.json diff --git a/examples/chatglm/requirements.txt b/examples/chatglm/requirements.txt index ce3d5bf47..4d3d2b1f0 100644 --- a/examples/chatglm/requirements.txt +++ b/examples/chatglm/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.9.0.dev2024031200 +tensorrt_llm==0.9.0.dev2024031900 datasets~=2.14.5 evaluate~=0.4.1 protobuf diff --git a/examples/cpp_executor/CMakeLists.txt b/examples/cpp_executor/CMakeLists.txt new file mode 100644 index 000000000..44a00cfb0 --- /dev/null +++ b/examples/cpp_executor/CMakeLists.txt @@ -0,0 +1,98 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & +# AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not +# use this file except in compliance with the License. You may obtain a copy of +# the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations under +# the License. cmake needs this line + +cmake_minimum_required(VERSION 3.1) + +set(TRTLLM_DIR "${CMAKE_CURRENT_SOURCE_DIR}/../../") +include(${TRTLLM_DIR}/cpp/cmake/modules/set_ifndef.cmake) +include(${TRTLLM_DIR}/cpp/cmake/modules/find_library_create_target.cmake) + +add_compile_options("-D_GLIBCXX_USE_CXX11_ABI=0") + +# Enable C++11 +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED TRUE) + +# Define project name +set(TARGET_NAME tensorrt_llm_executor) +project(${TARGET_NAME}) + +set(CMAKE_VERBOSE_MAKEFILE 1) + +# Compile options +set(CMAKE_CXX_FLAGS "-Wall -pthread ") +set(CMAKE_CXX_FLAGS_RELEASE "-O3") +set(CMAKE_CXX_FLAGS "${CMAKE_C_FLAGS} -lstdc++") + +set(CMAKE_BUILD_TYPE release) + +find_package(CUDA REQUIRED) +message(STATUS "CUDA library status:") +message(STATUS " config: ${CUDA_DIR}") +message(STATUS " version: ${CUDA_VERSION}") +message(STATUS " libraries: ${CUDA_LIBRARIES}") +message(STATUS " include path: ${CUDA_INCLUDE_DIRS}") + +if(${CUDA_VERSION} VERSION_GREATER_EQUAL "11") + add_definitions("-DENABLE_BF16") + message( + STATUS + "CUDA_VERSION ${CUDA_VERSION} is greater or equal than 11.0, enable -DENABLE_BF16 flag" + ) +endif() + +if(${CUDA_VERSION} VERSION_GREATER_EQUAL "11.8") + add_definitions("-DENABLE_FP8") + message( + STATUS + "CUDA_VERSION ${CUDA_VERSION} is greater or equal than 11.8, enable -DENABLE_FP8 flag" + ) +endif() + +# Declare the executable target built from your sources +add_executable(${TARGET_NAME} main.cpp) + +set_ifndef(TRT_LIB_DIR + /usr/local/tensorrt/targets/${CMAKE_SYSTEM_PROCESSOR}-linux-gnu/lib) +set_ifndef( + TRT_INCLUDE_DIR + /usr/local/tensorrt/targets/${CMAKE_SYSTEM_PROCESSOR}-linux-gnu/include) + +set(TRT_LIB nvinfer) +find_library_create_target(${TRT_LIB} nvinfer SHARED ${TRT_LIB_DIR}) + +# +# tensorrt_llm shared lib +add_library(tensorrt_llm SHARED IMPORTED) +set_property( + TARGET tensorrt_llm + PROPERTY IMPORTED_LOCATION + "${TRTLLM_DIR}/cpp/build/tensorrt_llm/libtensorrt_llm.so") + +# nvinfer_plugin_tensorrt_llm shared lib +add_library(nvinfer_plugin_tensorrt_llm SHARED IMPORTED) +set_property( + TARGET nvinfer_plugin_tensorrt_llm + PROPERTY + IMPORTED_LOCATION + "${TRTLLM_DIR}/cpp/build/tensorrt_llm/plugins/libnvinfer_plugin_tensorrt_llm.so" +) + +target_link_libraries(${TARGET_NAME} ${CUDA_LIBRARIES} nvinfer + nvinfer_plugin_tensorrt_llm tensorrt_llm) + +# Set include folders +target_include_directories(${TARGET_NAME} PUBLIC /usr/local/cuda/include) +target_include_directories(${TARGET_NAME} PUBLIC ${TRTLLM_DIR}/cpp/include/) diff --git a/examples/cpp_executor/main.cpp b/examples/cpp_executor/main.cpp new file mode 100644 index 000000000..09d29abe6 --- /dev/null +++ b/examples/cpp_executor/main.cpp @@ -0,0 +1,66 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include + +#include "tensorrt_llm/executor/executor.h" +#include "tensorrt_llm/plugins/api/tllmPlugin.h" +#include "tensorrt_llm/runtime/tllmLogger.h" + +namespace tle = tensorrt_llm::executor; + +int main(int argc, char* argv[]) +{ + // Register the TRT-LLM plugins and the logger + auto logger = std::make_shared(); + initTrtLlmPlugins(logger.get()); + + if (argc != 2) + { + logger->log(nvinfer1::ILogger::Severity::kERROR, "Usage: ./tensorrt_llm_executor "); + return 1; + } + + // Create the executor for this engine + tle::SizeType beamWidth = 1; + auto executorConfig = tle::ExecutorConfig(beamWidth); + auto trtEnginePath = argv[1]; + auto executor = tle::Executor(trtEnginePath, tle::ModelType::kDECODER_ONLY, executorConfig); + + // Create the request + tle::SizeType maxNewTokens = 5; + tle::VecTokens inputTokens{1, 2, 3, 4}; + auto request = tle::Request(inputTokens, maxNewTokens); + + // Enqueue the request + auto requestId = executor.enqueueRequest(std::move(request)); + + // Wait for the response + auto responses = executor.awaitResponses(requestId); + + // Get outputTokens + auto outputTokens = responses.at(0).getResult().outputTokenIds.at(beamWidth - 1); + + logger->log(nvinfer1::ILogger::Severity::kINFO, "Output tokens: "); + for (auto& outputToken : outputTokens) + { + logger->log(nvinfer1::ILogger::Severity::kINFO, std::to_string(outputToken).c_str()); + } + + return 0; +} diff --git a/examples/enc_dec/README.md b/examples/enc_dec/README.md index fa7b2777c..77eceec69 100644 --- a/examples/enc_dec/README.md +++ b/examples/enc_dec/README.md @@ -208,7 +208,7 @@ python build.py --model_type bart \ --max_beam_width 1 ``` -* Run the engine, setting `--lora_dir` and `--lora_task_uids`. `--lora_task_uids` should be set as a list of uids which length equals to batch size. The following example is for batch size = 2: +* Run the engine, setting `--lora_dir` and `--lora_task_uids`. `--lora_task_uids` should be set as a list of uids which length equals to batch size. The following example is for batch size = 3: ```bash python run.py \ @@ -218,7 +218,7 @@ python run.py \ --max_new_token=64 \ --num_beams=1 \ --lora_dir tmp/hf_models/bart-large-cnn-samsum-lora/ \ - --lora_task_uids 0 0 + --lora_task_uids 0 0 0 ``` * Run with multi-loRA, append `--lora_dir` with other lora directories and set `--lora_task_uids` according to the index of the lora directories. Set to "-1" to run with the base model: diff --git a/examples/enc_dec/run.py b/examples/enc_dec/run.py index c130550f1..6f60bb78e 100644 --- a/examples/enc_dec/run.py +++ b/examples/enc_dec/run.py @@ -118,8 +118,6 @@ def read_config(config_path: Path): max_prompt_embedding_table_size=max_prompt_embedding_table_size, lora_plugin=use_lora_plugin, lora_target_modules=builder_config.get('lora_target_modules'), - hf_modules_to_trtllm_modules=builder_config.get( - 'hf_modules_to_trtllm_modules'), trtllm_modules_to_hf_modules=builder_config.get( 'trtllm_modules_to_hf_modules'), skip_cross_qkv=skip_cross_qkv, diff --git a/examples/falcon/requirements.txt b/examples/falcon/requirements.txt index 59eb000e3..604df7546 100644 --- a/examples/falcon/requirements.txt +++ b/examples/falcon/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.9.0.dev2024031200 +tensorrt_llm==0.9.0.dev2024031900 transformers>=4.31.0 datasets~=2.14.5 evaluate~=0.4.1 diff --git a/examples/gemma/requirements.txt b/examples/gemma/requirements.txt index 3661de075..8897034f4 100644 --- a/examples/gemma/requirements.txt +++ b/examples/gemma/requirements.txt @@ -1,6 +1,6 @@ -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.9.0.dev2024031200 +tensorrt_llm==0.9.0.dev2024031900 flax~=0.8.0 jax[cuda12_pip]~=0.4.19; platform_system != "Windows" jax~=0.4.19; platform_system == "Windows" diff --git a/examples/generate_checkpoint_config.py b/examples/generate_checkpoint_config.py new file mode 100644 index 000000000..c36e3e408 --- /dev/null +++ b/examples/generate_checkpoint_config.py @@ -0,0 +1,187 @@ +import argparse +import json +import os + +from tensorrt_llm.quantization.mode import ( + FP8, INT8, W4A8_AWQ, W4A16, W4A16_AWQ, W4A16_GPTQ, W8A8_SQ_PER_CHANNEL, + W8A8_SQ_PER_CHANNEL_PER_TENSOR_PLUGIN, W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN, + W8A8_SQ_PER_TENSOR_PER_TOKEN_PLUGIN, W8A8_SQ_PER_TENSOR_PLUGIN, W8A16) + + +def parse_arguments(): + parser = argparse.ArgumentParser() + + parser.add_argument( + '--output_path', + type=str, + default='config.json', + help='The path to save the TensorRT-LLM checkpoint config.json file') + parser.add_argument('--architecture', type=str, default='GPTForCausalLM') + parser.add_argument('--dtype', + type=str, + default='float16', + choices=['float32', 'bfloat16', 'float16']) + parser.add_argument('--vocab_size', type=int, default=32000) + parser.add_argument('--max_position_embeddings', type=int, default=1024) + parser.add_argument('--hidden_size', type=int, default=768) + parser.add_argument('--intermediate_size', type=int, default=None) + parser.add_argument('--num_hidden_layers', type=int, default=12) + parser.add_argument('--num_attention_heads', type=int, default=12) + parser.add_argument('--num_key_value_heads', type=int, default=None) + parser.add_argument('--hidden_act', type=str, default='gelu') + parser.add_argument('--norm_epsilon', type=float, default=1e-5) + parser.add_argument('--position_embedding_type', + type=str, + default='learned_absolute') + parser.add_argument( + '--use_parallel_embedding', + action='store_true', + default=False, + help= + 'By default embedding parallelism is disabled. By setting this flag, embedding parallelism is enabled' + ) + parser.add_argument( + '--embedding_sharding_dim', + type=int, + default=0, + choices=[0, 1], + help= + 'By default the embedding lookup table is sharded along vocab dimension (embedding_sharding_dim=0). ' + 'To shard it along hidden dimension, set embedding_sharding_dim=1' + 'Note: embedding sharing is only enabled when embedding_sharding_dim = 0' + ) + parser.add_argument( + '--share_embedding_table', + action='store_true', + default=False, + help= + 'Try to reduce the engine size by sharing the embedding lookup table between two layers.' + 'Note: the flag might not take effect when the criteria are not met.') + + parser.add_argument('--tp_size', + type=int, + default=1, + help='N-way tensor parallelism size') + parser.add_argument('--pp_size', + type=int, + default=1, + help='N-way pipeline parallelism size') + + parser.add_argument('--quant_algo', + type=str, + default=None, + choices=[ + None, W8A16, W4A16, W4A16_AWQ, W4A8_AWQ, W4A16_GPTQ, + W8A8_SQ_PER_CHANNEL, W8A8_SQ_PER_TENSOR_PLUGIN, + W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN, + W8A8_SQ_PER_CHANNEL_PER_TENSOR_PLUGIN, + W8A8_SQ_PER_TENSOR_PER_TOKEN_PLUGIN + ]) + parser.add_argument('--kv_cache_quant_algo', + type=str, + default=None, + choices=[None, FP8, INT8]) + parser.add_argument('--group_size', type=int, default=64) + parser.add_argument('--has_zero_point', default=False, action='store_true') + parser.add_argument('--pre_quant_scale', default=False, action='store_true') + parser.add_argument('--exclude_modules', nargs='+', default=None) + + parser.add_argument('--bias', default=False, action='store_true') + parser.add_argument('--apply_query_key_layer_scaling', + default=False, + action='store_true') + parser.add_argument('--rotary_pct', type=float, default=1.0) + parser.add_argument('--rotary_base', type=float, default=10000.0) + parser.add_argument('--rotary_scaling', nargs=2, type=str, default=None) + + parser.add_argument( + '--max_lora_rank', + type=int, + default=64, + help='maximum lora rank for different lora modules. ' + 'It is used to compute the workspace size of lora plugin.') + parser.add_argument( + '--lora_target_modules', + nargs='+', + default=None, + choices=[ + "attn_qkv", + "attn_q", + "attn_k", + "attn_v", + "attn_dense", + "mlp_h_to_4h", + "mlp_gate", + "mlp_4h_to_h", + ], + help= + "Add lora in which modules. Only be activated when use_lora_plugin is enabled." + ) + + args = parser.parse_args() + return args + + +if __name__ == '__main__': + args = parse_arguments() + world_size = args.tp_size * args.pp_size + + assert args.output_path.endswith('.json') + output_dir = os.path.dirname(args.output_path) + if output_dir and not os.path.exists(output_dir): + os.makedirs(output_dir) + + config = { + 'architecture': args.architecture, + 'dtype': args.dtype, + 'vocab_size': args.vocab_size, + 'max_position_embeddings': args.max_position_embeddings, + 'hidden_size': args.hidden_size, + 'intermediate_size': args.intermediate_size, + 'num_hidden_layers': args.num_hidden_layers, + 'num_attention_heads': args.num_attention_heads, + 'num_key_value_heads': args.num_key_value_heads, + 'hidden_act': args.hidden_act, + 'norm_epsilon': args.norm_epsilon, + 'position_embedding_type': args.position_embedding_type, + 'use_parallel_embedding': args.use_parallel_embedding, + 'embedding_sharding_dim': args.embedding_sharding_dim, + 'share_embedding_table': args.share_embedding_table, + 'quantization': { + 'quant_algo': args.quant_algo, + 'kv_cache_quant_algo': args.kv_cache_quant_algo, + }, + 'mapping': { + 'world_size': world_size, + 'tp_size': args.tp_size, + 'pp_size': args.pp_size, + }, + 'bias': args.bias, + 'apply_query_key_layer_scaling': args.apply_query_key_layer_scaling, + 'rotary_pct': args.rotary_pct, + 'rotary_base': args.rotary_base, + 'rotary_scaling': args.rotary_scaling, + 'max_lora_rank': args.max_lora_rank, + 'lora_target_modules': args.lora_target_modules, + } + + if args.intermediate_size is None: + config['intermediate_size'] = args.hidden_size * 4 + if args.num_key_value_heads is None: + config['num_key_value_heads'] = args.num_attention_heads + + if args.quant_algo is not None: + if 'AWQ' in args.quant_algo or 'GPTQ' in args.quant_algo: + config['quantization'].update({ + 'group_size': + args.group_size, + 'has_zero_point': + args.has_zero_point, + 'pre_quant_scale': + args.pre_quant_scale, + 'exclude_modules': + args.exclude_modules, + }) + + with open(args.output_path, 'w') as f: + json.dump(config, f, indent=4) diff --git a/examples/gpt/README.md b/examples/gpt/README.md index d7670f8fc..9d18c8e99 100644 --- a/examples/gpt/README.md +++ b/examples/gpt/README.md @@ -1,15 +1,12 @@ # GPT -This document explains how to build the [GPT](https://huggingface.co/gpt2) model using TensorRT-LLM and run on a single GPU, a single node with -multiple GPUs or multiple nodes with multiple GPUs. +This document explains how to build the [GPT](https://huggingface.co/gpt2) model using TensorRT-LLM and run on a single GPU, a single node with multiple GPUs or multiple nodes with multiple GPUs. ## Overview -The TensorRT-LLM GPT implementation can be found in [`tensorrt_llm/models/gpt/model.py`](../../tensorrt_llm/models/gpt/model.py). The TensorRT-LLM GPT example code is located in [`examples/gpt`](./). There are two main files: +The TensorRT-LLM GPT implementation can be found in [`tensorrt_llm/models/gpt/model.py`](../../tensorrt_llm/models/gpt/model.py). The TensorRT-LLM GPT example code is located in [`examples/gpt`](./). There is one main file: -* [`hf_gpt_convert.py`](./hf_gpt_convert.py) to convert a checkpoint from the [HuggingFace (HF) Transformers](https://github.com/huggingface/transformers) - format to the [FasterTransformer (FT)](https://github.com/NVIDIA/FasterTransformer) format, -* [`build.py`](./build.py) to build the [TensorRT](https://developer.nvidia.com/tensorrt) engine(s) needed to run the GPT model. +* [`convert_checkpoint.py`](./convert_checkpoint.py) to convert a checkpoint from the [HuggingFace (HF) Transformers](https://github.com/huggingface/transformers) format to the TensorRT-LLM format. In addition, there are two shared files in the parent folder [`examples`](../) for inference and evaluation: @@ -28,128 +25,201 @@ In addition, there are two shared files in the parent folder [`examples`](../) f ## Usage The next two sections describe how to convert the weights from the [HuggingFace (HF) Transformers](https://github.com/huggingface/transformers) -format to the FT format. You can skip those two sections if you already have weights in the -FT format. - -Note, also, that if your weights are neither in HF Transformers nor in FT formats, you will need to convert to the FT format. The script like -[`hf_gpt_convert.py`](./hf_gpt_convert.py) can serve as a starting point. +format to the TensorRT-LLM format. ### 1. Download weights from HuggingFace Transformers ```bash -# Weights & config +# Download hf gpt2 model rm -rf gpt2 && git clone https://huggingface.co/gpt2-medium gpt2 pushd gpt2 && rm pytorch_model.bin model.safetensors && wget -q https://huggingface.co/gpt2-medium/resolve/main/pytorch_model.bin && popd ``` -### 2. Convert weights from HF Transformers to FT format - -TensorRT-LLM can directly load weights from FT. The [`hf_gpt_convert.py`](./hf_gpt_convert.py) script allows you to convert weights from HF Transformers -format to FT format. +### 2. Convert weights from HF Transformers to TensorRT-LLM format +The [`convert_checkpoint.py`](./convert_checkpoint.py) script converts HF weights to TensorRT-LLM checkpoints. The number of checkpoint files (in .safetensors format) is same to the number of GPUs used to run inference. ```bash -python3 hf_gpt_convert.py -i gpt2 -o ./c-model/gpt2 --tensor-parallelism 1 --storage-type float16 +# single gpu, dtype float16 +python3 convert_checkpoint.py --model_dir gpt2 \ + --dtype float16 \ + --output_dir gpt2/trt_ckpt/fp16/1-gpu + +# 2-way tensor parallelism +python3 convert_checkpoint.py --model_dir gpt2 \ + --dtype float16 \ + --tp_size 2 \ + --output_dir gpt2/trt_ckpt/fp16/2-gpu ``` -This script uses multiple processes to speed-up writing the model to disk. This may saturate your RAM depending on the model you are exporting. -In case that happens, you can reduce the number of processes with `--processes `. Set it to 1 for minimal RAM usage. - ### 3. Build TensorRT engine(s) +The `trtllm-build` command builds TensorRT-LLM engines from TensorRT-LLM checkpoints. The checkpoint directory provides the model's weights and architecture configuration. The number of engine files is also same to the number of GPUs used to run inference. -TensorRT-LLM builds TensorRT engine(s) using a checkpoint in FT format. The checkpoint directory provides the model's weights, architecture configuration -and custom tokenizer if specified. If no checkpoint directories are specified, TensorRT-LLM will build engine(s) using random weights. When building with -random weights, you can use command-line arguments to modify the architecture: `--n_layer, --n_head, --n_embd, --hidden_act, --no_bias, ...` -Also, note that the number of TensorRT engines depends on the number of GPUs that will be used to run inference. - -The [`build.py`](./build.py) script requires a single GPU to build the TensorRT engine(s). However, if you have more than one GPU in your system (of the same -model), you can enable parallel builds to accelerate the engine building process. For that, add the `--parallel_build` argument to the build command. Please -note that for the moment, the `parallel_build` feature cannot take advantage of more than a single node. - -Examples of build invocations: +Normally, the `trtllm-build` command only requires a single GPU, but you can enable parallel building by passing the number of GPUs to the `--workers` argument. ```bash -# Build a single-GPU float16 engine using FT weights. -# Enable the special TensorRT-LLM GPT Attention plugin (--use_gpt_attention_plugin) to increase runtime performance. -# It is recommend to use --remove_input_padding along with --use_gpt_attention_plugin for better performance -python3 build.py --model_dir=./c-model/gpt2/1-gpu --use_gpt_attention_plugin --remove_input_padding - -# Build 8-GPU GPT-175B float16 engines using dummy weights, useful for performance tests. -# Enable several TensorRT-LLM plugins to increase runtime performance. It also helps with build time. -python3 build.py --world_size=8 \ - --log_level=verbose \ - --n_layer=96 \ - --n_embd=12288 \ - --n_head=96 \ - --max_batch_size=256 \ - --dtype float16 \ - --remove_input_padding \ - --use_gpt_attention_plugin \ - --enable_context_fmha \ - --use_gemm_plugin \ - --output_dir=gpt_175b 2>&1 | tee build.log +# Build a single-GPU float16 engine from TensorRT-LLM checkpoint. +# Enable the special TensorRT-LLM GPT Attention plugin (--gpt_attention_plugin) to increase runtime performance. +# It is recommend to use --remove_input_padding along with --gpt_attention_plugin for better performance +trtllm-build --checkpoint_dir gpt2/trt_ckpt/fp16/1-gpu \ + --gpt_attention_plugin float16 \ + --remove_input_padding enable \ + --output_dir gpt2/trt_engines/fp16/1-gpu + +# Build 2-way tensor parallelism engines from TensorRT-LLM checkpoint. +trtllm-build --checkpoint_dir gpt2/trt_ckpt/fp16/2-gpu \ + --gpt_attention_plugin float16 \ + --remove_input_padding enable \ + --output_dir gpt2/trt_engines/fp16/2-gpu +``` -# Build 16-GPU GPT-530B float16 engines using dummy weights, useful for performance tests. -# Enable several TensorRT-LLM plugins to increase runtime performance. It also helps with build time. -python3 build.py --world_size=16 \ - --log_level=info \ - --n_layer=105 \ - --n_embd=20480 \ - --n_head=128 \ - --max_batch_size=128 \ - --max_input_len=128 \ - --max_output_len=20 \ - --dtype float16 \ - --remove_input_padding \ - --use_gpt_attention_plugin \ - --enable_context_fmha \ - --use_gemm_plugin \ - --output_dir=gpt_530b 2>&1 | tee build.log +If the engines are built successfully, you will see output like: +``` +...... +[03/12/2024-10:21:08] [TRT] [I] Engine generation completed in 35.9738 seconds. +[03/12/2024-10:21:08] [TRT] [I] [MemUsageStats] Peak memory usage of TRT CPU/GPU memory allocators: CPU 212 MiB, GPU 775 MiB +[03/12/2024-10:21:08] [TRT] [I] [MemUsageChange] TensorRT-managed allocation in building engine: CPU +0, GPU +775, now: CPU 0, GPU 775 (MiB) +[03/12/2024-10:21:09] [TRT] [I] [MemUsageStats] Peak memory usage during Engine building and serialization: CPU: 6600 MiB +[03/12/2024-10:21:09] [TRT-LLM] [I] Total time of building Unnamed Network 0: 00:00:36 +[03/12/2024-10:21:09] [TRT-LLM] [I] Serializing engine to gpt2/trt_engines/fp16/1-gpu/rank0.engine... +[03/12/2024-10:21:11] [TRT-LLM] [I] Engine serialized. Total time: 00:00:02 +[03/12/2024-10:21:11] [TRT-LLM] [I] Total time of building all engines: 00:00:41 ``` #### Fused MultiHead Attention (FMHA) -You can enable the FMHA kernels for GPT by adding `--enable_context_fmha` to the invocation of `build.py`. +You can enable the FMHA kernels by adding `--context_fmha enable` to the invocation of `trtllm-build`. -If you find that the default fp16 accumulation (`--enable_context_fmha`) cannot meet the requirement, you can try to enable fp32 accumulation by adding `--enable_context_fmha_fp32_acc`. However, it is expected to see performance drop. +If you find that the default fp16 accumulation (`--context_fmha enable`) cannot meet the requirement, you can try to enable fp32 accumulation by adding `--context_fmha_fp32_acc enable`. However, it is expected to see performance drop. -Note `--enable_context_fmha` / `--enable_context_fmha_fp32_acc` has to be used together with `--use_gpt_attention_plugin float16`. +Note that the FMHA kernels have to be used together with `--gpt_attention_plugin float16`. #### In-flight batching and paged KV cache -If one wants to use [in-flight batching in C++ runtime](../../docs/in_flight_batching.md), the engine must be built accordingly. -In-flight batching is enabled by adding `--use_inflight_batching` to the invocation of `build.py`. -Note that in-flight batching in C++ runtime works only with attention plugin `--use_gpt_attention_plugin=float16`, paged KV cache `--paged_kv_cache` and with packed data `--remove_input_padding`. -Adding `--use_inflight_batching` will enable these three flags if not already enabled. It is possible to choose a different precision for `--use_gpt_attention_plugin` if the flag is provided separately. -One can additionally control the size of the block in paged KV cache using `--tokens_per_block=N`. +If one wants to use [in-flight batching in C++ runtime](../../docs/in_flight_batching.md), the engine must be built accordingly. In-flight batching in C++ runtime works only with attention plugin, paged KV cache and with packed data. Hence, the `trtllm-build` should be called with `--gpt_attention_plugin float16`, `--paged_kv_cache enable`, `--remove_input_padding enable`. It is possible to choose a different precision for `--gpt_attention_plugin` if the flag is provided separately. One can additionally control the size of the block in paged KV cache using `--tokens_per_block=N`. + +### 4. Build TensorRT engine(s) with Random Weights +You can build engine(s) using random weights, which is useful for benchmarking. First, the [`../generate_checkpoint_config.py`](../generate_checkpoint_config.py) script can be used to generate a TensorRT-LLM checkpoint config file: -### 4. Run +```bash +# Generate an 8-GPU GPT-175B float16 checkpoint config file. +python3 ../generate_checkpoint_config.py --architecture GPTForCausalLM \ + --vocab_size 51200 \ + --hidden_size 12288 \ + --num_hidden_layers 96 \ + --num_attention_heads 96 \ + --dtype float16 \ + --tp_size 8 \ + --output_path gpt_175b/trt_ckpt/fp16/8-gpu/config.json + + +# Generate a 16-GPU GPT-530B float16 checkpoint config file. +python3 ../generate_checkpoint_config.py --architecture GPTForCausalLM \ + --vocab_size 51200 \ + --hidden_size 20480 \ + --num_hidden_layers 105 \ + --num_attention_heads 128 \ + --dtype float16 \ + --tp_size 16 \ + --output_path gpt_530b/trt_ckpt/fp16/16-gpu/config.json +``` + +Then, use `trtllm-build` command to build engine(s) with random weights and the model architecture specified by the generated config file. + +```bash +# Build 8-GPU GPT-175B float16 engines using dummy weights, useful for performance tests. +# Enable several TensorRT-LLM plugins to increase runtime performance. It also helps with build time. +trtllm-build --model_config gpt_175b/trt_ckpt/fp16/8-gpu/config.json \ + --gpt_attention_plugin float16 \ + --remove_input_padding enable \ + --context_fmha enable \ + --gemm_plugin float16 \ + --max_batch_size 256 \ + --output_dir gpt_175b/trt_engines/fp16/8-gpu \ + --workers 8 +# Build 16-GPU GPT-530B float16 engines using dummy weights, useful for performance tests. +# Enable several TensorRT-LLM plugins to increase runtime performance. It also helps with build time. +trtllm-build --model_config gpt_530b/trt_ckpt/fp16/16-gpu/config.json \ + --gpt_attention_plugin float16 \ + --remove_input_padding enable \ + --context_fmha enable \ + --gemm_plugin float16 \ + --max_batch_size 128 \ + --max_input_len 128 \ + --max_output_len 20 \ + --output_dir gpt_530b/trt_engines/fp16/16-gpu \ + --workers 8 +``` +### 5. Run inference #### Single node, single GPU -To run a TensorRT-LLM GPT model on a single GPU, you can use `python3`: +The [`../run.py`](../run.py) script can be used to run inference with the built engine(s). ```bash -# Run the GPT-350M model on a single GPU. -python3 ../run.py --max_output_len=8 --no_add_special_tokens +python3 ../run.py --engine_dir gpt2/trt_engines/fp16/1-gpu \ + --tokenizer_dir gpt2 \ + --max_output_len 8 +``` + +If the engines are run successfully, you will see output like: +``` +...... +Input [Text 0]: "Born in north-east France, Soyer trained as a" +Output [Text 0 Beam 0]: " chef before moving to London in the early" +``` + +The [`../summarize.py`](../summarize.py) script can run the built engines to summarize the articles from the [cnn_dailymail](https://huggingface.co/datasets/cnn_dailymail) dataset. +For each summary, the script can compute the +[ROUGE](https://en.wikipedia.org/wiki/ROUGE_(metric)) scores and use the `ROUGE-1` score to validate the implementation. +By passing `--test_trt_llm` flag, the script will evaluate TensorRT-LLM engines. You may also pass `--test_hf` flag to evaluate the HF model. + +```bash +python3 ../summarize.py --engine_dir gpt2/trt_engines/fp16/1-gpu \ + --hf_model_dir gpt2 \ + --test_trt_llm \ + --test_hf +``` +If the engines are run successfully, you will see output like: +``` +...... +[03/13/2024-05:43:18] [TRT-LLM] [I] TensorRT-LLM (total latency: 1.520904541015625 sec) +[03/13/2024-05:43:18] [TRT-LLM] [I] TensorRT-LLM (total output tokens: 0) +[03/13/2024-05:43:18] [TRT-LLM] [I] TensorRT-LLM (tokens per second: 0.0) +[03/13/2024-05:43:18] [TRT-LLM] [I] TensorRT-LLM beam 0 result +[03/13/2024-05:43:18] [TRT-LLM] [I] rouge1 : 21.13474087351942 +[03/13/2024-05:43:18] [TRT-LLM] [I] rouge2 : 6.2641616526063775 +[03/13/2024-05:43:18] [TRT-LLM] [I] rougeL : 16.693574311238077 +[03/13/2024-05:43:18] [TRT-LLM] [I] rougeLsum : 18.477384201634088 +[03/13/2024-05:43:18] [TRT-LLM] [I] Hugging Face (total latency: 8.76440143585205 sec) +[03/13/2024-05:43:18] [TRT-LLM] [I] HF beam 0 result +[03/13/2024-05:43:18] [TRT-LLM] [I] rouge1 : 20.834898522466 +[03/13/2024-05:43:18] [TRT-LLM] [I] rouge2 : 5.6914719275508805 +[03/13/2024-05:43:18] [TRT-LLM] [I] rougeL : 16.297064309934132 +[03/13/2024-05:43:18] [TRT-LLM] [I] rougeLsum : 18.018627021792142 ``` #### Single node, multiple GPUs -To run a model using multiple GPUs on a single node, you can use `mpirun` as: +To run engines using multiple GPUs on a single node, you can use `mpirun` as: ```bash -# Run the GPT-175B model on a single node using multiple GPUs. -mpirun -np 8 python3 ../run.py --max_output_len=8 --engine_dir=gpt_175b --no_add_special_tokens +mpirun -np 2 \ + python3 ../run.py --engine_dir gpt2/trt_engines/fp16/2-gpu \ + --tokenizer_dir gpt2 \ + --max_output_len 8 + +# Note that GPT-175B is built with random weights, so the output will also be random +mpirun -np 8 \ + python3 ../run.py --engine_dir gpt_175b/trt_engines/fp16/8-gpu \ + --max_output_len 8 ``` #### Multiple nodes, multiple GPUs using [Slurm](https://slurm.schedmd.com) -To run a model using multiple nodes, you should use a cluster manager like `Slurm`. The following section shows how to configure -TensorRT-LLM to execute on two nodes using Slurm. +To run engines using multiple nodes, you should use a cluster manager like `Slurm`. The following section shows how to configure TensorRT-LLM to execute on two nodes using Slurm. -We start by preparing an `sbatch` script called `tensorrt_llm_run.sub`. That script contains the following code (you must replace -the `` strings with your own values): +We start by preparing an `sbatch` script called `tensorrt_llm_run.sub`. That script contains the following code (you must replace the `` strings with your own values): ```bash #!/bin/bash @@ -165,11 +235,12 @@ the `` strings with your own values): sudo nvidia-smi -lgc 1410,1410 srun --mpi=pmix \ - --container-image \ - --container-mounts : \ - --container-workdir \ - --output logs/tensorrt_llm_%t.out \ - --error logs/tensorrt_llm_%t.error python3 -u ../run.py --max_output_len=8 --engine_dir --no_add_special_tokens + --container-image \ + --container-mounts : \ + --container-workdir \ + --output logs/tensorrt_llm_%t.out \ + --error logs/tensorrt_llm_%t.error \ + python3 -u ../run.py --engine_dir --max_output_len 8 ``` Then, submit the job using: @@ -180,110 +251,13 @@ sbatch tensorrt_llm_run.sub You might have to contact your cluster's administrator to help you customize the above script. -## GPT Variant - SantaCoder -The SantaCoder extends the existing GPT model with multi-query attention mechanism. The following example shows building a 4-GPU engine and running simple prompt to generate the implementation of `hello_world()`. - -The main differences in this example are: -1. In model conversion `hf_gpt_convert.py` where extra option `--model santacoder` is required to allow converting checkpoint correctly -2. In engine execution `../run.py` where `--tokenizer_dir ./santacoder` needs to be specified to decode the output ids correctly. - -```bash -git clone https://huggingface.co/bigcode/santacoder +## Quantization -python3 hf_gpt_convert.py -p 8 --model santacoder -i ./santacoder -o ./c-model/santacoder --tensor-parallelism 4 --storage-type float16 - -python3 build.py \ - --model_dir ./c-model/santacoder/4-gpu \ - --remove_input_padding \ - --use_gpt_attention_plugin \ - --enable_context_fmha \ - --use_gemm_plugin \ - --parallel_build \ - --output_dir santacoder_outputs_tp4 \ - --world_size 4 - -mpirun -np 4 python3 ../run.py --engine_dir santacoder_outputs_tp4 --tokenizer_dir ./santacoder --input_text "def print_hello_world():" --max_output_len 20 --no_add_special_tokens -``` - -## GPT Variant - StarCoder (v1 and v2) - -For StarCoder, the steps are similar except that `santacoder` is swapped with `starcoder`. - -```bash -git clone https://huggingface.co/bigcode/starcoder - -python3 hf_gpt_convert.py -p 8 --model starcoder -i ./starcoder -o ./c-model/starcoder --tensor-parallelism 4 --storage-type float16 - -python3 build.py \ - --model_dir ./c-model/starcoder/4-gpu \ - --remove_input_padding \ - --use_gpt_attention_plugin \ - --enable_context_fmha \ - --use_gemm_plugin \ - --parallel_build \ - --output_dir starcoder_outputs_tp4 \ - --world_size 4 - -mpirun -np 4 python3 ../run.py --engine_dir starcoder_outputs_tp4 --tokenizer_dir ./starcoder --input_text "def print_hello_world():" --max_output_len 20 --no_add_special_tokens -``` - -For StarCoder2, you can use almost the same steps as shown above by just setting `--model starcoder2` when converting the huggingface models. - - Note that StarCoder2 hasn't been merged to the official releases of transformers package yet, so remember using the [main branch of transformers repo](https://github.com/huggingface/transformers). - - Add `--max_attention_window_size 4096` when running with run.py or summarization, which enables the sliding window attention. - - the sliding window size comes from the hf model [config.json](https://huggingface.co/bigcode/starcoder2-15b/blob/main/config.json#L23). - -## Summarization using the GPT model - -The following section describes how to run a TensorRT-LLM GPT model to summarize the articles from the -[cnn_dailymail](https://huggingface.co/datasets/cnn_dailymail) dataset. For each summary, the script can compute the -[ROUGE](https://en.wikipedia.org/wiki/ROUGE_(metric)) scores and use the `ROUGE-1` score to validate the implementation. -The script can also perform the same summarization using the HF GPT model. - -As previously explained, the first step is to convert from an HF checkpoint and build the TensorRT engines. - -```bash -# Load the GPT2 weights from the HF hub. -pip install -r requirements.txt -rm -rf gpt2 && git clone https://huggingface.co/gpt2 -pushd gpt2 && rm pytorch_model.bin model.safetensors && wget -q https://huggingface.co/gpt2/resolve/main/pytorch_model.bin && popd - -# Convert the weights to FT format. -python3 hf_gpt_convert.py -i gpt2 -o ./c-model/gpt2/fp16 --tensor-parallelism 1 --storage-type float16 - -# Build the model. -python3 build.py --model_dir=./c-model/gpt2/fp16/1-gpu \ - --remove_input_padding \ - --use_gpt_attention_plugin \ - --enable_context_fmha \ - --use_gemm_plugin \ - --max_batch_size 8 \ - --max_input_len 924 \ - --max_output_len 100 \ - --output_dir trt_engine/gpt2/fp16/1-gpu/ \ - --hidden_act gelu -``` - -The summarization can be done using the [`../summarize.py`](../summarize.py) script as follows: - -```bash -# Run the summarization task. -python3 ../summarize.py --engine_dir trt_engine/gpt2/fp16/1-gpu \ - --hf_model_dir gpt2 \ - --test_trt_llm \ - --test_hf \ - --batch_size 1 \ - --check_accuracy \ - --tensorrt_llm_rouge1_threshold=14 \ - --no_add_special_tokens -``` - -## SmoothQuant +### SmoothQuant This section explains how to use SmoothQuant on GPT models with TensorRT-LLM. -### Overview - [SmoothQuant](https://arxiv.org/abs/2211.10438) is a post-training quantization (PTQ) method to quantize LLM models to INT8 for faster inference. As explained in the article, SmoothQuant modifies a model to enable INT8 quantization @@ -346,251 +320,383 @@ TensorRT-LLM also supports more elaborate modes: Users can mix-and-match per-channel and per-token options. Both tend to increase the accuracy of the model at the cost of a slightly increased latency. -### Usage - -#### SmoothQuant a HF model, export weights & scales for TensorRT-LLM - -For SmoothQuant, [`hf_gpt_convert.py`](./hf_gpt_convert.py) features a -`--smoothquant, -sq` option. It must be set to a decimal value in `[0, 1]` and +#### Usage +[`convert_checkpoint.py`](./convert_checkpoint.py) features a +`--smoothquant` option. It must be set to a decimal value in `[0, 1]` and corresponds to the `alpha` parameter in the [SmoothQuant -paper](https://arxiv.org/abs/2211.10438). Setting `-sq` will smooth the model +paper](https://arxiv.org/abs/2211.10438). Setting `--smoothquant` will smooth the model as explained in [model transformation](#model-transformation) and export the scaling factors needed for INT8 inference. -Example: +By default, it will run the model in the per-tensor mode, as explained in [INT8 +inference](#int8-inference). You can add any combination of `--per_token` and `--per_channel` to get the corresponding behaviors. + ```bash -python3 hf_gpt_convert.py -i gpt2 -o ./c-model/gpt2-smooth --smoothquant 0.5 -t float16 +# Per-tensor SmoothQuant +python3 convert_checkpoint.py --model_dir gpt2 \ + --dtype float16 \ + --smoothquant 0.5 \ + --output_dir gpt2/trt_ckpt/int8-sq/1-gpu + +# Per-token per-channel SmoothQuant +python3 convert_checkpoint.py --model_dir gpt2 \ + --dtype float16 \ + --smoothquant 0.5 \ + --per_token \ + --per_channel \ + --output_dir gpt2/trt_ckpt/int8-sq-ptpc/1-gpu ``` -#### Build TensorRT engine(s) +Then, use `trtllm-build` to build engine(s). + +```bash +# Per-tensor SmoothQuant +trtllm-build --checkpoint_dir gpt2/trt_ckpt/int8-sq/1-gpu \ + --gpt_attention_plugin float16 \ + --remove_input_padding enable \ + --output_dir gpt2/trt_engines/int8-sq/1-gpu + +# Per-token per-channel SmoothQuant +trtllm-build --checkpoint_dir gpt2/trt_ckpt/int8-sq-ptpc/1-gpu \ + --gpt_attention_plugin float16 \ + --remove_input_padding enable \ + --output_dir gpt2/trt_engines/int8-sq-ptpc/1-gpu +``` -[`build.py`](./build.py) add new options for the support of INT8 inference of SmoothQuant models. +Note that GPT attention plugin is required to be enabled for SmoothQuant for now. -`--use_smooth_quant` is the starting point of INT8 inference. By default, it -will run the model in the _per-tensor_ mode, as explained in [INT8 -inference](#int8-inference). -Then, you can add any combination of `--per-token` and `--per-channel` to get the corresponding behaviors. +### INT8 KV Cache -Examples of build invocations: +[`convert_checkpoint.py`](./convert_checkpoint.py) features a +`--int8_kv_cache` option. Setting `--int8_kv_cache` will calibrate the model and export the +scaling factors needed for INT8 KV cache inference. ```bash -# Build model for SmoothQuant in the _per_tensor_ mode. -python3 build.py --model_dir=./c-model/gpt2-smooth/1-gpu \ - --use_gpt_attention_plugin \ - --use_smooth_quant - -# Build model for SmoothQuant in the _per_token_ + _per_channel_ mode -python3 build.py --model_dir=./c-model/gpt2-smooth/1-gpu \ - --use_gpt_attention_plugin \ - --use_smooth_quant \ - --per_token \ - --per_channel +# Int8 KV cache +python3 convert_checkpoint.py --model_dir gpt2 \ + --dtype float16 \ + --int8_kv_cache \ + --output_dir gpt2/trt_ckpt/int8kv/1-gpu + +trtllm-build --checkpoint_dir gpt2/trt_ckpt/int8kv/1-gpu \ + --gpt_attention_plugin float16 \ + --remove_input_padding enable \ + --strongly_typed \ + --output_dir gpt2/trt_engines/int8kv/1-gpu ``` -Note that GPT attention plugin is required to be enabled for SmoothQuant for now. -### INT8 KV Cache, export weights & scales for TensorRT-LLM +INT8 KV cache can be used with or without gpt attention plugin. -For int8 kv cache, [`hf_gpt_convert.py`](./hf_gpt_convert.py) features a -`--calibrate-kv-cache, -kv` option. Setting `-kv` will calibrate the model as -explained in [model transformation](#model-transformation) and export the -scaling factors needed for INT8 KV cache inference. +### Weight Only Quantization -Example: +[`convert_checkpoint.py`](./convert_checkpoint.py) features a `--use_weight_only` option that can enable weight-only quantization. You can further set the weight-only precision by passing `int8` or `int4` to the `--weight_only_precision` flag. ```bash -python3 hf_gpt_convert.py -i gpt2 -o ./c-model/gpt2 --calibrate-kv-cache -t float16 +# Int8 weight-only quantization +python3 convert_checkpoint.py --model_dir gpt2 \ + --dtype float16 \ + --use_weight_only \ + --weight_only_precision int8 \ + --output_dir gpt2/trt_ckpt/int8-wo/1-gpu + +# Int4 weight-only quantization +python3 convert_checkpoint.py --model_dir gpt2 \ + --dtype float16 \ + --use_weight_only \ + --weight_only_precision int4 \ + --output_dir gpt2/trt_ckpt/int4-wo/1-gpu ``` -#### Build TensorRT engine(s) - -[`build.py`](./build.py) add new options for the support of INT8 kv cache for models. -`--int8_kv_cache` forces KV cache to int8. INT8 KV cache can be used with or without gpt attention plugin. -Examples of build invocations: +Then, use `trtllm-build` to build engine(s). ```bash -# Build model for GPT with int8 kv cache. -python3 build.py --model_dir=./c-model/gpt2/1-gpu \ - --int8_kv_cache --remove_input_padding --use_gpt_attention_plugin float16 +# Int8 weight-only quantization +trtllm-build --checkpoint_dir gpt2/trt_ckpt/int8-wo/1-gpu \ + --gpt_attention_plugin float16 \ + --remove_input_padding enable \ + --output_dir gpt2/trt_engines/int8-wo/1-gpu + +# Int4 weight-only quantization +trtllm-build --checkpoint_dir gpt2/trt_ckpt/int4-wo/1-gpu \ + --gpt_attention_plugin float16 \ + --remove_input_padding enable \ + --output_dir gpt2/trt_engines/int4-wo/1-gpu ``` -Example of build invocations without gpt attention plugin +### FP8 Quantization + +[`../quantization/quantize.py`](../quantization/quantize.py) can do FP8 quantization and/or FP8 kv cache quantization, and export TensorRT-LLM checkpoint. + ```bash -python3 build.py --model_dir=./c-model/gpt2/1-gpu --int8_kv_cache +# FP8 quantization with FP8 kv cache +python3 ../quantization/quantize.py --model_dir gpt2 \ + --dtype float16 \ + --qformat fp8 \ + --kv_cache_dtype fp8 \ + --output_dir gpt2/trt_ckpt/fp8/1-gpu + +trtllm-build --checkpoint_dir gpt2/trt_ckpt/fp8/1-gpu \ + --gpt_attention_plugin float16 \ + --remove_input_padding enable \ + --strongly_typed \ + --output_dir gpt2/trt_engines/fp8/1-gpu ``` -## GPT-Next +## Embedding Parallelism and Sharing +Since the embedding lookup table can be several gigabytes in size. We can distribute this weight across multiple GPUs in order to reduce the memory consumption per GPU. -NVIDIA has released a GPT-like model with some architectural improvements, that you can find here: [https://huggingface.co/nvidia/GPT-2B-001](https://huggingface.co/nvidia/GPT-2B-001) -This architecture is also supported by TensorRT-LLM +### 1. Embedding parallelism +To enable this feature, add the flag `--use_parallel_embedding` to `convert_checkpoint.py`. -### 1. Download weights from HuggingFace Transformers +### 2. The sharding dimension for embedding parallelism + +Assume the size of embedding lookup table is (vocab\_size \* hidden\_size), we can shard it along the vocab\_size (`--embedding_sharding_dim 0`) or hidden\_size (`--embedding_sharding_dim 1`) dimension. + +2.1 To shard the embedding lookup table along the hidden\_size dimension, set the flag `--use_parallel_embedding --embedding_sharding_dim 1`. Here is an example: ```bash -wget https://huggingface.co/nvidia/GPT-2B-001/resolve/main/GPT-2B-001_bf16_tp1.nemo +# 2-way tensor parallelism with embedding parallelism along hidden dimension +python3 convert_checkpoint.py --model_dir gpt2 \ + --dtype float16 \ + --tp_size 2 \ + --use_parallel_embedding \ + --embedding_sharding_dim 1 \ + --output_dir gpt2/trt_ckpt/fp16/2-gpu + +trtllm-build --checkpoint_dir gpt2/trt_ckpt/fp16/2-gpu \ + --gpt_attention_plugin float16 \ + --remove_input_padding enable \ + --output_dir gpt2/trt_engines/fp16/2-gpu ``` -### 2. Convert weights from NeMo Checkpoint to FT format - -TensorRT-LLM can convert `.nemo` to generic binary files with [`nemo_ckpt_convert.py`](./nemo_ckpt_convert.py) script. For example: +2.2 To shard the embedding lookup table along the vocab\_size dimension, set the flag `--use_parallel_embedding --embedding_sharding_dim 0`. In this case, you can optionally enable the lookup plugin when building the engines. ```bash -python3 nemo_ckpt_convert.py -i GPT-2B-001_bf16_tp1.nemo -o ./c-model/gpt-next-2B --tensor-parallelism 1 --storage-type bfloat16 +# 2-way tensor parallelism with embedding parallelism along vocab dimension +python3 convert_checkpoint.py --model_dir gpt2 \ + --dtype float16 \ + --tp_size 2 \ + --use_parallel_embedding \ + --embedding_sharding_dim 0 \ + --output_dir gpt2/trt_ckpt/fp16/2-gpu + +# It is optional to add --lookup_plugin +trtllm-build --checkpoint_dir gpt2/trt_ckpt/fp16/2-gpu \ + --gpt_attention_plugin float16 \ + --remove_input_padding enable \ + --lookup_plugin float16 \ + --output_dir gpt2/trt_engines/fp16/2-gpu ``` -### 3. Build TensorRT engine(s) +### 3. Embedding sharing +In some models, the embedding weight is used in both the embedding layer and lm_head (language modeling head) layer. In this case, sharing the embedding weight can reduce memory consumption. + +With flag `--use_embedding_sharing` for `convert_checkpoint.py`, we will try to enable this feature. However it only takes effect when the following criteria are met: +- The embedding weight is shared between the embedding and lm_head layers. If not, we should not enable this feature. +- For tensor parallelism cases, `--use_parallel_embedding --embedding_sharding_dim 0` must be set. In other words, we must enable embedding parallelism along the vocab dimension, which minimizes the overall communication cost. +- For TensorRT 9.0 version, the engine size is expected to be reduced when the lookup and gemm plugin are enabled. + +Here is an example for using embedding parallelism and sharing feature: -```bash -# Build a single-GPU bfloat16 engine using FT weights. -# --use_gpt_attention_plugin must be set for GPT-Next since Rotary positional embeddings (RoPE) is only supported by the gpt attention plugin at this time. -python3 build.py --model_dir=./c-model/gpt-next-2B/1-gpu \ - --dtype bfloat16 \ - --remove_input_padding \ - --use_gpt_attention_plugin - -# Build GPT-Next architecture engines using dummy weights, useful for performance tests. -# Enable several TensorRT-LLM plugins to increase runtime performance. It also helps with build time. -python3 build.py --vocab_size=256000 \ - --n_layer=24 \ - --n_embd=2048 \ - --n_head=16 \ - --max_batch_size=256 \ - --dtype float16 \ - --no_bias \ - --hidden_act swiglu \ - --rotary_pct 0.5 \ - --remove_input_padding \ - --use_gpt_attention_plugin \ - --use_gemm_plugin \ - --output_dir=gpt-next-2B -``` - -### 4. Run ```bash -# Run the GPT-Next model on a single GPU. Use custom tokenizer. -python3 ../run.py --max_output_len=8 \ - --vocab_file=./c-model/gpt-next-2B/1-gpu/tokenizer.model \ - --no_add_special_tokens +# 2-way tensor parallelism with embedding sharing +# It requires enabling embedding parallelism along vocab dimension +python3 convert_checkpoint.py --model_dir gpt2 \ + --dtype float16 \ + --tp_size 2 \ + --use_embedding_sharing \ + --use_parallel_embedding \ + --embedding_sharding_dim 0 \ + --output_dir gpt2/trt_ckpt/fp16/2-gpu + +# It is recommended to add --lookup_plugin and --gemm_plugin +trtllm-build --checkpoint_dir gpt2/trt_ckpt/fp16/2-gpu \ + --gpt_attention_plugin float16 \ + --remove_input_padding enable \ + --lookup_plugin float16 \ + --gemm_plugin float16 \ + --output_dir gpt2/trt_engines/fp16/2-gpu ``` -## Prompt-tuning -For efficient fine-tuning, the NeMo framework allows you to learn virtual tokens to accomplish a downstream task. For more details, please read the -NeMo documentation [here](https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/nlp/nemo_megatron/prompt_learning.html). +## GPT Variant - SantaCoder -TensorRT-LLM supports inference with those virtual tokens. To enable it, pass the prompt embedding table's maximum size at build time with -`--max_prompt_embedding_table_size N`. For example: -```bash -# Build a GPT-Next model with prompt-tuning enabled -python3 build.py --model_dir=./c-model/gpt-next-8B/1-gpu --remove_input_padding --use_gpt_attention_plugin --max_prompt_embedding_table_size 100 -``` +The SantaCoder extends the existing GPT model with multi-query attention mechanism. The following example shows building a 4-GPU engine and running simple prompt to generate the implementation of `print_hello_world()`. -You can now export the learned embedding table with: ```bash -python3 nemo_prompt_convert.py -i email_composition.nemo -o email_composition.npy -``` -It'll give you a summary of the different tasks in the table, that you can specify at runtime. +# Download hf santacoder model +git clone https://huggingface.co/bigcode/santacoder -Finally, you can run inference on pre-defined tokens: -```bash -python3 ../run.py --input_file input.csv --prompt_table email_composition.npy --tasks 0 --max_output_len=8 --vocab_file=./c-model/gpt-next-8B/1-gpu/tokenizer.model --no_add_special_tokens +# Convert to TensorRT-LLM checkpoint +python3 convert_checkpoint.py --model_dir santacoder \ + --dtype float16 \ + --tp_size 4 \ + --output_dir santacoder/trt_ckpt/fp16/4-gpu + +# Build TensorRT-LLM engines +trtllm-build --checkpoint_dir santacoder/trt_ckpt/fp16/4-gpu \ + --gpt_attention_plugin float16 \ + --remove_input_padding enable \ + --context_fmha enable \ + --gemm_plugin float16 \ + --output_dir santacoder/trt_engines/fp16/4-gpu + +# Run inference +mpirun -np 4 \ + python3 ../run.py --engine_dir santacoder/trt_engines/fp16/4-gpu \ + --tokenizer_dir santacoder \ + --input_text "def print_hello_world():" \ + --max_output_len 20 ``` -## Tensor Parallelism for Embedding Lookup Table. -Since the embedding lookup table can be several gigabytes in size. We can distribute this weight across multiple GPUs in order to reduce the memory consumption per GPU. -### 1. Enable this feature -To enable this feature, add the flag `--use_parallel_embedding` to `build.py`. +## GPT Variant - StarCoder (v1 and v2) -### 2. Choose the dimension for tensor parallelism +For StarCoder, the steps are similar to SantaCoder. -Assume the size of embedding lookup table is (vocab\_size \* hidden\_size), we can shard it along the vocab\_size (`--embedding_sharding_dim 0`) or hidden\_size (`--embedding_sharding_dim 1`) dimension. +```bash +# Download hf starcoder model +git clone https://huggingface.co/bigcode/starcoder -2.1 To shard the embedding lookup table along the hidden\_size dimension, set the flag `--use_parallel_embedding --embedding_sharding_dim 1`. Here is an example: -```Bash -python3 build.py --model_dir=./c-model/gpt2/2-gpu --dtype float16 --world_size=2 --remove_input_padding --use_gpt_attention_plugin float16 --parallel_build --max_input_len 1000 \ - --use_parallel_embedding --embedding_sharding_dim 1 \ - --output_dir=trt_engine/gpt2/float16/2-gpu +# Convert to TensorRT-LLM checkpoint +python3 convert_checkpoint.py --model_dir starcoder \ + --dtype float16 \ + --tp_size 4 \ + --output_dir starcoder/trt_ckpt/fp16/4-gpu + +# Build TensorRT-LLM engines +trtllm-build --checkpoint_dir starcoder/trt_ckpt/fp16/4-gpu \ + --gpt_attention_plugin float16 \ + --remove_input_padding enable \ + --context_fmha enable \ + --gemm_plugin float16 \ + --output_dir starcoder/trt_engines/fp16/4-gpu + +# Run inference +mpirun -np 4 \ + python3 ../run.py --engine_dir starcoder/trt_engines/fp16/4-gpu \ + --tokenizer_dir starcoder \ + --input_text "def print_hello_world():" \ + --max_output_len 20 ``` -2.2 To shard the embedding lookup table along the vocab\_size dimension, set the flag `--use_parallel_embedding --embedding_sharding_dim 0`. -Meanwhile, we provide a lookup plugin to support tensor parallelism on vocab\_size dimension. +For StarCoder2, you can use almost the same steps as shown above. + - Note that StarCoder2 hasn't been merged to the official releases of transformers package yet, so remember using the [main branch of transformers repo](https://github.com/huggingface/transformers). + - Add `--max_attention_window_size 4096` when running with run.py or summarization, which enables the sliding window attention. + - the sliding window size comes from the hf model [config.json](https://huggingface.co/bigcode/starcoder2-15b/blob/main/config.json#L23). -- An example of sharing along vocab\_size dimension with lookup plugin: -```Bash -python3 build.py --model_dir=./c-model/gpt2/2-gpu --dtype float16 --world_size=2 --remove_input_padding --use_gpt_attention_plugin float16 --parallel_build --max_input_len 1000 \ - --use_parallel_embedding --embedding_sharding_dim 0 --use_lookup_plugin float16 \ - --output_dir=trt_engine/gpt2/float16/2-gpu -``` -- An example of sharing along vocab\_size dimension without lookup plugin: -```Bash -python3 build.py --model_dir=./c-model/gpt2/2-gpu --dtype float16 --world_size=2 --remove_input_padding --use_gpt_attention_plugin float16 --parallel_build --max_input_len 1000 \ - --use_parallel_embedding --embedding_sharding_dim 0 \ - --output_dir=trt_engine/gpt2/float16/2-gpu +## GPT-Next + +NVIDIA has released a GPT-like model with some architectural improvements, that you can find here: [https://huggingface.co/nvidia/GPT-2B-001](https://huggingface.co/nvidia/GPT-2B-001). This architecture is also supported by TensorRT-LLM. + +Different from Huggingface's checkpoint, you should specify the NeMo checkpoint path using `--nemo_ckpt_path` for `convert_checkpoint.py`. The script also extracts the tokenizer file from the NeMo checkpoint and saves it to the TensorRT-LLM checkpoint folder, which can be used in the inference scripts. + +```bash +# Download NeMo checkpoint +wget https://huggingface.co/nvidia/GPT-2B-001/resolve/main/GPT-2B-001_bf16_tp1.nemo + +# Convert to TensorRT-LLM checkpoint +# It also extracts the tokenizer file and saves to the TensorRT-LLM checkpoint folder +python3 convert_checkpoint.py --nemo_ckpt_path GPT-2B-001_bf16_tp1.nemo \ + --dtype bfloat16 \ + --output_dir gpt-next-2B/trt_ckpt/bf16/1-gpu + +# Build TensorRT-LLM engines +# --gpt_attention_plugin must be set for GPT-Next since Rotary positional embeddings (RoPE) is only supported by the gpt attention plugin at this time. +trtllm-build --checkpoint_dir gpt-next-2B/trt_ckpt/bf16/1-gpu \ + --gpt_attention_plugin bfloat16 \ + --remove_input_padding enable \ + --output_dir gpt-next-2B/trt_engines/bf16/1-gpu + +# Run inference +python3 ../run.py --engine_dir gpt-next-2B/trt_engines/bf16/1-gpu \ + --vocab_file gpt-next-2B/trt_ckpt/bf16/1-gpu/tokenizer.model \ + --no_add_special_tokens \ + --max_output_len 8 ``` -### 3. Embedding sharing -In some examples, the embedding lookup table is used both in embedding() and lm_head() layers. Sharing the embedding lookup table can reduce memory consumption. -With flag `--use_embedding_sharing` for `build.py`, we will try to enable this feature. However it only takes effect when the following criteria are met: -- The weight is shared between two layers. If we found the weight for lm_head() layer, we cannot enable it. -- For multiple processes case, `--use_parallel_embedding` must be set. And we only support sharing when the embedding lookup table is sharded along the vocab dimension (`--embedding_sharding_dim 0`, as is the default value), which minimizes the overall communication cost. -- For TensorRT 9.0 version, the engine size is expected to be reduced when the lookup and gemm plugin are enabled. +### Prompt-tuning -Here is an example for using embedding parallelism and sharing feature: -```Bash -python3 hf_gpt_convert.py -i gpt2 -o ./c-model/gpt2 --tensor-parallelism 2 --storage-type bfloat16 +For efficient fine-tuning, the NeMo framework allows you to learn virtual tokens to accomplish a downstream task. For more details, please read the +NeMo documentation [here](https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/nlp/nemo_megatron/prompt_learning.html). + +TensorRT-LLM supports inference with those virtual tokens. To enable it, pass the prompt embedding table's maximum size at build time with `--max_prompt_embedding_table_size N`. For example: -python3 build.py --model_dir=./c-model/gpt2/2-gpu --dtype bfloat16 --world_size=2 --remove_input_padding --use_gpt_attention_plugin --use_gemm_plugin --parallel_build --max_input_len 1000 --use_parallel_embedding --embedding_sharding_dim 0 --use_lookup_plugin --use_embedding_sharing --output_dir=trt_engine/gpt2/bfloat16/2-gpu +```bash +# Convert to TensorRT-LLM checkpoint +python3 convert_checkpoint.py --nemo_ckpt_path megatron_converted_8b_tp4_pp1.nemo \ + --dtype float16 \ + --output_dir gpt-next-8B/trt_ckpt/fp16/1-gpu + +# Build TensorRT-LLM engines with prompt-tuning enabled +trtllm-build --checkpoint_dir gpt-next-8B/trt_ckpt/fp16/1-gpu \ + --gpt_attention_plugin float16 \ + --remove_input_padding enable \ + --max_prompt_embedding_table_size 100 \ + --output_dir gpt-next-8B/trt_engines/fp16/1-gpu +``` + +You can now export the learned embedding table with: +```bash +python3 nemo_prompt_convert.py -i email_composition.nemo -o email_composition.npy +``` +It'll give you a summary of the different tasks in the table, that you can specify at runtime. -mpirun -np 2 python3 ../summarize.py --engine_dir trt_engine/gpt2/bfloat16/2-gpu --hf_model_dir gpt2 --batch_size 10 --test_trt_llm --check_accuracy --tensorrt_llm_rouge1_threshold=14 --dataset_path ./dataset --no_add_special_tokens +Finally, you can run inference on pre-defined tokens: +```bash +python3 ../run.py --engine_dir gpt-next-8B/trt_engines/fp16/1-gpu \ + --vocab_file gpt-next-8B/trt_ckpt/fp16/1-gpu/tokenizer.model \ + --no_add_special_tokens \ + --prompt_table_path email_composition.npy \ + --prompt_tasks 0 \ + --max_output_len 8 ``` -### Run MultiLoRA with the Nemo checkpoint + +### MultiLoRA with the Nemo checkpoint ```bash -git clone https://huggingface.co/nvidia/GPT-2B-001 - -python3 examples/gpt/nemo_ckpt_convert.py -i GPT-2B-001_bf16_tp1.nemo -o gpt-2b-fp16-weights-tp1-pp1 -tp 1 -p 4 -t float16 - -python3 examples/gpt/build.py --model_dir=gpt-2b-fp16-weights-tp1-pp1/1-gpu \ - --dtype float16 \ - --remove_input_padding \ - --use_gpt_attention_plugin float16 \ - --use_inflight_batching \ - --paged_kv_cache \ - --output_dir gpt-2b-trt-fp16-tp1-pp1-test \ - --use_lora_plugin \ - --lora_target_modules attn_qkv \ - --max_batch_size 4 \ - --max_beam_width 2 \ - --max_input_len 512 \ - --max_output_len 50 \ - --log_level verbose +# Download NeMo checkpoint +wget https://huggingface.co/nvidia/GPT-2B-001/resolve/main/GPT-2B-001_bf16_tp1.nemo + +# Convert to TensorRT-LLM checkpoint +python3 convert_checkpoint.py --nemo_ckpt_path GPT-2B-001_bf16_tp1.nemo \ + --dtype float16 \ + --output_dir gpt-next-2B/trt_ckpt/fp16/1-gpu + +# Build TensorRT-LLM engines +trtllm-build --checkpoint_dir gpt-next-2B/trt_ckpt/fp16/1-gpu \ + --gpt_attention_plugin float16 \ + --remove_input_padding enable \ + --lora_plugin float16 \ + --lora_dir gpt2b_lora-900.nemo gpt2b_lora-stories.nemo \ + --lora_ckpt_source "nemo" \ + --lora_target_modules attn_qkv \ + --max_batch_size 4 \ + --max_beam_width 2 \ + --max_input_len 512 \ + --max_output_len 50 \ + --output_dir gpt-next-2B/trt_engines/fp16/1-gpu # Run inference directly from NeMo LoRA checkpoint # --lora_task_ids correspond to the index of the models given with --lora_dir. -1 means no LoRA - -python3 examples/run.py --max_output_len=20 \ - --use_py_session \ - --vocab_file=gpt-2b-fp16-weights-tp1-pp1/1-gpu/tokenizer.model \ - --engine_dir gpt-2b-trt-fp16-tp1-pp1-test/ \ - --lora_dir gpt2b_lora-900.nemo gpt2b_lora-stories.nemo \ - --lora_task_uids 0 -1 1 \ - --lora_ckpt_source "nemo" \ - --no_add_special_tokens \ - --input_text "After Washington had returned to Williamsburg, Dinwiddie ordered him to lead a larger force to assist Trent in his work. While en route, Washington learned of Trent's retreat. Since Tanaghrisson had promised support to the British, Washington continued toward Fort Duquesne and met with the Mingo leader. Learning of a French scouting party in the area, Washington, with Tanaghrisson and his party, surprised the Canadians on May 28 in what became known as the Battle of Jumonville Glen. They killed many of the Canadians, including their commanding officer, Joseph Coulon de Jumonville, whose head was reportedly split open by Tanaghrisson with a tomahawk. The historian Fred Anderson suggests that Tanaghrisson was acting to gain the support of the British and regain authority over his own people. They had been inclined to support the French, with whom they had long trading relationships. One of Tanaghrisson's men told Contrecoeur that Jumonville had been killed by British musket fire. Question: Upon learning of a French scounting party in the area, what did Washington do? Answer:" "You hold the job title in the Wizarding World of Harry Potter where you say random words looking for spells" "You hold the job title in the Wizarding World of Harry Potter where you say random words looking for spells" +python3 ../run.py --engine_dir gpt-next-2B/trt_engines/fp16/1-gpu \ + --vocab_file gpt-next-2B/trt_ckpt/fp16/1-gpu/tokenizer.model \ + --no_add_special_tokens \ + --max_output_len 20 \ + --use_py_session \ + --lora_task_uids 0 -1 1 \ + --input_text "After Washington had returned to Williamsburg, Dinwiddie ordered him to lead a larger force to assist Trent in his work. While en route, Washington learned of Trent's retreat. Since Tanaghrisson had promised support to the British, Washington continued toward Fort Duquesne and met with the Mingo leader. Learning of a French scouting party in the area, Washington, with Tanaghrisson and his party, surprised the Canadians on May 28 in what became known as the Battle of Jumonville Glen. They killed many of the Canadians, including their commanding officer, Joseph Coulon de Jumonville, whose head was reportedly split open by Tanaghrisson with a tomahawk. The historian Fred Anderson suggests that Tanaghrisson was acting to gain the support of the British and regain authority over his own people. They had been inclined to support the French, with whom they had long trading relationships. One of Tanaghrisson's men told Contrecoeur that Jumonville had been killed by British musket fire. Question: Upon learning of a French scounting party in the area, what did Washington do? Answer:" "You hold the job title in the Wizarding World of Harry Potter where you say random words looking for spells" "You hold the job title in the Wizarding World of Harry Potter where you say random words looking for spells" ``` -#### Example output - -* Note that in this case the adapters have only been trained for a few epochs, so the result quality is poor. +The output would look like (Note that in this case the adapters have only been trained for a few epochs, so the result quality is poor): ``` -Input [Text 0]: "After Washington had returned to Williamsburg, Dinwiddie ordered him to lead a larger force to assist Trent in his work. While en route, Washington learned of Trent's retreat. Since Tanaghrisson had promised support to the British, Washington continued toward Fort Duquesne and met with the Mingo leader. Learning of a French scouting party in the area, Washington, with Tanaghrisson and his party, surprise -d the Canadians on May 28 in what became known as the Battle of Jumonville Glen. They killed many of the Canadians, including their commanding officer, Joseph Coulon de Jumonville, whose head was reportedly split open by Tanaghrisson with a tomahawk. The historian Fred Anderson suggests that Tanaghrisson was acting to gain the support of the British and regain authority over his own people. They had been inclined to support the French, with whom they had long trading relati -onships. One of Tanaghrisson's men told Contrecoeur that Jumonville had been killed by British musket fire. Question: Upon learning of a French scounting party in the area, what did Washington do? Answer:" +...... +Input [Text 0]: "After Washington had returned to Williamsburg, Dinwiddie ordered him to lead a larger force to assist Trent in his work. While en route, Washington learned of Trent's retreat. Since Tanaghrisson had promised support to the British, Washington continued toward Fort Duquesne and met with the Mingo leader. Learning of a French scouting party in the area, Washington, with Tanaghrisson and his party, surprised the Canadians on May 28 in what became known as the Battle of Jumonville Glen. They killed many of the Canadians, including their commanding officer, Joseph Coulon de Jumonville, whose head was reportedly split open by Tanaghrisson with a tomahawk. The historian Fred Anderson suggests that Tanaghrisson was acting to gain the support of the British and regain authority over his own people. They had been inclined to support the French, with whom they had long trading relationships. One of Tanaghrisson's men told Contrecoeur that Jumonville had been killed by British musket fire. Question: Upon learning of a French scounting party in the area, what did Washington do? Answer:" Output [Text 0 Beam 0]: "He surprised the Canadians on May 28 in what became known as the Battle of Jumonville" Input [Text 1]: "You hold the job title in the Wizarding World of Harry Potter where you say random words looking for spells" Output [Text 1 Beam 0]: ". diff --git a/examples/gpt/build.py b/examples/gpt/build.py deleted file mode 100644 index 6f651e98d..000000000 --- a/examples/gpt/build.py +++ /dev/null @@ -1,816 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import argparse -import math -import time -from pathlib import Path -from typing import List - -import torch -import torch.multiprocessing as mp - -import tensorrt_llm -from tensorrt_llm._common import check_max_num_tokens -from tensorrt_llm._utils import str_dtype_to_trt -from tensorrt_llm.builder import Builder -from tensorrt_llm.layers import MoeConfig, PositionEmbeddingType -from tensorrt_llm.logger import logger -from tensorrt_llm.mapping import Mapping -from tensorrt_llm.models import quantize_model -from tensorrt_llm.network import net_guard -from tensorrt_llm.plugin.plugin import ContextFMHAType -from tensorrt_llm.profiler import check_gpt_mem_usage -from tensorrt_llm.quantization import QuantMode - -from weight import load_from_ft, parse_ft_config, check_embedding_share # isort:skip - -MODEL_NAME = "gpt" - - -def get_engine_name(model, dtype, tp_size, rank): - return '{}_{}_tp{}_rank{}.engine'.format(model, dtype, tp_size, rank) - - -def find_engines(dir: Path, - model_name: str = "*", - dtype: str = "*", - tp_size: str = "*", - rank: str = "*") -> List[Path]: - template = f"{model_name}_{dtype}_tp{tp_size}_rank{rank}.engine" - return list(dir.glob(template)) - - -def serialize_engine(engine, path): - logger.info(f'Serializing engine to {path}...') - tik = time.time() - with open(path, 'wb') as f: - f.write(engine) - tok = time.time() - t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) - logger.info(f'Engine serialized. Total time: {t}') - - -def override_args_from_model_dir(args: argparse.Namespace) -> None: - if args.model_dir is not None: - logger.info(f"Setting model configuration from {args.model_dir}.") - parsed_params = parse_ft_config(Path(args.model_dir) / "config.ini") - args.n_embd = parsed_params["n_embd"] - args.n_head = parsed_params["n_head"] - args.n_kv_head = parsed_params["n_kv_head"] - args.n_layer = parsed_params["n_layer"] - args.n_positions = parsed_params["n_positions"] - args.vocab_size = parsed_params["vocab_size"] - args.hidden_act = parsed_params["hidden_act"] - if parsed_params["rotary_pct"] is not None: - args.rotary_pct = parsed_params["rotary_pct"] - if parsed_params["rotary_base"] is not None: - args.rotary_base = parsed_params["rotary_base"] - if parsed_params["rotary_scaling"] is not None: - args.rotary_scaling = parsed_params["rotary_scaling"] - args.bias = parsed_params["bias"] - args.dtype = parsed_params["dtype"] - args.inter_size = parsed_params["inter_size"] - args.multi_query_mode = parsed_params["multi_query_mode"] - else: - args.n_kv_head = 1 if args.multi_query_mode else args.n_head - - -def parse_arguments(args): - parser = argparse.ArgumentParser() - parser.add_argument('--world_size', - type=int, - default=1, - help='world size, only support tensor parallelism now') - parser.add_argument('--model_dir', type=str, default=None) - parser.add_argument('--dtype', - type=str, - default='float16', - choices=['float16', 'float32', 'bfloat16']) - parser.add_argument('--logits_dtype', - type=str, - default='float32', - choices=['float16', 'float32']) - parser.add_argument( - '--timing_cache', - type=str, - default='model.cache', - help= - 'The path of to read timing cache from, will be ignored if the file does not exist' - ) - parser.add_argument( - '--profiling_verbosity', - type=str, - default='layer_names_only', - choices=['layer_names_only', 'detailed', 'none'], - help= - 'The profiling verbosity for the generated TRT engine. Set to detailed can inspect tactic choices and kernel parameters.' - ) - parser.add_argument('--log_level', type=str, default='info') - parser.add_argument('--vocab_size', type=int, default=51200) - parser.add_argument('--n_layer', type=int, default=24) - parser.add_argument('--n_positions', type=int, default=1024) - parser.add_argument('--n_embd', type=int, default=1024) - parser.add_argument('--n_head', type=int, default=16) - parser.add_argument('--hidden_act', type=str, default='gelu') - parser.add_argument('--rotary_base', type=float, default=10000.0) - parser.add_argument('--rotary_scaling', nargs=2, type=str, default=None) - parser.add_argument( - '--rotary_pct', - type=float, - default=0.0, - help="Setting this to a value > 0.0 (and <= 1.0) activates RoPE.") - parser.add_argument('--inter_size', type=int, default=None) - parser.add_argument('--no_bias', action="store_false") - parser.add_argument('--max_batch_size', type=int, default=256) - parser.add_argument('--max_input_len', type=int, default=200) - parser.add_argument('--max_output_len', type=int, default=200) - parser.add_argument('--max_beam_width', type=int, default=1) - parser.add_argument( - '--use_gpt_attention_plugin', - nargs='?', - const=None, - type=str, - default=False, - choices=['float16', 'float32', 'bfloat16'], - help= - "Activates attention plugin. You can specify the plugin dtype or leave blank to use the model dtype." - ) - parser.add_argument( - '--use_gemm_plugin', - nargs='?', - const=None, - type=str, - default=False, - choices=['float16', 'float32', 'bfloat16'], - help= - "Activates GEMM plugin. You can specify the plugin dtype or leave blank to use the model dtype." - ) - parser.add_argument('--parallel_build', default=False, action='store_true') - parser.add_argument('--enable_context_fmha', - default=False, - action='store_true') - parser.add_argument('--enable_context_fmha_fp32_acc', - default=False, - action='store_true') - parser.add_argument( - '--multi_block_mode', - default=False, - action='store_true', - help= - 'Split long kv sequence into multiple blocks (applied to generation MHA kernels). \ - It is beneficial when batch x num_heads cannot fully utilize GPU.' - ) - parser.add_argument('--gpus_per_node', type=int, default=8) - parser.add_argument('--builder_opt', type=int, default=None) - parser.add_argument( - '--output_dir', - type=Path, - default='engine_outputs', - help= - 'The path to save the serialized engine files, timing cache file and model configs' - ) - parser.add_argument( - "--multi_query_mode", - "-mq", - default=False, - action='store_true', - help= - "Whether this model uses multi-query attention mechanism (default: False)" - ) - parser.add_argument('--remove_input_padding', - default=False, - action='store_true') - - # Arguments related to the quantization of the model. - parser.add_argument( - '--use_smooth_quant', - default=False, - action="store_true", - help= - 'Use the SmoothQuant method to quantize activations and weights for the various GEMMs.' - 'See --per_channel and --per_token for finer-grained quantization options.' - ) - parser.add_argument( - '--use_weight_only', - default=False, - action="store_true", - help='Quantize weights for the various GEMMs to INT4/INT8.' - 'See --weight_only_precision to set the precision') - parser.add_argument( - '--weight_only_precision', - const='int8', - type=str, - nargs='?', - default='int8', - choices=['int8', 'int4'], - help= - 'Define the precision for the weights when using weight-only quantization.' - 'You must also use --use_weight_only for that argument to have an impact.' - ) - parser.add_argument( - '--per_channel', - default=False, - action="store_true", - help= - 'By default, we use a single static scaling factor for the GEMM\'s result. ' - 'per_channel instead uses a different static scaling factor for each channel. ' - 'The latter is usually more accurate, but a little slower.') - parser.add_argument( - '--per_token', - default=False, - action="store_true", - help= - 'By default, we use a single static scaling factor to scale activations in the int8 range. ' - 'per_token chooses at run time, and for each token, a custom scaling factor. ' - 'The latter is usually more accurate, but a little slower.') - parser.add_argument( - '--int8_kv_cache', - default=False, - action="store_true", - help= - 'By default, we use dtype for KV cache. int8_kv_cache chooses int8 quantization for KV' - ) - parser.add_argument( - '--random_seed', - type=int, - default=None, - help= - 'Seed to use when initializing the random number generator for torch.') - parser.add_argument( - '--paged_kv_cache', - action="store_true", - default=False, - help= - 'By default we use contiguous KV cache. By setting this flag you enable paged KV cache' - ) - parser.add_argument('--tokens_per_block', - type=int, - default=128, - help='Number of tokens per block in paged KV cache') - parser.add_argument( - '--max_prompt_embedding_table_size', - type=int, - default=0, - help='Setting to a value > 0 enables support for prompt tuning.') - parser.add_argument( - '--use_inflight_batching', - action="store_true", - default=False, - help="Activates inflight batching mode of gptAttentionPlugin.") - parser.add_argument( - '--use_parallel_embedding', - action="store_true", - default=False, - help= - 'By default embedding parallelism is disabled. By setting this flag, embedding parallelism is enabled' - ) - parser.add_argument( - '--embedding_sharding_dim', - type=int, - default=0, - choices=[0, 1], - help= - 'By default the embedding lookup table is sharded along vocab dimension (embedding_sharding_dim=0). ' - 'To shard it along hidden dimension, set embedding_sharding_dim=1' - 'Note: embedding sharing is only enabled when embedding_sharding_dim = 0' - ) - parser.add_argument( - '--use_embedding_sharing', - action="store_true", - default=False, - help= - 'Try to reduce the engine size by sharing the embedding lookup table between two layers.' - 'Note: the flag might not take effect when the criteria are not met.') - parser.add_argument( - '--use_lookup_plugin', - nargs='?', - const=None, - default=False, - choices=['float16', 'float32', 'bfloat16'], - help="Activates the lookup plugin which enables embedding sharing.") - parser.add_argument( - '--gather_all_token_logits', - action='store_true', - default=False, - help='Enable both gather_context_logits and gather_generation_logits') - parser.add_argument('--gather_context_logits', - action='store_true', - default=False, - help='Gather context logits') - parser.add_argument('--gather_generation_logits', - action='store_true', - default=False, - help='Gather generation logits') - - parser.add_argument('--enable_fp8', default=False, action='store_true') - parser.add_argument( - '--fp8_kv_cache', - default=False, - action="store_true", - help= - 'By default, we use dtype for KV cache. fp8_kv_cache chooses fp8 quantization for KV' - ) - parser.add_argument( - '--max_num_tokens', - type=int, - default=None, - help= - 'Define the max number of tokens supported by the engine, note that it takes no effect if --remove_input_padding is not set' - ) - parser.add_argument( - '--strongly_typed', - default=False, - action="store_true", - help= - 'This option is introduced with trt 9.1.0.1+ and will reduce the building time significantly for fp8.' - ) - parser.add_argument( - '--use_custom_all_reduce', - action='store_true', - help= - 'Activates latency-optimized algorithm for all-reduce instead of NCCL.') - parser.add_argument( - '--use_lora_plugin', - nargs='?', - const=None, - default=False, - choices=['float16', 'float32', 'bfloat16'], - help="Activates the lora plugin which enables embedding sharing.") - parser.add_argument( - '--max_draft_len', - type=int, - default=0, - help= - 'Maximum lengths of draft tokens for speculative decoding target model.' - ) - parser.add_argument( - '--use_paged_context_fmha', - action='store_true', - help= - 'Activates paged context FMHA. This mode of the context FMHA is required for chunked context, speculative decoding and reuse of KV cache blocks. Context FMHA performance is worse when this mode is on.' - ) - parser.add_argument( - '--use_context_fmha_for_generation', - action='store_true', - help= - 'Activates context FMHA for generation phase instead of MMHA. Use only for testing and debug.' - ) - parser.add_argument( - '--lora_target_modules', - nargs='+', - default=None, - choices=[ - "attn_qkv", - "attn_q", - "attn_k", - "attn_v", - "attn_dense", - "mlp_h_to_4h", - "mlp_gate", - "mlp_4h_to_h", - ], - help= - "Add lora in which modules. Only be activated when use_lora_plugin is enabled." - ) - parser.add_argument( - '--max_lora_rank', - type=int, - default=64, - help='maximum lora rank for different lora modules. ' - 'It is used to compute the workspace size of lora plugin.') - parser.add_argument( - '--moe_num_experts', - default=0, - type=int, - help='Specify the number of experts to use for MOE layers') - parser.add_argument( - '--moe_top_k', - default=0, - type=int, - help= - 'Specify the top_k value to use for MOE layers. Default to 1 if --moe_num_experts is set' - ) - parser.add_argument( - '--moe_tp_mode', - default=MoeConfig.ParallelismMode.TENSOR_PARALLEL, - type=int, - help= - 'Controls how to distribute experts in TP. Check layers/moe.py for accepted values', - ) - parser.add_argument( - '--moe_renorm_mode', - default=MoeConfig.ExpertScaleNormalizationMode.RENORMALIZE, - type=int, - help= - 'Controls renormalization after gate logits. Check layers/moe.py for accepted values', - ) - args = parser.parse_args(args) - logger.set_level(args.log_level) - - if not args.remove_input_padding: - if args.use_gpt_attention_plugin: - logger.warning( - f"It is recommended to specify --remove_input_padding when using GPT attention plugin" - ) - - args.bias = not args.no_bias - if args.inter_size is None: - args.inter_size = 4 * args.n_embd - - override_args_from_model_dir(args) - plugins_args = [ - 'use_gpt_attention_plugin', 'use_gemm_plugin', 'use_lookup_plugin', - 'use_lora_plugin' - ] - for plugin_arg in plugins_args: - if getattr(args, plugin_arg) is None: - logger.info( - f"{plugin_arg} set, without specifying a value. Using {args.dtype} automatically." - ) - setattr(args, plugin_arg, args.dtype) - - assert not ( - args.use_smooth_quant and args.use_weight_only - ), "You cannot enable both SmoothQuant and INT8 weight-only together." - - if args.use_inflight_batching: - if not args.use_gpt_attention_plugin: - args.use_gpt_attention_plugin = 'float16' - logger.info( - f"Using GPT attention plugin for inflight batching mode. Setting to default '{args.use_gpt_attention_plugin}'" - ) - if not args.remove_input_padding: - args.remove_input_padding = True - logger.info( - "Using remove input padding for inflight batching mode.") - if not args.paged_kv_cache: - args.paged_kv_cache = True - logger.info("Using paged KV cache for inflight batching mode.") - - assert (math.log2(args.tokens_per_block).is_integer() - ), "tokens_per_block must be power of 2" - if args.enable_context_fmha or args.enable_context_fmha_fp32_acc: - assert (args.tokens_per_block >= - 128), "Context fMHA requires >= 128 tokens per block" - - if args.use_smooth_quant: - args.quant_mode = QuantMode.use_smooth_quant(args.per_token, - args.per_channel) - elif args.use_weight_only: - args.quant_mode = QuantMode.use_weight_only( - use_int4_weights=(args.weight_only_precision == 'int4')) - else: - args.quant_mode = QuantMode(0) - - if args.int8_kv_cache: - args.quant_mode = args.quant_mode.set_int8_kv_cache() - if args.fp8_kv_cache: - assert ( - args.use_gpt_attention_plugin or args.use_inflight_batching - ), "You have to use GPT attention plugin when fp8 KV cache is set" - args.quant_mode = args.quant_mode.set_fp8_kv_cache() - - if args.enable_fp8: - args.quant_mode = args.quant_mode.set_fp8_qdq() - - if args.rotary_scaling is not None: - assert args.use_gpt_attention_plugin, "RoPE scaling is only supported through GPT attention plugin." - rotary_scaling = { - "type": args.rotary_scaling[0], - "factor": float(args.rotary_scaling[1]) - } - assert rotary_scaling["type"] in ["linear", "dynamic"] - assert rotary_scaling["factor"] > 1.0 - args.rotary_scaling = rotary_scaling - - args.max_num_tokens = check_max_num_tokens( - max_num_tokens=args.max_num_tokens, - max_batch_size=args.max_batch_size, - max_input_len=args.max_input_len, - remove_input_padding=args.remove_input_padding, - enable_context_fmha=args.enable_context_fmha, - tokens_per_block=args.tokens_per_block) - - if args.moe_num_experts and args.moe_top_k == 0: - args.moe_top_k = 1 - args.moe_config = MoeConfig(args.moe_num_experts, args.moe_top_k, - args.moe_tp_mode, - args.moe_renorm_mode).validate() - - if args.gather_all_token_logits: - args.gather_context_logits = True - args.gather_generation_logits = True - - return args - - -def build_rank_engine(builder: Builder, - builder_config: tensorrt_llm.builder.BuilderConfig, - engine_name, rank, args): - ''' - @brief: Build the engine on the given rank. - @param rank: The rank to build the engine. - @param args: The cmd line arguments. - @return: The built engine. - ''' - kv_dtype = str_dtype_to_trt(args.dtype) - - # Share_embedding_table can be set True only when: - # 1) the weight for lm_head() does not exist while other weights exist - # 2) For multiple-processes, use_parallel_embedding=True and embedding_sharding_dim == 0. - # Besides, for TensorRT 9.0, we can observe the engine size reduction when the lookup and gemm plugin are enabled. - share_embedding_table = False - if args.use_embedding_sharing: - if args.world_size > 1: - if args.model_dir is not None and args.embedding_sharding_dim == 0 and args.use_parallel_embedding: - share_embedding_table = check_embedding_share(args.model_dir) - else: - if args.model_dir is not None: - share_embedding_table = check_embedding_share(args.model_dir) - - if not share_embedding_table: - logger.warning(f'Cannot share the embedding lookup table.') - - if share_embedding_table: - logger.info( - 'Engine will try to share embedding and language modeling weights. Note: Flag --use_lookup_plugin and --use_gemm_plugin are also needed for now.' - ) - - # Initialize Module - tensorrt_llm_gpt = tensorrt_llm.models.GPTLMHeadModel( - num_layers=args.n_layer, - num_heads=args.n_head, - num_kv_heads=args.n_kv_head, - hidden_size=args.n_embd, - inter_size=args.inter_size, - vocab_size=args.vocab_size, - hidden_act=args.hidden_act, - max_position_embeddings=args.n_positions, - position_embedding_type=PositionEmbeddingType.learned_absolute - if args.rotary_pct == 0.0 else PositionEmbeddingType.rope_gpt_neox, - rotary_embedding_percentage=args.rotary_pct, - rotary_base=args.rotary_base, - rotary_scaling=args.rotary_scaling, - dtype=kv_dtype, - logits_dtype=args.logits_dtype, - mapping=Mapping(world_size=args.world_size, - rank=rank, - tp_size=args.world_size), # TP only - apply_query_key_layer_scaling=builder_config. - apply_query_key_layer_scaling, - quant_mode=args.quant_mode, - bias=args.bias, - use_prompt_tuning=args.max_prompt_embedding_table_size > 0, - use_parallel_embedding=args.use_parallel_embedding, - embedding_sharding_dim=args.embedding_sharding_dim, - share_embedding_table=share_embedding_table, - moe_config=args.moe_config, - max_lora_rank=args.max_lora_rank, - ) - - if args.use_smooth_quant or args.use_weight_only: - tensorrt_llm_gpt = quantize_model(tensorrt_llm_gpt, args.quant_mode) - - if args.model_dir is not None: - gpt_dummy_fp8_scaling_factors = { - 'fc_act': [0.5 for _ in range(args.n_layer)], - 'fc_weights': [0.5 for _ in range(args.n_layer)], - 'proj_act': [0.5 for _ in range(args.n_layer)], - 'proj_weights': [0.5 for _ in range(args.n_layer)], - 'qkv_act': [0.5 for _ in range(args.n_layer)], - 'qkv_weights': [0.5 for _ in range(args.n_layer)], - 'qkv_output': [0.5 for _ in range(args.n_layer)], - 'dense_act': [0.5 for _ in range(args.n_layer)], - 'dense_weights': [0.5 for _ in range(args.n_layer)], - } - - load_from_ft(tensorrt_llm_gpt, - args.model_dir, - rank, - args.world_size, - args.dtype, - args.use_parallel_embedding, - args.embedding_sharding_dim, - share_embedding_table, - scaling_factors=gpt_dummy_fp8_scaling_factors - if args.enable_fp8 else None) - - # Module -> Network - network = builder.create_network() - network.trt_network.name = engine_name - network.plugin_config.to_legacy_setting() - if args.use_gpt_attention_plugin: - network.plugin_config.set_gpt_attention_plugin( - dtype=args.use_gpt_attention_plugin) - if args.use_gemm_plugin: - if not args.enable_fp8: - network.plugin_config.set_gemm_plugin(dtype=args.use_gemm_plugin) - else: - logger.info( - "Gemm plugin does not support FP8. Disabled Gemm plugin.") - assert not (args.enable_context_fmha and args.enable_context_fmha_fp32_acc) - if args.enable_context_fmha: - network.plugin_config.set_context_fmha(ContextFMHAType.enabled) - if args.enable_context_fmha_fp32_acc: - network.plugin_config.set_context_fmha( - ContextFMHAType.enabled_with_fp32_acc) - if args.multi_block_mode: - network.plugin_config.enable_mmha_multi_block_mode() - if args.remove_input_padding: - network.plugin_config.enable_remove_input_padding() - if args.paged_kv_cache: - network.plugin_config.enable_paged_kv_cache(args.tokens_per_block) - if args.use_lora_plugin: - network.plugin_config.set_lora_plugin(dtype=args.use_lora_plugin) - - # Quantization plugins. - if args.use_smooth_quant: - network.plugin_config.set_smooth_quant_gemm_plugin(dtype=args.dtype) - network.plugin_config.set_layernorm_quantization_plugin( - dtype=args.dtype) - - network.plugin_config.set_quantize_tensor_plugin() - network.plugin_config.set_quantize_per_token_plugin() - elif args.use_weight_only: - network.plugin_config.set_weight_only_quant_matmul_plugin( - dtype=args.dtype) - - if args.world_size > 1: - network.plugin_config.set_nccl_plugin(args.dtype, - args.use_custom_all_reduce) - - if args.use_lookup_plugin: - # Use the plugin for the embedding parallelism and sharing - network.plugin_config.set_lookup_plugin(dtype=args.dtype) - - if args.use_paged_context_fmha or args.max_draft_len > 0: - assert args.enable_context_fmha or args.enable_context_fmha_fp32_acc, "context fmha must be enabled" - network.plugin_config.set_paged_context_fmha() - - if args.use_context_fmha_for_generation: - logger.warning( - f'use_context_fmha_for_generation is set. This flag must be used only for testing' - ) - assert args.use_gpt_attention_plugin and args.paged_kv_cache and args.use_paged_context_fmha, "use_context_fmha_for_generation must be used with paged KV cache and attention." - network.plugin_config.set_context_fmha_for_generation() - - with net_guard(network): - # Prepare - network.set_named_parameters(tensorrt_llm_gpt.named_parameters()) - - # Forward - inputs = tensorrt_llm_gpt.prepare_inputs( - max_batch_size=args.max_batch_size, - max_input_len=args.max_input_len, - max_seq_len=args.max_input_len + args.max_output_len, - use_cache=True, - max_beam_width=args.max_beam_width, - max_num_tokens=args.max_num_tokens, - prompt_embedding_table_size=args.max_prompt_embedding_table_size, - gather_context_logits=args.gather_context_logits, - gather_generation_logits=args.gather_generation_logits, - max_draft_len=args.max_draft_len, - lora_target_modules=args.lora_target_modules) - tensorrt_llm_gpt(*inputs) - - tensorrt_llm.graph_rewriting.optimize(network) - - engine = None - - # Network -> Engine - engine = builder.build_engine(network, builder_config) - if rank == 0: - config_path = args.output_dir / 'config.json' - builder.save_config(builder_config, config_path) - - return engine - - -def build(rank, args): - torch.cuda.set_device(rank % args.gpus_per_node) - tensorrt_llm.logger.set_level(args.log_level) - args.output_dir.mkdir(parents=True, exist_ok=True) - timing_cache_file = args.timing_cache if args.timing_cache else args.output_dir / "model.cache" - timing_cache = timing_cache_file - - builder = Builder() - apply_query_key_layer_scaling = False - for cur_rank in range(args.world_size): - # skip other ranks if parallel_build is enabled - if args.parallel_build and cur_rank != rank: - continue - # NOTE: when only int8 kv cache is used together with paged kv cache no int8 tensors are exposed to TRT - int8_trt_flag = args.quant_mode.has_act_or_weight_quant() or ( - args.paged_kv_cache == False - and args.quant_mode.has_int8_kv_cache()) - builder_config = builder.create_builder_config( - name=MODEL_NAME, - precision=args.dtype, - timing_cache=timing_cache, - profiling_verbosity=args.profiling_verbosity, - tensor_parallel=args.world_size, # TP only - parallel_build=args.parallel_build, - num_layers=args.n_layer, - num_heads=args.n_head, - num_kv_heads=args.n_kv_head, - hidden_size=args.n_embd, - vocab_size=args.vocab_size, - hidden_act=args.hidden_act, - max_position_embeddings=args.n_positions, - apply_query_key_layer_scaling=apply_query_key_layer_scaling, - max_batch_size=args.max_batch_size, - max_beam_width=args.max_beam_width, - max_input_len=args.max_input_len, - max_output_len=args.max_output_len, - max_num_tokens=args.max_num_tokens, - max_draft_len=args.max_draft_len, - int8=int8_trt_flag, - opt_level=args.builder_opt, - strongly_typed=args.strongly_typed, - max_prompt_embedding_table_size=args. - max_prompt_embedding_table_size, - gather_context_logits=args.gather_context_logits, - gather_generation_logits=args.gather_generation_logits, - quant_mode=args.quant_mode, - use_parallel_embedding=args.use_parallel_embedding, - lora_target_modules=args.lora_target_modules, - max_lora_rank=args.max_lora_rank, - ) - - engine_name = get_engine_name(MODEL_NAME, args.dtype, args.world_size, - cur_rank) - engine = build_rank_engine(builder, builder_config, engine_name, - cur_rank, args) - assert engine is not None, f'Failed to build engine for rank {cur_rank}' - - local_num_kv_heads = (args.n_kv_head + args.world_size - - 1) // args.world_size - kv_dtype = str_dtype_to_trt(args.dtype) - if args.quant_mode.has_int8_kv_cache(): - kv_dtype = str_dtype_to_trt('int8') - elif args.quant_mode.has_fp8_kv_cache(): - kv_dtype = str_dtype_to_trt('fp8') - check_gpt_mem_usage( - engine=engine, - kv_dtype=kv_dtype, - use_gpt_attention_plugin=args.use_gpt_attention_plugin, - paged_kv_cache=args.paged_kv_cache, - max_batch_size=args.max_batch_size, - max_beam_width=args.max_beam_width, - max_seq_len=args.max_input_len + args.max_output_len, - local_num_kv_heads=local_num_kv_heads, - head_size=args.n_embd / args.n_head, - num_layers=args.n_layer) - - if cur_rank == 0: - # Use in-memory timing cache for multiple builder passes. - if not args.parallel_build: - timing_cache = builder_config.trt_builder_config.get_timing_cache( - ) - - serialize_engine(engine, args.output_dir / engine_name) - del engine - - if rank == 0: - ok = builder.save_timing_cache(builder_config, timing_cache_file) - assert ok, "Failed to save timing cache." - - -def run_build(args=None): - args = parse_arguments(args) - - if args.random_seed is not None: - torch.manual_seed(args.random_seed) - - logger.set_level(args.log_level) - tik = time.time() - if args.parallel_build and args.world_size > 1 and \ - torch.cuda.device_count() >= args.world_size: - logger.warning( - f'Parallel build TensorRT engines. Please make sure that all of the {args.world_size} GPUs are totally free.' - ) - mp.spawn(build, nprocs=args.world_size, args=(args, )) - else: - args.parallel_build = False - logger.info('Serially build TensorRT engines.') - build(0, args) - - tok = time.time() - t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) - logger.info(f'Total time of building all {args.world_size} engines: {t}') - - -if __name__ == '__main__': - run_build() diff --git a/examples/gpt/convert_checkpoint.py b/examples/gpt/convert_checkpoint.py new file mode 100644 index 000000000..6645236c5 --- /dev/null +++ b/examples/gpt/convert_checkpoint.py @@ -0,0 +1,2034 @@ +import argparse +import configparser +import functools +import json +import logging +import os +import shutil +import tarfile +import time +import traceback +from collections import defaultdict, namedtuple +from concurrent.futures import ThreadPoolExecutor, as_completed +from pathlib import Path +from typing import Any, Dict, Optional, Tuple, Union + +import numpy as np +import safetensors +import torch +import torch.nn as nn +import yaml +from datasets import load_dataset +from tqdm import tqdm +from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer, + GPT2Config, GPT2Tokenizer, T5Tokenizer) +from transformers.models.gpt2.modeling_gpt2 import GPT2Block +from transformers.pytorch_utils import Conv1D + +import tensorrt_llm +from tensorrt_llm._utils import pad_vocab_size, str_dtype_to_torch +from tensorrt_llm.mapping import Mapping +from tensorrt_llm.models.llama.utils import retrieved_layer_index_from_name + +LOGGER = logging.getLogger(__name__) + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument('--model_dir', type=str, default=None) + parser.add_argument('--nemo_ckpt_path', type=str, default=None) + parser.add_argument('--load_nemo_on_gpu', + default=False, + action="store_true", + help="Whether to load NeMo checkpoint on GPU") + parser.add_argument( + '--gpt_variant', + default=None, + choices=[None, 'gpt2', 'santacoder', 'starcoder', 'starcoder2'], + help= + "By default the script will try to infer the gpt_variant from model_dir. " + "Or users may overwrite gpt_variant by explicitly passing the variant.") + parser.add_argument('--tp_size', + type=int, + default=1, + help='N-way tensor parallelism size') + parser.add_argument('--pp_size', + type=int, + default=1, + help='N-way pipeline parallelism size') + parser.add_argument('--dtype', + type=str, + default='float16', + choices=['float32', 'bfloat16', 'float16']) + parser.add_argument( + '--use_parallel_embedding', + action="store_true", + default=False, + help= + 'By default embedding parallelism is disabled. By setting this flag, embedding parallelism is enabled' + ) + parser.add_argument( + '--embedding_sharding_dim', + type=int, + default=0, + choices=[0, 1], + help= + 'By default the embedding lookup table is sharded along vocab dimension (embedding_sharding_dim=0). ' + 'To shard it along hidden dimension, set embedding_sharding_dim=1' + 'Note: embedding sharing is only enabled when embedding_sharding_dim = 0' + ) + parser.add_argument( + '--use_embedding_sharing', + action="store_true", + default=False, + help= + 'Try to reduce the engine size by sharing the embedding lookup table between two layers.' + 'Note: the flag might not take effect when the criteria are not met.') + + parser.add_argument( + '--use_weight_only', + default=False, + action="store_true", + help='Quantize weights for the various GEMMs to INT4/INT8.' + 'See --weight_only_precision to set the precision') + parser.add_argument( + '--weight_only_precision', + const='int8', + type=str, + nargs='?', + default='int8', + choices=['int8', 'int4'], + help= + 'Define the precision for the weights when using weight-only quantization.' + 'You must also use --use_weight_only for that argument to have an impact.' + ) + + parser.add_argument( + '--int8_kv_cache', + default=False, + action="store_true", + help= + 'By default, we use dtype for KV cache. int8_kv_cache chooses int8 quantization for KV' + ) + parser.add_argument( + '--per_channel', + default=False, + action="store_true", + help= + 'By default, we use a single static scaling factor for the GEMM\'s result. ' + 'per_channel instead uses a different static scaling factor for each channel. ' + 'The latter is usually more accurate, but a little slower.') + parser.add_argument( + '--per_token', + default=False, + action="store_true", + help= + 'By default, we use a single static scaling factor to scale activations in the int8 range. ' + 'per_token chooses at run time, and for each token, a custom scaling factor. ' + 'The latter is usually more accurate, but a little slower.') + parser.add_argument( + "--smoothquant", + "-sq", + type=float, + default=None, + help="Set the α parameter (see https://arxiv.org/pdf/2211.10438.pdf)" + " to Smoothquant the model, and output int8 weights." + " A good first try is 0.5. Must be in [0, 1]") + parser.add_argument("--dataset_cache_dir", + type=str, + default=None, + help="cache dir to load the hugging face dataset") + + parser.add_argument( + '--lora_target_modules', + nargs='+', + default=None, + choices=[ + "attn_qkv", + "attn_q", + "attn_k", + "attn_v", + "attn_dense", + "mlp_h_to_4h", + "mlp_gate", + "mlp_4h_to_h", + ], + help= + "Add lora in which modules. Only be activated when use_lora_plugin is enabled." + ) + parser.add_argument( + '--max_lora_rank', + type=int, + default=64, + help='maximum lora rank for different lora modules. ' + 'It is used to compute the workspace size of lora plugin.') + parser.add_argument('--output_dir', + type=str, + default='tllm_checkpoint', + help='The path to save the TensorRT-LLM checkpoint') + parser.add_argument( + '--workers', + type=int, + default=1, + help='The number of workers for converting checkpoint in parallel') + parser.add_argument('--log_level', type=str, default='info') + args = parser.parse_args() + + tensorrt_llm.logger.set_level(args.log_level) + return args + + +def load_gpt_config(model_dir: str, + gpt_variant: Optional[str] = None) -> GPT2Config: + config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True) + + if gpt_variant is None: + print("Inferring gpt variant from path...") + for v in ['starcoder2', 'starcoder', 'santacoder', 'gpt2']: + if v in config._name_or_path: + gpt_variant = v + break + assert gpt_variant in ['gpt2', 'santacoder', 'starcoder', 'starcoder2'] + print(f"Gpt variant: {gpt_variant}") + + if gpt_variant == 'starcoder2': + config.n_embd = config.hidden_size + config.n_inner = config.intermediate_size + config.n_head = config.num_attention_heads + config.n_kv_head = config.num_key_value_heads + config.n_layer = config.num_hidden_layers + config.n_positions = config.max_position_embeddings + config.activation_function = 'gelu' + config.layer_norm_epsilon = config.norm_epsilon + config.bias = config.use_bias + config.position_embedding_type = 'rope_gpt_neox' + config.rotary_base = config.rope_theta + config.rotary_pct = 1.0 + else: + if config.n_inner is None: + config.n_inner = config.n_embd * 4 + if gpt_variant in ['santacoder', 'starcoder']: + config.n_kv_head = 1 + else: + config.n_kv_head = config.n_head + return config, gpt_variant + + +def split(param: torch.Tensor, + tp_rank: int, + tp_size: int, + is_column: bool = True) -> torch.Tensor: + """Split linear layer's weight, bias or scaling factors for tensor parallelism.""" + if param is None: + return None + assert param.ndim in [1, 2] + if tp_size == 1: + return param + if param.numel() == 1: + return param + if param.ndim == 1 and not is_column: + return param + split_dim = 0 if (param.ndim == 1 or is_column) else 1 + return torch.chunk(param, tp_size, dim=split_dim)[tp_rank].contiguous() + + +def split_qkv( + param: torch.Tensor, + tp_rank: int, + tp_size: int, + hidden_size: int, + num_heads: int, + num_kv_heads: Optional[int] = None, +) -> torch.Tensor: + """Split qkv layer's weight, bias or scaling factors for tensor parallelism. + + param: (num_heads*head_dim + 2*num_kv_heads*head_dim, in_dim) + """ + if param is None: + return None + assert hidden_size % num_heads == 0 + head_dim = hidden_size // num_heads + num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads + assert num_heads % num_kv_heads == 0 + assert num_heads % tp_size == 0 + + q_param, k_param, v_param = torch.split( + param, [hidden_size, num_kv_heads * head_dim, num_kv_heads * head_dim], + dim=0) + + if num_kv_heads < tp_size: + assert tp_size % num_kv_heads == 0 + num_dups = tp_size // num_kv_heads + remain_shape = k_param.shape[1:] + k_param = k_param.view( + num_kv_heads, head_dim, + *remain_shape).repeat_interleave(num_dups, dim=0).view( + num_kv_heads * head_dim * num_dups, *remain_shape) + v_param = v_param.view( + num_kv_heads, head_dim, + *remain_shape).repeat_interleave(num_dups, dim=0).view( + num_kv_heads * head_dim * num_dups, *remain_shape) + else: + assert num_kv_heads % tp_size == 0 + + q_param = split(q_param, tp_rank, tp_size, is_column=True) + k_param = split(k_param, tp_rank, tp_size, is_column=True) + v_param = split(v_param, tp_rank, tp_size, is_column=True) + return torch.cat([q_param, k_param, v_param], dim=0) + + +def get_weight(params: Dict[str, torch.Tensor], prefix: str, + dtype: torch.dtype) -> torch.Tensor: + if f'{prefix}.weight' not in params: + return None + return params[f'{prefix}.weight'].to(dtype).detach().cpu() + + +def get_bias(params: Dict[str, torch.Tensor], prefix: str, + dtype: torch.dtype) -> torch.Tensor: + if f'{prefix}.bias' not in params: + return None + return params[f'{prefix}.bias'].to(dtype).detach().cpu() + + +def get_weight_and_bias(params: Dict[str, torch.Tensor], prefix: str, + dtype: torch.dtype) -> Tuple[torch.Tensor]: + return get_weight(params, prefix, dtype), get_bias(params, prefix, dtype) + + +def get_tllm_linear_weight( + weight: torch.Tensor, + prefix: str, + bias: Optional[torch.Tensor] = None, + use_weight_only: bool = False, + plugin_weight_only_quant_type: torch.dtype = torch.int8 +) -> Dict[str, torch.Tensor]: + results = {} + if use_weight_only: + v = weight.t().contiguous() + processed_torch_weights, torch_weight_scales = \ + torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix( + v, plugin_weight_only_quant_type) + results[f'{prefix}.weight'] = processed_torch_weights + results[f'{prefix}.per_channel_scale'] = torch_weight_scales + else: + results[f'{prefix}.weight'] = weight + + if bias is not None: + results[f'{prefix}.bias'] = bias + + return results + + +def convert_hf_gpt(hf_model: AutoModelForCausalLM, + hf_config: AutoConfig, + gpt_variant: str, + mapping: Mapping, + dtype: str = 'float32', + use_parallel_embedding: bool = False, + sharding_dim: int = 0, + share_embedding_table: bool = False, + use_weight_only: bool = False, + plugin_weight_only_quant_type: torch.dtype = torch.int8): + weights = {} + tik = time.time() + + model_params = dict(hf_model.named_parameters()) + dtype = getattr(torch, dtype) + num_attention_heads = hf_config.n_head + hidden_size = hf_config.n_embd + vocab_size = hf_config.vocab_size + num_kv_heads = hf_config.n_kv_head + num_hidden_layers = hf_config.n_layer + + layers_range = mapping.pp_layers(num_hidden_layers) + for l in layers_range: + if gpt_variant == 'starcoder2': + prefix = f'model.layers.{l}' + else: + prefix = f'transformer.h.{l}' + tllm_prex = f'transformer.layers.{l-layers_range[0]}' + if gpt_variant == 'santacoder': + q_w, q_b = get_weight_and_bias(model_params, + f'{prefix}.attn.q_attn', dtype) + kv_w, kv_b = get_weight_and_bias(model_params, + f'{prefix}.attn.kv_attn', dtype) + qkv_w = torch.cat([q_w, kv_w], dim=-1) + qkv_b = torch.cat([q_b, kv_b], dim=-1) + elif gpt_variant == 'starcoder2': + q_w, q_b = get_weight_and_bias(model_params, + f'{prefix}.self_attn.q_proj', dtype) + k_w, k_b = get_weight_and_bias(model_params, + f'{prefix}.self_attn.k_proj', dtype) + v_w, v_b = get_weight_and_bias(model_params, + f'{prefix}.self_attn.v_proj', dtype) + qkv_w = torch.cat([q_w, k_w, v_w], dim=0) + qkv_b = torch.cat([q_b, k_b, v_b], dim=0) + else: + qkv_w, qkv_b = get_weight_and_bias(model_params, + f'{prefix}.attn.c_attn', dtype) + if gpt_variant in ['gpt2', 'santacoder']: + qkv_w = qkv_w.t().contiguous() # transpose for Conv1D + qkv_w = split_qkv(qkv_w, mapping.tp_rank, mapping.tp_size, hidden_size, + num_attention_heads, num_kv_heads) + qkv_b = split_qkv(qkv_b, mapping.tp_rank, mapping.tp_size, hidden_size, + num_attention_heads, num_kv_heads) + + weights.update( + get_tllm_linear_weight(qkv_w, f'{tllm_prex}.attention.qkv', qkv_b, + use_weight_only, + plugin_weight_only_quant_type)) + + if gpt_variant == 'starcoder2': + attn_dense_w, attn_dense_b = get_weight_and_bias( + model_params, f'{prefix}.self_attn.o_proj', dtype) + else: + attn_dense_w, attn_dense_b = get_weight_and_bias( + model_params, f'{prefix}.attn.c_proj', dtype) + if gpt_variant in ['gpt2', 'santacoder']: + attn_dense_w = attn_dense_w.t().contiguous() # transpose for Conv1D + attn_dense_w = split(attn_dense_w, + mapping.tp_rank, + mapping.tp_size, + is_column=False) + weights.update( + get_tllm_linear_weight(attn_dense_w, f'{tllm_prex}.attention.dense', + attn_dense_b, use_weight_only, + plugin_weight_only_quant_type)) + + mlp_fc_w, mlp_fc_b = get_weight_and_bias(model_params, + f'{prefix}.mlp.c_fc', dtype) + if gpt_variant in ['gpt2', 'santacoder']: + mlp_fc_w = mlp_fc_w.t().contiguous() # transpose for Conv1D + mlp_fc_w = split(mlp_fc_w, + mapping.tp_rank, + mapping.tp_size, + is_column=True) + mlp_fc_b = split(mlp_fc_b, + mapping.tp_rank, + mapping.tp_size, + is_column=True) + weights.update( + get_tllm_linear_weight(mlp_fc_w, f'{tllm_prex}.mlp.fc', mlp_fc_b, + use_weight_only, + plugin_weight_only_quant_type)) + + mlp_proj_w, mlp_proj_b = get_weight_and_bias(model_params, + f'{prefix}.mlp.c_proj', + dtype) + if gpt_variant in ['gpt2', 'santacoder']: + mlp_proj_w = mlp_proj_w.t().contiguous() # transpose for Conv1D + mlp_proj_w = split(mlp_proj_w, + mapping.tp_rank, + mapping.tp_size, + is_column=False) + weights.update( + get_tllm_linear_weight(mlp_proj_w, f'{tllm_prex}.mlp.proj', + mlp_proj_b, use_weight_only, + plugin_weight_only_quant_type)) + + if gpt_variant == 'starcoder2': + input_ln_w, input_ln_b = get_weight_and_bias( + model_params, f'{prefix}.input_layernorm', dtype) + else: + input_ln_w, input_ln_b = get_weight_and_bias( + model_params, f'{prefix}.ln_1', dtype) + weights[f'{tllm_prex}.input_layernorm.weight'] = input_ln_w + if input_ln_b is not None: + weights[f'{tllm_prex}.input_layernorm.bias'] = input_ln_b + + if gpt_variant == 'starcoder2': + post_ln_w, post_ln_b = get_weight_and_bias( + model_params, f'{prefix}.post_attention_layernorm', dtype) + else: + post_ln_w, post_ln_b = get_weight_and_bias(model_params, + f'{prefix}.ln_2', dtype) + weights[f'{tllm_prex}.post_layernorm.weight'] = post_ln_w + if post_ln_b is not None: + weights[f'{tllm_prex}.post_layernorm.bias'] = post_ln_b + + if mapping.is_first_pp_rank(): + if gpt_variant == 'starcoder2': + embed_w = get_weight(model_params, 'model.embed_tokens', dtype) + else: + embed_w = get_weight(model_params, 'transformer.wte', dtype) + if not use_parallel_embedding: + weights['transformer.vocab_embedding.weight'] = embed_w + else: + if sharding_dim == 0: + if vocab_size % mapping.tp_size != 0: + vocab_size_padded = pad_vocab_size(vocab_size, + mapping.tp_size) + pad_width = vocab_size_padded - vocab_size + embed_w = torch.nn.functional.pad(embed_w, + (0, 0, 0, pad_width), + value=0) + else: + assert hidden_size % mapping.tp_size == 0 + weights['transformer.vocab_embedding.weight'] = split( + embed_w, + mapping.tp_rank, + mapping.tp_size, + is_column=(sharding_dim == 0)) + + pos_embed_w = get_weight(model_params, 'transformer.wpe', dtype) + if pos_embed_w is not None: + weights['transformer.position_embedding.weight'] = pos_embed_w + + if mapping.is_last_pp_rank(): + if gpt_variant == 'starcoder2': + embed_w = get_weight(model_params, 'lm_head', dtype) + if embed_w is None: + embed_w = get_weight(model_params, 'model.embed_tokens', dtype) + else: + embed_w = get_weight(model_params, 'transformer.wte', dtype) + if not share_embedding_table: + if vocab_size % mapping.tp_size != 0: + vocab_size_padded = pad_vocab_size(vocab_size, mapping.tp_size) + pad_width = vocab_size_padded - vocab_size + embed_w = torch.nn.functional.pad(embed_w, (0, 0, 0, pad_width), + value=0) + weights['lm_head.weight'] = split(embed_w.clone(), + mapping.tp_rank, + mapping.tp_size, + is_column=True) + if gpt_variant == 'starcoder2': + ln_f_w, ln_f_b = get_weight_and_bias(model_params, 'model.norm', + dtype) + else: + ln_f_w, ln_f_b = get_weight_and_bias(model_params, + 'transformer.ln_f', dtype) + weights['transformer.ln_f.weight'] = ln_f_w + if ln_f_b is not None: + weights['transformer.ln_f.bias'] = ln_f_b + + tok = time.time() + t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) + print(f'Weights loaded. Total time: {t}') + return weights + + +def generate_int8(weights, act_range, is_qkv=False, multi_query_mode=False): + """ + This function has two purposes: + - compute quantized weights, scaled either per-tensor or per-column + - compute scaling factors + + Depending on the GEMM API (CUTLASS/CUBLAS) the required scaling factors differ. + CUTLASS uses two sets of scaling factors. One for the activation X, one for the weight W. + CUBLAS only has one (we can't do per-row scaling). So we must provide pre-multiplied scaling factor. + + Here is the list of what we need (T means per-tensor, C per-column): + - scale_x_orig_quant puts fp activation into the quantized range (i.e. [-128, 127], for int8). Used before the GEMM. (T) + - scale_y_quant_orig puts quantized activation into the fp range. Used if the GEMM outputs int8. (T) + - scale_w_quant_orig puts weights from quant range to fp range (used with CUTLASS) (T, C) + - scale_y_accum_quant puts the GEMM result (XW) from accumulation range (int32) + to quant range (int8) (used for CUBLAS) (T, C) + + Note that we don't do anything special about row-parallel GEMM. Theoretically, we could have per-GPU scaling factors too, + but then the model would change depending on the number of GPUs used. + + For QKV projection, the behavior is special. Even if we have a single matrix to perform QKV projection, we consider it + as three different matrices: Q, K, and V. So per-tensor actually means one scaling factor for each Q, K and V. + """ + + # compute weight scaling factors for fp->int8 and int8->fp + if is_qkv and not multi_query_mode: + scale_w_orig_quant_t = 127. / act_range["w"].reshape(3, -1).max( + dim=-1, keepdims=True)[0].cpu().numpy() + scale_w_orig_quant_c = 127. / act_range["w"].reshape(3, + -1).cpu().numpy() + elif is_qkv and multi_query_mode: + raise ValueError( + f"Multi-query w/ int8 quant has not been supported yet") + else: + scale_w_orig_quant_t = 127. / act_range["w"].max().cpu().numpy() + scale_w_orig_quant_c = 127. / act_range["w"].cpu().numpy() + scale_w_quant_orig_t = 1.0 / scale_w_orig_quant_t + scale_w_quant_orig_c = 1.0 / scale_w_orig_quant_c + + # compute the rest of needed scaling factors + scale_x_orig_quant_t = np.array(127. / act_range["x"].max().item()) + scale_y_orig_quant_t = np.array(127. / act_range["y"].max().item()) + scale_y_quant_orig_t = np.array(act_range["y"].max().item() / 127.) + scale_y_accum_quant_t = scale_y_orig_quant_t / (scale_x_orig_quant_t * + scale_w_orig_quant_t) + scale_y_accum_quant_c = scale_y_orig_quant_t / (scale_x_orig_quant_t * + scale_w_orig_quant_c) + if is_qkv: + scale_y_accum_quant_t = np.broadcast_to(scale_y_accum_quant_t, + scale_w_orig_quant_c.shape) + scale_w_quant_orig_t = np.broadcast_to(scale_w_quant_orig_t, + scale_w_orig_quant_c.shape) + + to_i8 = lambda x: x.round().clip(-127, 127).astype(np.int8) + return { + "weight.int8": to_i8(weights * scale_w_orig_quant_t), + "weight.int8.col": to_i8(weights * scale_w_orig_quant_c), + "scale_x_orig_quant": scale_x_orig_quant_t.astype(np.float32), + "scale_w_quant_orig": scale_w_quant_orig_t.astype(np.float32), + "scale_w_quant_orig.col": scale_w_quant_orig_c.astype(np.float32), + "scale_y_accum_quant": scale_y_accum_quant_t.astype(np.float32), + "scale_y_accum_quant.col": scale_y_accum_quant_c.astype(np.float32), + "scale_y_quant_orig": scale_y_quant_orig_t.astype(np.float32), + } + + +@torch.no_grad() +def apply_smoothing(scales, + gemm_weights, + layernorm_weights=None, + layernorm_bias=None, + dtype=torch.float32, + layernorm_1p=False): + if not isinstance(gemm_weights, list): + gemm_weights = [gemm_weights] + + if layernorm_weights is not None: + assert layernorm_weights.numel() == scales.numel() + layernorm_weights.div_(scales).to(dtype) + if layernorm_bias is not None: + assert layernorm_bias.numel() == scales.numel() + layernorm_bias.div_(scales).to(dtype) + if layernorm_1p: + layernorm_weights += (1 / scales) - 1 + + for gemm in gemm_weights: + gemm.mul_(scales.view(1, -1)).to(dtype) + + +@torch.no_grad() +def smooth_gemm(gemm_weights, + act_scales, + layernorm_weights=None, + layernorm_bias=None, + alpha=0.5, + weight_scales=None): + if not isinstance(gemm_weights, list): + gemm_weights = [gemm_weights] + orig_dtype = gemm_weights[0].dtype + + for gemm in gemm_weights: + # gemm_weights are expected to be transposed + assert gemm.shape[1] == act_scales.numel() + + if weight_scales is None: + weight_scales = torch.cat( + [gemm.abs().max(dim=0, keepdim=True)[0] for gemm in gemm_weights], + dim=0) + weight_scales = weight_scales.max(dim=0)[0] + weight_scales.to(float).clamp(min=1e-5) + scales = (act_scales.to(gemm_weights[0].device).to(float).pow(alpha) / + weight_scales.pow(1 - alpha)).clamp(min=1e-5) + + apply_smoothing(scales, gemm_weights, layernorm_weights, layernorm_bias, + orig_dtype) + + return scales + + +@torch.no_grad() +def capture_activation_range(model, + tokenizer, + dataset, + num_samples=512, + seq_len=512): + model.eval() + device = next(model.parameters()).device + act_scales = defaultdict(lambda: {"x": None, "y": None, "w": None}) + + def stat_tensor(name, tensor, act_scales, key): + hidden_dim = tensor.shape[-1] + tensor = tensor.view(-1, hidden_dim).abs().detach() + comming_max = torch.max(tensor, dim=0)[0].float() + + if act_scales[name][key] is None: + act_scales[name][key] = comming_max + else: + act_scales[name][key] = torch.max(act_scales[name][key], + comming_max) + + def stat_input_hook(m, x, y, name): + if isinstance(x, tuple): + x = x[0] + stat_tensor(name, x, act_scales, "x") + stat_tensor(name, y, act_scales, "y") + + if act_scales[name]["w"] is None: + act_scales[name]["w"] = m.weight.abs().clip(1e-8, + None).max(dim=0)[0] + + hooks = [] + for name, m in model.named_modules(): + if isinstance(m, nn.Linear) or isinstance(m, Conv1D): + hooks.append( + m.register_forward_hook( + functools.partial(stat_input_hook, name=name))) + + for i in tqdm(range(num_samples), desc="calibrating model"): + input_ids = tokenizer(dataset[i]["text"], + return_tensors="pt", + max_length=seq_len, + truncation=True).input_ids.to(device) + model(input_ids) + + for h in hooks: + h.remove() + + return act_scales + + +@torch.no_grad() +def smooth_gpt_model(model, scales, alpha): + # Smooth the activation and weights with smoother = $\diag{s}$ + for name, module in model.named_modules(): + if not isinstance(module, GPT2Block): + continue + + # qkv_proj + layer_name = name + ".attn.c_attn" + smoother = smooth_gemm(module.attn.c_attn.weight.T, + scales[layer_name]["x"], module.ln_1.weight, + module.ln_1.bias, alpha) + scales[layer_name]["x"] = scales[layer_name]["x"] / smoother + scales[layer_name]["w"] = module.attn.c_attn.weight.abs().max(dim=0)[0] + + # fc1 + layer_name = name + ".mlp.c_fc" + smoother = smooth_gemm(module.mlp.c_fc.weight.T, + scales[layer_name]["x"], module.ln_2.weight, + module.ln_2.bias, alpha) + scales[layer_name]["x"] = scales[layer_name]["x"] / smoother + scales[layer_name]["w"] = module.mlp.c_fc.weight.abs().max(dim=0)[0] + + +def get_tllm_linear_sq_weight(vals, + prefix, + shape, + tensor_parallel, + is_qkv=False, + per_token=False, + per_channel=False, + last_prefix=None, + bias=None, + smoother_value=None, + smoother_shape=None, + rank=0, + cat_dim=0, + multi_query_mode=False): + results = {} + + def multi_query_split(data, local_dim, head_size, tp_size, cur_rank): + q, k, v = np.split(data, [local_dim, local_dim + head_size], axis=-1) + q_split = np.split(q, tp_size, axis=-1) + k_split = np.split(k, tp_size, axis=-1) + v_split = np.split(v, tp_size, axis=-1) + return [ + np.concatenate((q_split[ii], k_split[ii], v_split[ii]), axis=-1) + for ii in range(tp_size) + ][cur_rank] + + col_shape = shape if (is_qkv or per_channel) else [1, 1] + + if per_token: + if per_channel: + original_weights = np.array(vals["weight.int8.col"]) + else: + original_weights = np.array(vals["weight.int8"]) + local_dim = original_weights.shape[0] + head_size = (original_weights.shape[1] - local_dim) // 2 + + if multi_query_mode: + cur_weights = multi_query_split(original_weights, local_dim, + head_size, tensor_parallel, rank) + else: + cur_weights = np.split(original_weights, + tensor_parallel, + axis=cat_dim)[rank] + if is_qkv: + hidden_dim = cur_weights.shape[0] + cur_weights = cur_weights.reshape(hidden_dim, -1) + results[prefix + + 'weight'] = torch.from_numpy(cur_weights).t().contiguous() + if smoother_value is None: + results[last_prefix] = torch.from_numpy( + np.array([1.0], dtype=np.float32)) + + if per_channel: + cur_per_channel_value = vals["scale_w_quant_orig.col"] + if smoother_value is None: + if multi_query_mode: + cur_per_channel_value = multi_query_split( + vals["scale_w_quant_orig.col"], local_dim, head_size, + tensor_parallel, rank) + else: + cur_per_channel_value = np.split( + vals["scale_w_quant_orig.col"], + tensor_parallel, + axis=cat_dim)[rank] + else: + cur_per_channel_value = vals["scale_w_quant_orig"] + if is_qkv: + if multi_query_mode: + cur_per_channel_value = multi_query_split( + vals["scale_w_quant_orig"], local_dim, head_size, + tensor_parallel, rank) + else: + cur_per_channel_value = np.split(vals["scale_w_quant_orig"], + tensor_parallel, + axis=cat_dim)[rank] + + results[prefix + 'per_channel_scale'] = torch.from_numpy( + np.array(cur_per_channel_value, + dtype=np.float32).reshape(col_shape)).contiguous() + else: + if per_channel: + original_weights = np.array(vals["weight.int8.col"]) + else: + original_weights = np.array(vals["weight.int8"]) + local_dim = original_weights.shape[0] + head_size = (original_weights.shape[1] - local_dim) // 2 + + if multi_query_mode: + cur_weights = multi_query_split(original_weights, local_dim, + head_size, tensor_parallel, rank) + else: + cur_weights = np.split(original_weights, + tensor_parallel, + axis=cat_dim)[rank] + if is_qkv: + hidden_dim = cur_weights.shape[0] + cur_weights = cur_weights.reshape(hidden_dim, -1) + results[prefix + + 'weight'] = torch.from_numpy(cur_weights).t().contiguous() + + if per_channel: + cur_per_channel_value = vals["scale_y_accum_quant.col"] + if smoother_value is None: + if multi_query_mode: + cur_per_channel_value = multi_query_split( + vals["scale_y_accum_quant.col"], local_dim, head_size, + tensor_parallel, rank) + else: + cur_per_channel_value = np.split( + vals["scale_y_accum_quant.col"], + tensor_parallel, + axis=cat_dim)[rank] + else: + cur_per_channel_value = vals["scale_y_accum_quant"] + # QKV is always per_channel + if is_qkv: + if multi_query_mode: + cur_per_channel_value = multi_query_split( + vals["scale_y_accum_quant"], local_dim, head_size, + tensor_parallel, rank) + else: + cur_per_channel_value = np.split( + vals["scale_y_accum_quant"], + tensor_parallel, + axis=cat_dim)[rank] + + results[prefix + 'per_channel_scale'] = torch.from_numpy( + np.array([cur_per_channel_value], + dtype=np.float32).reshape(col_shape)).contiguous() + + results[last_prefix] = torch.from_numpy( + np.array([vals['scale_x_orig_quant']], + dtype=np.float32)).contiguous() + + results[prefix + 'act_scale'] = torch.from_numpy( + np.array([[vals["scale_y_quant_orig"]]], + dtype=np.float32)).contiguous() + + if smoother_value is not None: + cur_smoother_value = np.split(smoother_value, + tensor_parallel, + axis=cat_dim)[rank] + results[prefix + 'smoother'] = cur_smoother_value.reshape( + smoother_shape).contiguous().to(torch.float32) + + if bias is not None: + results[prefix + 'bias'] = bias + + return results + + +def convert_hf_gpt_legacy(hf_model: AutoModelForCausalLM, + hf_config: AutoConfig, + gpt_variant: str, + mapping: Mapping, + dtype: str = 'float32', + use_parallel_embedding: bool = False, + sharding_dim: int = 0, + share_embedding_table: bool = False, + use_smooth_quant=False, + per_channel=False, + per_token=False, + int8_kv_cache=False, + act_range=None): + weights = {} + tik = time.time() + + model_params = dict(hf_model.named_parameters()) + dtype = getattr(torch, dtype) + num_attention_heads = hf_config.n_head + hidden_size = hf_config.n_embd + vocab_size = hf_config.vocab_size + num_kv_heads = hf_config.n_kv_head + num_hidden_layers = hf_config.n_layer + multi_query_mode = (num_kv_heads != num_attention_heads) + tensor_parallel = mapping.tp_size + + layers_range = mapping.pp_layers(num_hidden_layers) + for l in layers_range: + prefix = f'transformer.h.{l}' + tllm_prex = f'transformer.layers.{l-layers_range[0]}' + + if gpt_variant == 'santacoder': + q_w, q_b = get_weight_and_bias(model_params, + f'{prefix}.attn.q_attn', dtype) + kv_w, kv_b = get_weight_and_bias(model_params, + f'{prefix}.attn.kv_attn', dtype) + qkv_w = torch.cat([q_w, kv_w], dim=-1) + qkv_b = torch.cat([q_b, kv_b], dim=-1) + else: + qkv_w, qkv_b = get_weight_and_bias(model_params, + f'{prefix}.attn.c_attn', dtype) + if gpt_variant in ['gpt2', 'santacoder']: + qkv_w = qkv_w.t().contiguous() # transpose for Conv1D + + if use_smooth_quant: + qkv_out_dim = qkv_w.shape[0] + qkv_w_numpy = qkv_w.t().numpy() + if not multi_query_mode: + qkv_w_numpy = qkv_w_numpy.reshape(hidden_size, 3, hidden_size) + int8_weights = generate_int8(qkv_w_numpy, + act_range.get(f'{prefix}.attn.c_attn'), + is_qkv=True, + multi_query_mode=multi_query_mode) + qkv_b = split_qkv(qkv_b, mapping.tp_rank, mapping.tp_size, + hidden_size, num_attention_heads, num_kv_heads) + weights.update( + get_tllm_linear_sq_weight( + int8_weights, + f'{tllm_prex}.attention.qkv.', + [1, qkv_out_dim // tensor_parallel], + tensor_parallel, + is_qkv=True, + per_token=per_token, + per_channel=per_channel, + last_prefix=f'{tllm_prex}.input_layernorm.scale_to_int', + bias=qkv_b, + smoother_value=None, + smoother_shape=None, + rank=rank, + cat_dim=-1, + multi_query_mode=multi_query_mode)) + else: + qkv_w = split_qkv(qkv_w, mapping.tp_rank, mapping.tp_size, + hidden_size, num_attention_heads, num_kv_heads) + qkv_b = split_qkv(qkv_b, mapping.tp_rank, mapping.tp_size, + hidden_size, num_attention_heads, num_kv_heads) + weights.update( + get_tllm_linear_weight(qkv_w, f'{tllm_prex}.attention.qkv', + qkv_b)) + + if int8_kv_cache: + qkv_w_numpy = qkv_w.t().numpy() + if not multi_query_mode: + qkv_w_numpy = qkv_w_numpy.reshape(hidden_size, 3, hidden_size) + int8_weights = generate_int8(qkv_w_numpy, + act_range.get(f'{prefix}.attn.c_attn'), + is_qkv=True, + multi_query_mode=multi_query_mode) + weights[ + f'{tllm_prex}.attention.kv_cache_scaling_factor'] = torch.from_numpy( + np.array([int8_weights['scale_y_quant_orig']], + dtype=np.float32)).contiguous() + + attn_dense_w, attn_dense_b = get_weight_and_bias( + model_params, f'{prefix}.attn.c_proj', dtype) + if gpt_variant in ['gpt2', 'santacoder']: + attn_dense_w = attn_dense_w.t().contiguous() # transpose for Conv1D + if use_smooth_quant: + attn_dense_w_numpy = attn_dense_w.t().numpy() + int8_weights = generate_int8(attn_dense_w_numpy, + act_range.get(f'{prefix}.attn.c_proj')) + # change it to the real smoother if dense layer is applied smooth quant + fake_smoother_value = torch.ones([1, hidden_size], + dtype=torch.float32) + weights.update( + get_tllm_linear_sq_weight( + int8_weights, + f'{tllm_prex}.attention.dense.', [1, hidden_size], + tensor_parallel, + is_qkv=False, + per_token=per_token, + per_channel=per_channel, + last_prefix= + f'{tllm_prex}.attention.quantization_scaling_factor', + bias=attn_dense_b, + smoother_value=fake_smoother_value, + smoother_shape=[1, hidden_size // tensor_parallel], + rank=rank, + cat_dim=0)) + else: + attn_dense_w = split(attn_dense_w, + mapping.tp_rank, + mapping.tp_size, + is_column=False) + weights.update( + get_tllm_linear_weight(attn_dense_w, + f'{tllm_prex}.attention.dense', + attn_dense_b)) + + mlp_fc_w, mlp_fc_b = get_weight_and_bias(model_params, + f'{prefix}.mlp.c_fc', dtype) + if gpt_variant in ['gpt2', 'santacoder']: + mlp_fc_w = mlp_fc_w.t().contiguous() # transpose for Conv1D + if use_smooth_quant: + mlp_fc_w_numpy = mlp_fc_w.t().numpy() + int8_weights = generate_int8(mlp_fc_w_numpy, + act_range.get(f'{prefix}.mlp.c_fc')) + mlp_fc_b = split(mlp_fc_b, + mapping.tp_rank, + mapping.tp_size, + is_column=True) + weights.update( + get_tllm_linear_sq_weight( + int8_weights, + f'{tllm_prex}.mlp.fc.', + [1, 4 * hidden_size // tensor_parallel], + tensor_parallel, + is_qkv=False, + per_token=per_token, + per_channel=per_channel, + last_prefix=f'{tllm_prex}.post_layernorm.scale_to_int', + bias=mlp_fc_b, + smoother_value=None, + smoother_shape=None, + rank=rank, + cat_dim=-1)) + else: + mlp_fc_w = split(mlp_fc_w, + mapping.tp_rank, + mapping.tp_size, + is_column=True) + mlp_fc_b = split(mlp_fc_b, + mapping.tp_rank, + mapping.tp_size, + is_column=True) + weights.update( + get_tllm_linear_weight(mlp_fc_w, f'{tllm_prex}.mlp.fc', + mlp_fc_b)) + + mlp_proj_w, mlp_proj_b = get_weight_and_bias(model_params, + f'{prefix}.mlp.c_proj', + dtype) + if gpt_variant in ['gpt2', 'santacoder']: + mlp_proj_w = mlp_proj_w.t().contiguous() # transpose for Conv1D + if use_smooth_quant: + mlp_proj_w_numpy = mlp_proj_w.t().numpy() + int8_weights = generate_int8(mlp_proj_w_numpy, + act_range.get(f'{prefix}.mlp.c_proj')) + # change it to the real smoother if proj layer is applied smooth quant + fake_smoother_value = torch.ones([1, 4 * hidden_size], + dtype=torch.float32) + weights.update( + get_tllm_linear_sq_weight( + int8_weights, + f'{tllm_prex}.mlp.proj.', [1, hidden_size], + tensor_parallel, + is_qkv=False, + per_token=per_token, + per_channel=per_channel, + last_prefix=f'{tllm_prex}.mlp.quantization_scaling_factor', + bias=mlp_proj_b, + smoother_value=fake_smoother_value, + smoother_shape=[1, 4 * hidden_size // tensor_parallel], + rank=rank, + cat_dim=0)) + else: + mlp_proj_w = split(mlp_proj_w, + mapping.tp_rank, + mapping.tp_size, + is_column=False) + weights.update( + get_tllm_linear_weight(mlp_proj_w, f'{tllm_prex}.mlp.proj', + mlp_proj_b)) + + input_ln_w, input_ln_b = get_weight_and_bias(model_params, + f'{prefix}.ln_1', dtype) + weights[f'{tllm_prex}.input_layernorm.weight'] = input_ln_w + if input_ln_b is not None: + weights[f'{tllm_prex}.input_layernorm.bias'] = input_ln_b + + post_ln_w, post_ln_b = get_weight_and_bias(model_params, + f'{prefix}.ln_2', dtype) + weights[f'{tllm_prex}.post_layernorm.weight'] = post_ln_w + if post_ln_b is not None: + weights[f'{tllm_prex}.post_layernorm.bias'] = post_ln_b + + if mapping.is_first_pp_rank(): + embed_w = get_weight(model_params, 'transformer.wte', dtype) + if not use_parallel_embedding: + weights['transformer.vocab_embedding.weight'] = embed_w + else: + if sharding_dim == 0: + if vocab_size % mapping.tp_size != 0: + vocab_size_padded = pad_vocab_size(vocab_size, + mapping.tp_size) + pad_width = vocab_size_padded - vocab_size + embed_w = torch.nn.functional.pad(embed_w, + (0, 0, 0, pad_width), + value=0) + else: + assert hidden_size % mapping.tp_size == 0 + weights['transformer.vocab_embedding.weight'] = split( + embed_w, + mapping.tp_rank, + mapping.tp_size, + is_column=(sharding_dim == 0)) + + pos_embed_w = get_weight(model_params, 'transformer.wpe', dtype) + if pos_embed_w is not None: + weights['transformer.position_embedding.weight'] = pos_embed_w + + if mapping.is_last_pp_rank(): + embed_w = get_weight(model_params, 'transformer.wte', dtype) + if not share_embedding_table: + if vocab_size % mapping.tp_size != 0: + vocab_size_padded = pad_vocab_size(vocab_size, mapping.tp_size) + pad_width = vocab_size_padded - vocab_size + embed_w = torch.nn.functional.pad(embed_w, (0, 0, 0, pad_width), + value=0) + weights['lm_head.weight'] = split(embed_w.clone(), + mapping.tp_rank, + mapping.tp_size, + is_column=True) + ln_f_w, ln_f_b = get_weight_and_bias(model_params, 'transformer.ln_f', + dtype) + weights['transformer.ln_f.weight'] = ln_f_w + if ln_f_b is not None: + weights['transformer.ln_f.bias'] = ln_f_b + + tok = time.time() + t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) + print(f'Weights loaded. Total time: {t}') + return weights + + +def cpu_map_location(storage, loc): + return storage.cpu() + + +def gpu_map_location(storage, loc): + if loc.startswith("cuda"): + training_gpu_idx = int(loc.split(":")[1]) + inference_gpu_idx = training_gpu_idx % torch.cuda.device_count() + return storage.cuda(inference_gpu_idx) + elif loc.startswith("cpu"): + return storage.cpu() + else: + raise ValueError(f"Not handled {loc}") + + +# The field names are the same as in .nemo config file +# Defaults and their locations in NeMo code are given for commit 9c7926db4ae375b77dae7eb57656213de1dd76a5 in main branch +# The commit from main is used instead of a release because there are `rotary_base` commit was introduced recently. +NemoRotaryEmbeddingParameters = namedtuple( + "NemoRotaryEmbeddingParameters", + [ + "position_embedding_type", "rotary_percentage", + "seq_len_interpolation_factor", "rotary_base" + ], + defaults=[ + # "position_embedding_type", the default is taken from + # https://github.com/NVIDIA/NeMo/blob/9c7926db4ae375b77dae7eb57656213de1dd76a5/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py#L370 + "learned_absolute", + # "rotary_percentage", the default is taken from + # https://github.com/NVIDIA/NeMo/blob/9c7926db4ae375b77dae7eb57656213de1dd76a5/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py#L370 + 1.0, + # "seq_len_interpolation_factor", the default is take from + # https://github.com/NVIDIA/NeMo/blob/9c7926db4ae375b77dae7eb57656213de1dd76a5/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py#L388 + None, + # "rotary_base", the default is taken from + # https://github.com/NVIDIA/NeMo/blob/9c7926db4ae375b77dae7eb57656213de1dd76a5/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py#L389 + 10000, + ]) + + +def set_parameter_from_config(params: Dict[str, Any], nemo_config: Dict[str, + Any], + param_name: str) -> None: + if param_name in nemo_config: + params[param_name] = nemo_config[param_name] + else: + LOGGER.debug( + f"A parameter '{param_name}' is missing in nemo checkpoint. " + f"The default value {repr(NemoRotaryEmbeddingParameters._field_defaults[param_name])} will be used." + ) + + +def extract_rotary_parameters_from_nemo_config( + nemo_config: Dict[str, Any]) -> NemoRotaryEmbeddingParameters: + params = {} + set_parameter_from_config(params, nemo_config, "position_embedding_type") + set_parameter_from_config(params, nemo_config, "rotary_percentage") + set_parameter_from_config(params, nemo_config, + "seq_len_interpolation_factor") + set_parameter_from_config(params, nemo_config, "rotary_base") + return NemoRotaryEmbeddingParameters(**params) + + +def nemo_to_gpt_config(nemo_model_config, vocab_size, eos_id, bos_id): + convertion_dict = { + "activation_function": "activation", + "layer_norm_epsilon": "layernorm_epsilon", + "n_embd": "hidden_size", + "n_head": "num_attention_heads", + "n_layer": "num_layers", + "n_positions": "max_position_embeddings", + "rotary_pct": "rotary_percentage", + "bias": "bias", + "intermediate_size": "ffn_hidden_size", + } + + kwargs = { + key: nemo_model_config[value] + for key, value in convertion_dict.items() if value in nemo_model_config + } + kwargs["vocab_size"] = vocab_size + kwargs["eos_token_id"] = eos_id + kwargs["bos_token_id"] = bos_id + + return GPT2Config(**kwargs) + + +def copy_tokenizer_files(config, out_dir): + basenames = { + "model": "tokenizer", + "vocab_file": "vocab", + "merge_file": "merges", + } + + for key in basenames.keys(): + if config[key] is None: + continue + path = Path(config[key]) + if not path.exists(): + LOGGER.debug(f"Tokenizer {key}: {path} file not found") + continue + + dst_path = out_dir / f"{basenames[key]}{path.suffix}" + LOGGER.debug(f"Copy tokenizer {key}: {path}->{dst_path}") + shutil.copy(path.as_posix(), dst_path.as_posix()) + + +def add_rotary_parameters_to_ini_config( + config: configparser.ConfigParser, + rotary_parameters: NemoRotaryEmbeddingParameters) -> None: + if rotary_parameters.position_embedding_type == "rope": + if rotary_parameters.rotary_percentage > 1.0 or rotary_parameters.rotary_percentage <= 0.0: + raise ValueError( + f"Rotary percentage has to suffice 0.0 < rotary_percentage <= 1.0, whereas " + f"rotary_percentage={rotary_parameters.rotary_percentage}") + config["gpt"]["rotary_pct"] = str(rotary_parameters.rotary_percentage) + config["gpt"]["rotary_base"] = str(rotary_parameters.rotary_base) + if rotary_parameters.seq_len_interpolation_factor is not None: + if rotary_parameters.seq_len_interpolation_factor <= 1.0: + raise ValueError( + f"Rotary scaling is supported only for seq_len_interpolation_factor > 1.0. " + f"Got seq_len_interpolation_factor={rotary_parameters.seq_len_interpolation_factor}" + ) + config["gpt"]["rotary_scaling_type"] = "linear" + config["gpt"]["rotary_scaling_factor"] = str( + float(rotary_parameters.seq_len_interpolation_factor)) + else: + # As in HF rotary_pct > 0.0 triggers RoPE. Dislabe RoPE if different embedding type is used + config["gpt"]["rotary_pct"] = "0.0" + + +def update_tokenizer_paths(tokenizer_config: Dict, + tokenizer_file_paths: Dict[str, Optional[str]]): + for key, new_path in tokenizer_file_paths.items(): + old_path = tokenizer_config[key] + if old_path is None: + continue + old_path = Path(old_path) + if new_path: + LOGGER.debug(f"Update tokenizer {key} {old_path} -> {new_path}") + tokenizer_config[key] = new_path.as_posix() + elif not old_path.exists(): + LOGGER.warning( + f"Tokenizer {key}'s path {old_path} does not exists: set it to None" + ) + tokenizer_config[key] = None + return tokenizer_config + + +def build_tokenizer(tokenizer_config: Dict): + if tokenizer_config["library"] == "sentencepiece": + tokenizer = T5Tokenizer(tokenizer_config["model"], extra_ids=0) + elif "GPT2" in tokenizer_config["type"]: + tokenizer = GPT2Tokenizer(tokenizer_config["vocab_file"], + tokenizer_config["merge_file"]) + else: + raise ValueError( + f'Tokenizer type {tokenizer_config["library"]} not handled') + + if tokenizer.bos_token_id is None: + tokenizer.add_special_tokens({"bos_token": ""}) + if tokenizer.eos_token_id is None: + tokenizer.add_special_tokens({"eos_token": ""}) + + return tokenizer + + +def get_eos_bos_ids_from_tokenizer_config( + tokenizer_config: Dict[str, Any]) -> Tuple[int, int]: + tokenizer = build_tokenizer(tokenizer_config) + return tokenizer.eos_token_id, tokenizer.bos_token_id + + +def nemo_config_to_ini_config( + nemo_model_config: Dict[str, Any], + eos_id: int, + bos_id: int, + vocab_size: int, + storage_type: str, +) -> configparser.ConfigParser: + gpt_model_config = nemo_to_gpt_config(nemo_model_config, vocab_size, eos_id, + bos_id) + config = configparser.ConfigParser() + config["gpt"] = {k: str(v) for k, v in vars(gpt_model_config).items()} + config["gpt"]["storage_dtype"] = storage_type + add_rotary_parameters_to_ini_config( + config, extract_rotary_parameters_from_nemo_config(nemo_model_config)) + return config + + +def add_special_tokens_to_tokenizer(tokenizer): + + # Need to add cls, sep, mask tokens to the tokenizer if they don't exist. + # If cls, sep and mask are not attributes of the tokenizer, add it. + if not hasattr(tokenizer, 'cls_token'): + tokenizer.add_special_tokens({'cls_token': ''}) + if not hasattr(tokenizer.tokenizer, 'sep_id'): + tokenizer.add_special_tokens({'sep_token': ''}) + if not hasattr(tokenizer.tokenizer, 'mask_id'): + tokenizer.add_special_tokens({'mask_token': ''}) + + # bos, eos, pad and unk may be present in the provided spm .model file, if they are, use it. + if not hasattr(tokenizer, 'pad_token'): + if hasattr(tokenizer.tokenizer, + 'pad_id') and tokenizer.tokenizer.pad_id() > 0: + tokenizer.pad_token = tokenizer.tokenizer.id_to_piece( + tokenizer.tokenizer.pad_id()) + else: + tokenizer.add_special_tokens({'pad_token': ''}) + else: + tokenizer.add_special_tokens({'pad_token': ''}) + + if not hasattr(tokenizer, 'bos_token'): + if hasattr(tokenizer.tokenizer, + 'bos_id') and tokenizer.tokenizer.bos_id() > 0: + tokenizer.bos_token = tokenizer.tokenizer.id_to_piece( + tokenizer.tokenizer.bos_id()) + else: + tokenizer.add_special_tokens({'bos_token': ''}) + else: + tokenizer.add_special_tokens({'bos_token': ''}) + + if not hasattr(tokenizer, 'eos_token'): + if hasattr(tokenizer.tokenizer, + 'eos_id') and tokenizer.tokenizer.eos_id() > 0: + tokenizer.eos_token = tokenizer.tokenizer.id_to_piece( + tokenizer.tokenizer.eos_id()) + else: + tokenizer.add_special_tokens({'eos_token': ''}) + else: + tokenizer.add_special_tokens({'eos_token': ''}) + + +def unpack_nemo_ckpt(nemo_archive_path: Union[str, Path], + out_dir_path: Union[str, Path]): + nemo_archive_path = Path(nemo_archive_path) + if not nemo_archive_path.exists(): + raise FileNotFoundError(f"{nemo_archive_path} does not exist") + + for tar_mode in ["r:", "r:gz"]: + try: + with tarfile.open(nemo_archive_path, mode=tar_mode) as tar_file: + + def is_within_directory(directory, target): + + abs_directory = os.path.abspath(directory) + abs_target = os.path.abspath(target) + + prefix = os.path.commonprefix([abs_directory, abs_target]) + + return prefix == abs_directory + + def safe_members(tar_file): + members = [] + for member in tar_file.getmembers(): + member_path = os.path.join(out_dir_path, member.name) + if not is_within_directory(out_dir_path, member_path): + raise Exception( + "Attempted Path Traversal in Tar File") + members.append(member) + return members + + tar_file.extractall(out_dir_path, + members=safe_members(tar_file), + numeric_owner=False) + + return out_dir_path + except tarfile.ReadError: + pass + + raise RuntimeError(f"Could not unpack {nemo_archive_path}") + + +def extract_layers_with_prefix(model_, prefix): + length_to_trim = len(prefix) + model_state = model_.get("state_dict", model_) + return { + key[length_to_trim:]: model_state[key] + for key in model_state.keys() if prefix in key + } + + +class UnpackedNemoCheckpointDir: + + def __init__(self, + checkpoints_dir: Union[str, Path], + load_checkpoints_to_cpu: bool = False): + self._checkpoints_dir = Path(checkpoints_dir) + self._load_checkpoints_to_cpu = load_checkpoints_to_cpu + + @property + @functools.lru_cache + def model_config(self): + model_config = None + + model_config_filename = "model_config.yaml" + model_configs_paths = list( + self._checkpoints_dir.rglob(model_config_filename)) + if model_configs_paths: + if len(model_configs_paths) > 1: + raise RuntimeError( + f"There are more than single {model_config_filename} " + f"in {self._checkpoints_dir}: {', '.join(map(lambda p: p.as_posix(), model_configs_paths))}" + ) + model_config_path = model_configs_paths[0] + LOGGER.debug("Loading model config from %s", model_config_path) + with model_config_path.open("r") as model_config_file: + model_config = yaml.load(model_config_file, + Loader=yaml.SafeLoader) + else: + LOGGER.debug("Searching model config in checkpoints") + # try to obtain from checkpoint + checkpoint_name = self.checkpoint_name + checkpoints_paths = sorted( + self._checkpoints_dir.rglob(checkpoint_name)) + if checkpoints_paths: + # assume that parallel ranks 0 checkpoint should have model config embedded + checkpoint_path = checkpoints_paths[0] + + map_location_fn = cpu_map_location if self._load_checkpoints_to_cpu else gpu_map_location + + model_00 = torch.load(checkpoint_path, + map_location=map_location_fn) + if "hyper_parameters" in model_00 and "cfg" in model_00[ + "hyper_parameters"]: + model_config = model_00["hyper_parameters"]["cfg"] + LOGGER.debug("Loaded model config from checkpoint %s", + checkpoint_path) + else: + LOGGER.debug("Could not find model config in checkpoint %s", + checkpoint_path) + del model_00 + + if model_config is None: + LOGGER.warning( + "Could not find checkpoint with NeMo model config in %s", + self._checkpoints_dir) + + LOGGER.debug("Loaded model config %s", model_config) + + return model_config + + @property + def checkpoints_dir(self): + return self._checkpoints_dir + + def get_checkpoints_paths(self, + tensor_model_parallel_size=1, + pipeline_model_parallel_size=1): + """ + Injects tensor/pipeline model parallel ranks into the filepath. + Does nothing if not using model parallelism. + """ + + checkpoint_path_without_rank = self.checkpoints_dir / self.checkpoint_name + + def _inject_parallel_ranks(tp_rank, pp_rank): + if tensor_model_parallel_size > 1 or pipeline_model_parallel_size > 1: + if pipeline_model_parallel_size is None or pipeline_model_parallel_size == 1: + checkpoint_path = (checkpoint_path_without_rank.parent / + f"mp_rank_{tp_rank:02d}" / + checkpoint_path_without_rank.name) + else: + checkpoint_path = ( + checkpoint_path_without_rank.parent / + f"tp_rank_{tp_rank:02d}_pp_rank_{pp_rank:03d}" / + checkpoint_path_without_rank.name) + return checkpoint_path + else: + return checkpoint_path_without_rank + + return [[ + _inject_parallel_ranks(tp_rank=tp_rank, pp_rank=pp_rank) + for pp_rank in range(pipeline_model_parallel_size) + ] for tp_rank in range(tensor_model_parallel_size)] + + @property + @functools.lru_cache + def checkpoint_name(self): + patterns = [ + "model_weights.ckpt", # older megatron checkpoints + "*last.ckpt", # newer format of checkpoints + ] + for pattern in patterns: + model_files = sorted(list(self._checkpoints_dir.rglob(pattern))) + if model_files: + return model_files[0].name + + raise ValueError( + f"Could not find checkpoint files in {self._checkpoints_dir}") + + @functools.lru_cache + def get_tokenizer_file_path(self, tokenizer_key, file_key, + default_filename_pattern): + model_config = self.model_config + file_property = None + if tokenizer_key in model_config and file_key in model_config[ + tokenizer_key]: + file_property = model_config[tokenizer_key][file_key] + elif file_key in model_config: + file_property = model_config[file_key] + + LOGGER.debug("model_config[%s][%s]=%s", tokenizer_key, file_key, + file_property) + + if file_property and file_property.startswith("nemo:"): + filename = file_property.split("nemo:")[1] + filename_pattern = f"*{filename}" + elif file_property and file_property.startswith("/artifacts/"): + filename = Path(file_property).name + filename_pattern = f"*{filename}" + elif file_property is None or file_property == "None": + filename_pattern = None + else: + filename_pattern = default_filename_pattern + LOGGER.warning( + f"Tokenizer file from config: {tokenizer_key}.{file_key}={file_property} " + f"looks like unsupported path. Pattern {filename_pattern} will be used." + ) + + file_path = None + if filename_pattern is not None: + files_paths = list(self._checkpoints_dir.glob(filename_pattern)) + if files_paths: + assert len(files_paths) == 1 + file_path = files_paths[0] + + return file_path + + @functools.lru_cache + def get_all_tokenizer_file_paths(self): + return { + "model": + self.get_tokenizer_file_path("tokenizer", "model", "*.model"), + "vocab_file": + self.get_tokenizer_file_path("tokenizer", "vocab_file", "*vocab*"), + "merge_file": + self.get_tokenizer_file_path("tokenizer", "merge_file", + "*merge*.txt"), + } + + +def load_nemo_gpt_config( + unpacked_checkpoints_dir: UnpackedNemoCheckpointDir) -> GPT2Config: + nemo_model_config = unpacked_checkpoints_dir.model_config + + training_tp_size = nemo_model_config.get("tensor_model_parallel_size", 1) + training_pp_size = nemo_model_config.get("pipeline_model_parallel_size", 1) + + checkpoints_paths = unpacked_checkpoints_dir.get_checkpoints_paths( + training_tp_size, + training_pp_size, + ) + if unpacked_checkpoints_dir._load_checkpoints_to_cpu: + map_location_fn = cpu_map_location + else: + map_location_fn = gpu_map_location + model_00 = torch.load(checkpoints_paths[0][0], map_location=map_location_fn) + vocab_size = model_00[ + "model.language_model.embedding.word_embeddings.weight"].shape[ + 0] * training_tp_size + del model_00 + + hf_config = GPT2Config( + vocab_size=vocab_size, + n_positions=nemo_model_config['max_position_embeddings'], + n_embd=nemo_model_config['hidden_size'], + n_layer=nemo_model_config['num_layers'], + n_head=nemo_model_config['num_attention_heads'], + n_inner=nemo_model_config['ffn_hidden_size'], + activation_function=nemo_model_config['activation'], + layer_norm_epsilon=nemo_model_config['layernorm_epsilon'], + ) + hf_config.n_kv_head = hf_config.n_head + hf_config.bias = nemo_model_config['bias'] + # hf_config.apply_query_key_layer_scaling = nemo_model_config['apply_query_key_layer_scaling'] + hf_config.apply_query_key_layer_scaling = False + hf_config.position_embedding_type = 'rope_gpt_neox' + hf_config.rotary_pct = nemo_model_config['rotary_percentage'] + + tokenizer_config = update_tokenizer_paths( + nemo_model_config["tokenizer"], + unpacked_checkpoints_dir.get_all_tokenizer_file_paths()) + + return hf_config, tokenizer_config + + +@torch.no_grad() +def convert_nemo_gpt(unpacked_checkpoints_dir: UnpackedNemoCheckpointDir, + mapping: Mapping, + dtype: str = 'float32'): + nemo_model_config = unpacked_checkpoints_dir.model_config + + checkpoints_paths = unpacked_checkpoints_dir.get_checkpoints_paths( + nemo_model_config.get("tensor_model_parallel_size", 1), + nemo_model_config.get("pipeline_model_parallel_size", 1), + ) + + if unpacked_checkpoints_dir._load_checkpoints_to_cpu: + map_location_fn = cpu_map_location + else: + map_location_fn = gpu_map_location + dtype = str_dtype_to_torch(dtype) + + # load position_embedding from rank 0 + model_00 = torch.load(checkpoints_paths[0][0], map_location=map_location_fn) + model_00 = model_00.get("state_dict", model_00) + + has_position_embedding = "model.language_model.embedding.position_embeddings.weight" in model_00 + has_lm_head = "model.language_model.output_layer.weight" in model_00 + + num_layers = nemo_model_config["num_layers"] + training_tp_size = nemo_model_config.get("tensor_model_parallel_size", 1) + training_pp_size = nemo_model_config.get("pipeline_model_parallel_size", 1) + inference_tp_size = mapping.tp_size + inference_tp_rank = mapping.tp_rank + + apply_layernorm_1p = (nemo_model_config.get('normalization', + '') == "layernorm1p") + split_gated_activation = ("swiglu" + in nemo_model_config.get('activation', "gelu")) + num_attention_heads = nemo_model_config["num_attention_heads"] + # use_attention_nemo_shape = True + transpose_weights = True + # multi_query_mode = False + local_dim = None + + # merge_factor: how many TP training nodes are merged into an inference TP node + # split_factor: in how many parts a TP training node is split + gcd = np.gcd(training_tp_size, inference_tp_size) + merge_factor = training_tp_size // gcd + split_factor = inference_tp_size // gcd + + model_level_weights = defaultdict(list) + + def handle_model_level_weights(model, tp_idx: int, pp_idx: int): + if tp_idx == 0 and pp_idx == 0: + if has_position_embedding: + val = model[ + "model.language_model.embedding.position_embeddings.weight"] + model_level_weights[ + "transformer.position_embedding.weight"].append(val) + if pp_idx == 0: + val = model.get( + "state_dict", + model)["model.language_model.embedding.word_embeddings.weight"] + model_level_weights["transformer.vocab_embedding.weight"].append( + val) + if has_lm_head and pp_idx == training_pp_size - 1: + val = model.get("state_dict", + model)["model.language_model.output_layer.weight"] + model_level_weights["lm_head.weight"].append(val) + + weights = {} + tik = time.time() + tp_rank = inference_tp_rank // split_factor + # for tp_rank in range(training_tp_size // merge_factor): + for pp_rank in range(training_pp_size): + models = [] + for k in range(merge_factor): + rank_weights = checkpoints_paths[tp_rank * merge_factor + + k][pp_rank] + model = torch.load(rank_weights, map_location=map_location_fn) + handle_model_level_weights(model, tp_rank * merge_factor + k, + pp_rank) + layers = extract_layers_with_prefix( + model, "model.language_model.encoder.") + models.append(layers) + + for name in models[0].keys(): + params = [model[name] for model in models] + if transpose_weights and params[0].ndim == 2: + params = [p.T for p in params] + if "layernorm.weight" in name and apply_layernorm_1p: + params = [p + 1.0 for p in params] + + l = retrieved_layer_index_from_name(name) + if l is not None: + new_l = l + pp_rank * num_layers // training_pp_size + prefix = f'transformer.layers.{new_l}' + + if 'attention.query_key_value' in name: + if name.endswith('weight'): + hidden_dim = params[0].shape[0] + if local_dim is None: + local_dim = params[0].shape[-1] // 3 + + # multi_query_mode = False; use_attention_nemo_shape = True + head_num = num_attention_heads // training_tp_size + size_per_head = hidden_dim // num_attention_heads + params = [ + param.reshape(hidden_dim, head_num, 3, + size_per_head) for param in params + ] + params = [param.permute(0, 2, 1, 3) for param in params] + params = [ + param.reshape(hidden_dim, 3, local_dim) + for param in params + ] + cat_dim = -1 + param = torch.concat(params, dim=cat_dim) + param = torch.chunk(param, split_factor, + dim=cat_dim)[inference_tp_rank % + split_factor] + weights[ + f'{prefix}.attention.qkv.weight'] = param.reshape( + hidden_dim, -1).t() + else: + if local_dim is None: + local_dim = params[0].shape[-1] // 3 + + # multi_query_mode = False; use_attention_nemo_shape = True + head_num = num_attention_heads // training_tp_size + size_per_head = local_dim // head_num + params = [ + param.reshape(head_num, 3, size_per_head) + for param in params + ] + params = [param.permute(1, 0, 2) for param in params] + params = [ + param.reshape(3, local_dim) for param in params + ] + cat_dim = -1 + param = torch.concat(params, dim=cat_dim) + param = torch.chunk(param, split_factor, + dim=cat_dim)[inference_tp_rank % + split_factor] + weights[f'{prefix}.attention.qkv.bias'] = param.reshape( + -1) + + elif 'attention.dense' in name: + if name.endswith('weight'): + cat_dim = 0 + param = torch.concat(params, dim=cat_dim) + param = torch.chunk(param, split_factor, + dim=cat_dim)[inference_tp_rank % + split_factor] + weights[f'{prefix}.attention.dense.weight'] = param.t() + else: + weights[f'{prefix}.attention.dense.bias'] = params[0] + + elif 'mlp.dense_h_to_4h' in name: + if name.endswith('weight'): + if split_gated_activation: + params = [torch.chunk(p, 2, dim=-1) for p in params] + params, gate_params = list(zip(*params)) + cat_dim = -1 + param = torch.concat(params, dim=cat_dim) + param = torch.chunk(param, split_factor, + dim=cat_dim)[inference_tp_rank % + split_factor] + weights[f'{prefix}.mlp.fc.weight'] = param.t() + if split_gated_activation: + gate_param = torch.concat(gate_params, dim=cat_dim) + gate_param = torch.chunk( + gate_param, split_factor, + dim=cat_dim)[inference_tp_rank % split_factor] + weights[f'{prefix}.mlp.gate.weight'] = gate_param.t( + ) + else: + if split_gated_activation: + params = [torch.chunk(p, 2, dim=-1) for p in params] + params, gate_params = list(zip(*params)) + cat_dim = -1 + param = torch.concat(params, dim=cat_dim) + param = torch.chunk(param, split_factor, + dim=cat_dim)[inference_tp_rank % + split_factor] + weights[f'{prefix}.mlp.fc.bias'] = param + if split_gated_activation: + gate_param = torch.concat(gate_params, dim=cat_dim) + gate_param = torch.chunk( + gate_param, split_factor, + dim=cat_dim)[inference_tp_rank % split_factor] + weights[f'{prefix}.mlp.gate.bias'] = gate_param + + elif 'mlp.dense_4h_to_h' in name: + if name.endswith('weight'): + cat_dim = 0 + param = torch.concat(params, dim=cat_dim) + param = torch.chunk(param, split_factor, + dim=cat_dim)[inference_tp_rank % + split_factor] + weights[f'{prefix}.mlp.proj.weight'] = param.t() + else: + weights[f'{prefix}.mlp.proj.bias'] = params[0] + + elif 'input_layernorm' in name: + if name.endswith('weight'): + weights[f'{prefix}.input_layernorm.weight'] = params[0] + else: + weights[f'{prefix}.input_layernorm.bias'] = params[0] + elif 'post_attention_layernorm' in name: + if name.endswith('weight'): + weights[f'{prefix}.post_layernorm.weight'] = params[0] + else: + weights[f'{prefix}.post_layernorm.bias'] = params[0] + + elif 'final_layernorm' in name: + if name.endswith('weight'): + weights['transformer.ln_f.weight'] = params[0] + else: + weights['transformer.ln_f.bias'] = params[0] + + for key, params in model_level_weights.items(): + weights[key] = torch.concat(params, dim=0) + + weights = { + key: param.to(dtype).contiguous() + for key, param in weights.items() + } + + tok = time.time() + t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) + print(f'Weights loaded. Total time: {t}') + return weights + + +if __name__ == '__main__': + # TODO(qijun): Currently, the convert script depends on a torch op: + # torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix, + # which is included in tensorrt_llm Python package. Otherwise, the convert + # script does not need to import tensorrt_llm. Will remove it after reimplementing + # the op with PyTorch. + print(tensorrt_llm.__version__) + args = parse_arguments() + world_size = args.tp_size * args.pp_size + + tik = time.time() + + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + + quant_algo = None + kv_cache_quant_algo = None + plugin_weight_only_quant_type = None + if args.use_weight_only: + if args.weight_only_precision == 'int8': + plugin_weight_only_quant_type = torch.int8 + quant_algo = 'W8A16' + elif args.weight_only_precision == 'int4': + plugin_weight_only_quant_type = torch.quint4x2 + quant_algo = 'W4A16' + elif args.smoothquant: + if args.per_token and args.per_channel: + quant_algo = 'W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN' + elif not args.per_token and not args.per_channel: + quant_algo = 'W8A8_SQ_PER_TENSOR_PLUGIN' + elif not args.per_token and args.per_channel: + quant_algo = 'W8A8_SQ_PER_CHANNEL_PER_TENSOR_PLUGIN' + elif args.per_token and not args.per_channel: + quant_algo = 'W8A8_SQ_PER_TENSOR_PER_TOKEN_PLUGIN' + + if args.int8_kv_cache: + kv_cache_quant_algo = "INT8" + + if args.model_dir is not None: + hf_config, gpt_variant = load_gpt_config(args.model_dir, + args.gpt_variant) + elif args.nemo_ckpt_path is not None: + nemo_dir = Path(args.output_dir) / "unpacked" + nemo_dir = unpack_nemo_ckpt(args.nemo_ckpt_path, nemo_dir) + unpacked_checkpoints_dir = UnpackedNemoCheckpointDir( + nemo_dir, load_checkpoints_to_cpu=not args.load_nemo_on_gpu) + hf_config, tokenizer_config = load_nemo_gpt_config( + unpacked_checkpoints_dir) + copy_tokenizer_files(tokenizer_config, Path(args.output_dir)) + args.use_parallel_embedding = True + args.embedding_sharding_dim = 0 + else: + raise NotImplementedError("No source model path specified!") + + config = { + 'architecture': + 'GPTForCausalLM', + 'dtype': + args.dtype, + 'num_hidden_layers': + hf_config.n_layer, + 'num_attention_heads': + hf_config.n_head, + 'num_key_value_heads': + hf_config.n_kv_head, + 'hidden_size': + hf_config.n_embd, + 'intermediate_size': + hf_config.n_inner, + 'norm_epsilon': + hf_config.layer_norm_epsilon, + 'vocab_size': + hf_config.vocab_size, + 'position_embedding_type': + getattr(hf_config, 'position_embedding_type', 'learned_absolute'), + 'max_position_embeddings': + hf_config.n_positions, + 'hidden_act': + hf_config.activation_function, + 'use_parallel_embedding': + args.use_parallel_embedding, + 'embedding_sharding_dim': + args.embedding_sharding_dim, + 'share_embedding_table': + args.use_embedding_sharing, + 'quantization': { + 'quant_algo': quant_algo, + 'kv_cache_quant_algo': kv_cache_quant_algo, + }, + 'mapping': { + 'world_size': world_size, + 'tp_size': args.tp_size, + 'pp_size': args.pp_size, + }, + 'bias': + getattr(hf_config, 'bias', True), + 'apply_query_key_layer_scaling': + getattr(hf_config, 'apply_query_key_layer_scaling', False), + 'rotary_pct': + getattr(hf_config, 'rotary_pct', 1.0), + 'max_lora_rank': + args.max_lora_rank, + 'lora_target_modules': + args.lora_target_modules, + } + + with open(os.path.join(args.output_dir, 'config.json'), 'w') as f: + json.dump(config, f, indent=4) + + if args.model_dir is not None: + hf_model = AutoModelForCausalLM.from_pretrained(args.model_dir, + trust_remote_code=True, + device_map="auto", + torch_dtype="auto") + if args.smoothquant is not None or args.int8_kv_cache: + os.environ["TOKENIZERS_PARALLELISM"] = os.environ.get( + "TOKENIZERS_PARALLELISM", "false") + dataset = load_dataset("lambada", + split="validation", + cache_dir=args.dataset_cache_dir) + tokenizer = AutoTokenizer.from_pretrained(args.model_dir) + act_range = capture_activation_range(hf_model, tokenizer, dataset) + if args.smoothquant is not None: + smooth_gpt_model(hf_model, act_range, args.smoothquant) + + def covert_and_save(rank): + mapping = Mapping(world_size=world_size, + rank=rank, + tp_size=args.tp_size, + pp_size=args.pp_size) + + if args.model_dir is not None: + if args.smoothquant is not None or args.int8_kv_cache: + weights = convert_hf_gpt_legacy( + hf_model, + hf_config, + gpt_variant, + mapping, + dtype=args.dtype, + use_parallel_embedding=args.use_parallel_embedding, + sharding_dim=args.embedding_sharding_dim, + share_embedding_table=args.use_embedding_sharing, + use_smooth_quant=(args.smoothquant is not None), + per_channel=args.per_channel, + per_token=args.per_token, + int8_kv_cache=args.int8_kv_cache, + act_range=act_range, + ) + else: + weights = convert_hf_gpt( + hf_model, + hf_config, + gpt_variant, + mapping, + dtype=args.dtype, + use_parallel_embedding=args.use_parallel_embedding, + sharding_dim=args.embedding_sharding_dim, + share_embedding_table=args.use_embedding_sharing, + use_weight_only=args.use_weight_only, + plugin_weight_only_quant_type=plugin_weight_only_quant_type, + ) + + elif args.nemo_ckpt_path is not None: + weights = convert_nemo_gpt(unpacked_checkpoints_dir, mapping, + args.dtype) + + safetensors.torch.save_file( + weights, os.path.join(args.output_dir, f'rank{rank}.safetensors')) + + if args.workers == 1: + for rank in range(world_size): + covert_and_save(rank) + else: + with ThreadPoolExecutor(max_workers=args.workers) as p: + futures = [ + p.submit(covert_and_save, rank) for rank in range(world_size) + ] + exceptions = [] + for future in as_completed(futures): + try: + future.result() + except Exception as e: + traceback.print_exc() + exceptions.append(e) + assert len( + exceptions + ) == 0, "Checkpoint conversion failed, please check error log." + + if args.model_dir is not None: + del hf_model + elif args.nemo_ckpt_path is not None: + shutil.rmtree(nemo_dir) + + tok = time.time() + t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) + print(f'Total time of converting checkpoints: {t}') diff --git a/examples/gpt/hf_gpt_convert.py b/examples/gpt/hf_gpt_convert.py deleted file mode 100644 index c80f84bae..000000000 --- a/examples/gpt/hf_gpt_convert.py +++ /dev/null @@ -1,346 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -''' -Convert huggingface GPT model. Use https://huggingface.co/gpt2 as demo. -''' -import argparse -import configparser -import dataclasses -import os -import platform -from pathlib import Path - -import torch -import torch.multiprocessing as multiprocessing -from smoothquant import capture_activation_range, smooth_gemm -from tqdm import tqdm -from transformers import AutoModelForCausalLM # transformers-4.10.0-py3 -from transformers import AutoTokenizer -from transformers.models.gpt2.modeling_gpt2 import GPT2Block -from utils.convert import split_and_save_weight - -from tensorrt_llm._utils import str_dtype_to_torch, torch_to_numpy - - -@dataclasses.dataclass(frozen=True) -class ProgArgs: - out_dir: str - in_file: str - tensor_parallelism: int = 1 - processes: int = 4 - calibrate_kv_cache: bool = False - smoothquant: float = None - model: str = "gpt" - storage_type: str = "fp32" - dataset_cache_dir: str = None - load_model_on_cpu: bool = False - convert_model_on_cpu: bool = False - - @staticmethod - def parse(args=None) -> 'ProgArgs': - parser = argparse.ArgumentParser( - formatter_class=argparse.RawTextHelpFormatter) - parser.add_argument('--out-dir', - '-o', - type=str, - help='file name of output directory', - required=True) - parser.add_argument('--in-file', - '-i', - type=str, - help='file name of input checkpoint file', - required=True) - parser.add_argument('--tensor-parallelism', - '-tp', - type=int, - help='Requested tensor parallelism for inference', - default=1) - parser.add_argument( - "--processes", - "-p", - type=int, - help= - "How many processes to spawn for conversion (default: 4). Set it to a lower value to reduce RAM usage.", - default=4) - parser.add_argument( - "--calibrate-kv-cache", - "-kv", - action="store_true", - help= - "Generate scaling factors for KV cache. Used for storing KV cache in int8." - ) - parser.add_argument( - "--smoothquant", - "-sq", - type=float, - default=None, - help="Set the α parameter (see https://arxiv.org/pdf/2211.10438.pdf)" - " to Smoothquant the model, and output int8 weights." - " A good first try is 0.5. Must be in [0, 1]") - parser.add_argument( - "--model", - default="gpt2", - type=str, - help="Specify GPT variants to convert checkpoints correctly", - choices=["gpt2", "santacoder", "starcoder", "starcoder2"]) - parser.add_argument("--storage-type", - "-t", - type=str, - default="float32", - choices=["float32", "float16", "bfloat16"]) - parser.add_argument("--dataset-cache-dir", - type=str, - default=None, - help="cache dir to load the hugging face dataset") - parser.add_argument("--load-model-on-cpu", action="store_true") - parser.add_argument("--convert-model-on-cpu", action="store_true") - return ProgArgs(**vars(parser.parse_args(args))) - - -@torch.no_grad() -def smooth_gpt_model(model, scales, alpha): - # Smooth the activation and weights with smoother = $\diag{s}$ - for name, module in model.named_modules(): - if not isinstance(module, GPT2Block): - continue - - # qkv_proj - layer_name = name + ".attn.c_attn" - smoother = smooth_gemm(module.attn.c_attn.weight.T, - scales[layer_name]["x"], module.ln_1.weight, - module.ln_1.bias, alpha) - scales[layer_name]["x"] = scales[layer_name]["x"] / smoother - scales[layer_name]["w"] = module.attn.c_attn.weight.abs().max(dim=0)[0] - - # fc1 - layer_name = name + ".mlp.c_fc" - smoother = smooth_gemm(module.mlp.c_fc.weight.T, - scales[layer_name]["x"], module.ln_2.weight, - module.ln_2.bias, alpha) - scales[layer_name]["x"] = scales[layer_name]["x"] / smoother - scales[layer_name]["w"] = module.mlp.c_fc.weight.abs().max(dim=0)[0] - - -# SantaCoder separates Q projection from KV projection -def concat_qkv_weight_bias(q, hf_key, hf_model, model_type): - if model_type == "starcoder2": - k = hf_model.state_dict()[hf_key.replace("q_proj", - "k_proj")].to(q.device) - v = hf_model.state_dict()[hf_key.replace("q_proj", - "v_proj")].to(q.device) - if len(q.shape) == 2: - k = k.transpose(0, 1) - v = v.transpose(0, 1) - return torch.cat([q, k, v], dim=-1) - else: - kv = hf_model.state_dict()[hf_key.replace("q_attn", - "kv_attn")].to(q.device) - return torch.cat([q, kv], dim=-1) - - -# StarCoder uses nn.Linear for these following ops whose weight matrix is transposed compared to transformer.Conv1D -def transpose_weights(hf_name, param, model_type): - - weight_to_transpose = [] - if model_type == "starcoder": - weight_to_transpose = ["c_attn", "c_proj", "c_fc"] - elif model_type == "starcoder2": - weight_to_transpose = ["self_attn", "c_proj", "c_fc"] - if any([k in hf_name for k in weight_to_transpose]): - if len(param.shape) == 2: - param = param.transpose(0, 1) - return param - - -def gpt_to_ft_name(orig_name): - global_weights = { - "transformer.wpe.weight": "model.wpe", - "transformer.wte.weight": "model.wte", - "transformer.ln_f.bias": "model.final_layernorm.bias", - "transformer.ln_f.weight": "model.final_layernorm.weight", - "lm_head.weight": "model.lm_head.weight", - # StarCoder2 - "model.embed_tokens.weight": "model.wte", - "model.norm.weight": "model.final_layernorm.weight", - "model.norm.bias": "model.final_layernorm.bias" - } - - if orig_name in global_weights: - return global_weights[orig_name] - - _, _, layer_idx, *weight_name = orig_name.split(".") - layer_idx = int(layer_idx) - weight_name = "transformer." + ".".join(weight_name) - - per_layer_weights = { - "transformer.ln_1.bias": "input_layernorm.bias", - "transformer.ln_1.weight": "input_layernorm.weight", - "transformer.attn.c_attn.bias": "attention.query_key_value.bias", - "transformer.attn.c_attn.weight": "attention.query_key_value.weight", - "transformer.attn.q_attn.weight": "attention.query.weight", - "transformer.attn.q_attn.bias": "attention.query.bias", - "transformer.attn.kv_attn.weight": "attention.key_value.weight", - "transformer.attn.kv_attn.bias": "attention.key_value.bias", - "transformer.attn.c_proj.bias": "attention.dense.bias", - "transformer.attn.c_proj.weight": "attention.dense.weight", - "transformer.ln_2.bias": "post_attention_layernorm.bias", - "transformer.ln_2.weight": "post_attention_layernorm.weight", - "transformer.mlp.c_fc.bias": "mlp.dense_h_to_4h.bias", - "transformer.mlp.c_fc.weight": "mlp.dense_h_to_4h.weight", - "transformer.mlp.c_proj.bias": "mlp.dense_4h_to_h.bias", - "transformer.mlp.c_proj.weight": "mlp.dense_4h_to_h.weight", - # StarCoder2 - "transformer.input_layernorm.bias": "input_layernorm.bias", - "transformer.input_layernorm.weight": "input_layernorm.weight", - "transformer.self_attn.q_proj.bias": "attention.query.bias", - "transformer.self_attn.q_proj.weight": "attention.query.weight", - "transformer.self_attn.k_proj.weight": "attention.key.weight", - "transformer.self_attn.k_proj.bias": "attention.key.bias", - "transformer.self_attn.v_proj.weight": "attention.value.weight", - "transformer.self_attn.v_proj.bias": "attention.value.bias", - "transformer.self_attn.o_proj.bias": "attention.dense.bias", - "transformer.self_attn.o_proj.weight": "attention.dense.weight", - "transformer.post_attention_layernorm.bias": - "post_attention_layernorm.bias", - "transformer.post_attention_layernorm.weight": - "post_attention_layernorm.weight", - "transformer.mlp.c_fc.bias": "mlp.dense_h_to_4h.bias", - "transformer.mlp.c_fc.weight": "mlp.dense_h_to_4h.weight", - "transformer.mlp.c_proj.bias": "mlp.dense_4h_to_h.bias", - "transformer.mlp.c_proj.weight": "mlp.dense_4h_to_h.weight" - } - return f"layers.{layer_idx}.{per_layer_weights[weight_name]}" - - -@torch.no_grad() -def hf_gpt_converter(args: ProgArgs): - infer_tp = args.tensor_parallelism - multi_query_mode = True if args.model in ["santacoder", "starcoder" - ] else False - saved_dir = Path(args.out_dir) / f"{infer_tp}-gpu" - saved_dir.mkdir(parents=True, exist_ok=True) - - # load position_embedding from rank 0 - model = AutoModelForCausalLM.from_pretrained(args.in_file, - torch_dtype="auto", - device_map="auto", - trust_remote_code=True) - if args.load_model_on_cpu: - model = model.cpu() - torch.cuda.empty_cache() - act_range = {} - if args.smoothquant is not None or args.calibrate_kv_cache: - os.environ["TOKENIZERS_PARALLELISM"] = os.environ.get( - "TOKENIZERS_PARALLELISM", "false") - from datasets import load_dataset - dataset = load_dataset("lambada", - split="validation", - cache_dir=args.dataset_cache_dir) - act_range = capture_activation_range( - model, AutoTokenizer.from_pretrained(args.in_file), dataset) - if args.smoothquant is not None: - smooth_gpt_model(model, act_range, args.smoothquant) - - config = configparser.ConfigParser() - config["gpt"] = {} - for key in vars(args): - config["gpt"][key] = f"{vars(args)[key]}" - for k, v in vars(model.config).items(): - config["gpt"][k] = f"{v}" - config["gpt"]["storage_dtype"] = args.storage_type - config["gpt"]["multi_query_mode"] = str(multi_query_mode) - num_attention_heads = int(config['gpt'].get("num_attention_heads", 0)) - num_key_value_heads = 1 if multi_query_mode else int(config['gpt'].get( - "num_key_value_heads", num_attention_heads)) - with open(saved_dir / "config.ini", 'w') as configfile: - config.write(configfile) - - storage_type = str_dtype_to_torch(args.storage_type) - - global_ft_weights = [ - "model.wpe", "model.wte", "model.final_layernorm.bias", - "model.final_layernorm.weight", "model.lm_head.weight" - ] - - int8_outputs = None - if args.calibrate_kv_cache: - int8_outputs = "kv_cache_only" - if args.smoothquant is not None: - int8_outputs = "all" - - starmap_args = [] - for name, param in model.named_parameters(): - if "weight" not in name and "bias" not in name: - continue - ft_name = gpt_to_ft_name(name) - - if args.convert_model_on_cpu: - param = param.cpu() - param = transpose_weights(name, param, args.model) - if ft_name in global_ft_weights: - torch_to_numpy(param.to(storage_type).cpu()).tofile( - saved_dir / f"{ft_name}.bin") - else: - if 'q_attn' in name or 'q_proj' in name: - param = concat_qkv_weight_bias(param, name, model, args.model) - ft_name = ft_name.replace("query", "query_key_value") - # Needed by QKV projection weight split. With multi_query_mode one does not simply take - # out_dim and divide it by 3 to get local_dim because out_dim = local_dim + 2 * head_size - local_dim = model.transformer.h[ - 0].attn.embed_dim if multi_query_mode else None - if args.processes == 1: - split_and_save_weight( - 0, saved_dir, infer_tp, ft_name, param.to(storage_type), - storage_type, act_range.get(name.replace(".weight", "")), { - "int8_outputs": int8_outputs, - "multi_query_mode": multi_query_mode, - "local_dim": local_dim, - "num_attention_heads": num_attention_heads, - "num_key_value_heads": num_key_value_heads - }) - else: - starmap_args.append( - (0, saved_dir, infer_tp, ft_name, param.to(storage_type), - storage_type, act_range.get(name.replace(".weight", "")), { - "int8_outputs": int8_outputs, - "multi_query_mode": multi_query_mode, - "local_dim": local_dim, - "num_attention_heads": num_attention_heads, - "num_key_value_heads": num_key_value_heads - })) - - starmap_args = tqdm(starmap_args, desc="saving weights") - if args.processes > 1: - with multiprocessing.Pool(args.processes) as pool: - pool.starmap(split_and_save_weight, starmap_args) - - -def run_conversion(args: ProgArgs): - if args.processes > 1 and platform.system() == "Windows": - print( - "Resetting processes to 1 because multi-process on Windows is not implemented." - ) - args = dataclasses.replace(args, processes=1) - - print("\n=============== Arguments ===============") - for key, value in vars(args).items(): - print(f"{key}: {value}") - print("========================================") - hf_gpt_converter(args) - - -if __name__ == "__main__": - torch.multiprocessing.set_start_method("spawn") - run_conversion(ProgArgs.parse()) diff --git a/examples/gpt/nemo_ckpt_convert.py b/examples/gpt/nemo_ckpt_convert.py deleted file mode 100755 index 79fe40673..000000000 --- a/examples/gpt/nemo_ckpt_convert.py +++ /dev/null @@ -1,263 +0,0 @@ -#! /usr/bin/env python3 -# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import argparse -import datetime -import logging -import multiprocessing -import sys -import tempfile -from collections import defaultdict -from pathlib import Path - -import numpy as np -import torch -from tqdm import tqdm -from utils.convert import (cpu_map_location, gpu_map_location, - split_and_save_weight) -from utils.nemo import (UnpackedNemoCheckpointDir, copy_tokenizer_files, - extract_layers_with_prefix, - get_eos_bos_ids_from_tokenizer_config, - nemo_config_to_ini_config, unpack_nemo_ckpt, - update_tokenizer_paths) - -from tensorrt_llm._utils import str_dtype_to_torch, torch_to_numpy - -LOGGER = logging.getLogger(__name__) - - -def rename_key(old_key: str, pp_rank: int, num_layers: int, pp_size: int): - new_key = old_key - - if "layers." in old_key: - split_key = old_key.split(".") - split_key[1] = str(int(split_key[1]) + pp_rank * num_layers // pp_size) - new_key = ".".join(split_key) - - if "self_attention" in new_key: - new_key = new_key.replace("self_attention", "attention") - return new_key - - -@torch.no_grad() -def convert_checkpoint(unpacked_checkpoints_dir: UnpackedNemoCheckpointDir, - args): - nemo_model_config = unpacked_checkpoints_dir.model_config - - checkpoints_paths = unpacked_checkpoints_dir.get_checkpoints_paths( - nemo_model_config.get("tensor_model_parallel_size", 1), - nemo_model_config.get("pipeline_model_parallel_size", 1), - ) - - # if checkpoints files could be found - start preparing output dir - out_dir = create_out_dir(args) - - map_location_fn = gpu_map_location if args.load_checkpoints_on_gpu else cpu_map_location - storage_type = str_dtype_to_torch(args.storage_type) - - # load position_embedding from rank 0 - model_00 = torch.load(checkpoints_paths[0][0], map_location=map_location_fn) - model_00 = model_00.get("state_dict", model_00) - - has_position_embedding = "model.language_model.embedding.position_embeddings.weight" in model_00 - has_lm_head = "model.language_model.output_layer.weight" in model_00 - - num_layers = nemo_model_config["num_layers"] - training_tp_size = nemo_model_config.get("tensor_model_parallel_size", 1) - training_pp_size = nemo_model_config.get("pipeline_model_parallel_size", 1) - inference_tp_size = args.tensor_parallelism - - export_config = { - "apply_layernorm_1p": - nemo_model_config.get('normalization', '') == "layernorm1p", - "tp_size": - training_tp_size, - "split_gated_activation": - "swiglu" in nemo_model_config.get('activation', "gelu"), - "num_attention_heads": - nemo_model_config["num_attention_heads"], - "use_attention_nemo_shape": - True, - "transpose_weights": - True, - } - - # merge_factor: how many TP training nodes are merged into an inference TP node - # split_factor: in how many parts a TP training node is split - gcd = np.gcd(training_tp_size, inference_tp_size) - merge_factor = training_tp_size // gcd - split_factor = inference_tp_size // gcd - - model_level_weights = defaultdict(list) - - def handle_model_level_weights(model, tp_idx: int, pp_idx: int): - if tp_idx == 0 and pp_idx == 0: - if has_position_embedding: - val = model[ - "model.language_model.embedding.position_embeddings.weight"] - # not weight, do not need to transpose - val = torch_to_numpy(val.to(storage_type).cpu()) - val.tofile(out_dir / "model.wpe.bin") - model_level_weights["model.wpe.bin"].append(val) - if pp_idx == 0: - val = model.get( - "state_dict", - model)["model.language_model.embedding.word_embeddings.weight"] - val = torch_to_numpy(val.to(storage_type).cpu()) - model_level_weights["model.wte.bin"].append(val) - if has_lm_head and pp_idx == training_pp_size - 1: - val = model.get("state_dict", - model)["model.language_model.output_layer.weight"] - val = torch_to_numpy(val.to(storage_type).cpu()) - model_level_weights["model.lm_head.weight.bin"].append(val) - - for tp_rank in range(training_tp_size // merge_factor): - for pp_rank in range(training_pp_size): - - models = [] - for k in range(merge_factor): - rank_weights = checkpoints_paths[tp_rank * merge_factor + - k][pp_rank] - model = torch.load(rank_weights, map_location=map_location_fn) - handle_model_level_weights(model, tp_rank * merge_factor + k, - pp_rank) - layers = extract_layers_with_prefix( - model, "model.language_model.encoder.") - models.append(layers) - - starmap_args = [] - for key in models[0].keys(): - starmap_args.append(( - tp_rank, - out_dir, - split_factor, - rename_key(key, pp_rank, num_layers, training_pp_size), - [model[key] for model in models], - storage_type, - None, - export_config, - )) - starmap_args = tqdm(starmap_args, desc="saving weights") - - if args.processes > 1: - with multiprocessing.Pool(args.processes) as pool: - pool.starmap(split_and_save_weight, starmap_args) - else: - # simpler for debug situations - for starmap_arg in starmap_args: - split_and_save_weight(*starmap_arg) - - for key, values in model_level_weights.items(): - model_level_weights[key] = np.concatenate(values, axis=0) - model_level_weights[key].tofile(out_dir / key) - vocab_size = model_level_weights["model.wte.bin"].shape[0] - tokenizer_config = update_tokenizer_paths( - nemo_model_config["tokenizer"], - unpacked_checkpoints_dir.get_all_tokenizer_file_paths()) - copy_tokenizer_files(tokenizer_config, out_dir) - ini_config = nemo_config_to_ini_config( - nemo_model_config, - *get_eos_bos_ids_from_tokenizer_config(tokenizer_config), vocab_size, - args.storage_type) - config_path = out_dir / "config.ini" - with config_path.open("w") as config_file: - ini_config.write(config_file) - - -def create_out_dir(args): - out_dir = Path(args.out_dir) / f"{args.tensor_parallelism}-gpu/" - if not out_dir.exists(): - out_dir.mkdir(parents=True) - return out_dir - - -def main(): - torch.multiprocessing.set_start_method("spawn") - torch.multiprocessing.set_sharing_strategy("file_system") - - parser = argparse.ArgumentParser( - formatter_class=argparse.RawTextHelpFormatter) - parser.add_argument('--out-dir', - '-o', - type=Path, - help='path to output directory', - required=True) - parser.add_argument('--in-file', - '-i', - type=Path, - help='path to input checkpoint file', - required=True) - parser.add_argument('--tensor-parallelism', - '-tp', - type=int, - help='Requested tensor parallelism for inference', - default=1) - parser.add_argument( - "--processes", - "-p", - type=int, - help= - "How many processes to spawn for conversion (default: 4). Set it to a lower value to reduce RAM usage.", - default=4) - parser.add_argument("--storage-type", - "-t", - type=str, - default="float32", - choices=["float32", "float16", "bfloat16"]) - parser.add_argument("--load-checkpoints-on-gpu", - action="store_true", - help="Whether to load model weights to GPU") - parser.add_argument("--verbose", - action="store_true", - help="Provide verbose messages") - args = parser.parse_args() - - log_format = "%(asctime)s %(name)s [%(levelname)s] %(message)s" - logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO, - format=log_format) - - print("\n=============== Argument ===============") - for key in vars(args): - print(f"{key}: {vars(args)[key]}") - print("========================================") - - if not args.in_file.exists(): - LOGGER.error("%s does not exists", args.in_file) - sys.exit(1) - - with tempfile.TemporaryDirectory() as temp_dir: - temp_dir = Path(temp_dir) - - # unpack if needed - if args.in_file.is_dir(): - nemo_dir = args.in_file - else: - start_time = datetime.datetime.now() - checkpoint_dir_path = temp_dir / "unpacked" - nemo_dir = unpack_nemo_ckpt(args.in_file, checkpoint_dir_path) - LOGGER.info("Spent %s (h:m:s) to unpack NeMo archive", - datetime.datetime.now() - start_time) - - unpacked_checkpoint_dir = UnpackedNemoCheckpointDir( - nemo_dir, load_checkpoints_to_cpu=not args.load_checkpoints_on_gpu) - - start_time = datetime.datetime.now() - convert_checkpoint(unpacked_checkpoint_dir, args) - LOGGER.info("Spent %s (h:m:s) to convert the model", - datetime.datetime.now() - start_time) - - -if __name__ == "__main__": - main() diff --git a/examples/gpt/nemo_lora_convert.py b/examples/gpt/nemo_lora_convert.py index 67fd350fb..ad508a9e4 100644 --- a/examples/gpt/nemo_lora_convert.py +++ b/examples/gpt/nemo_lora_convert.py @@ -22,11 +22,10 @@ import numpy as np import torch import yaml -from utils.convert import cpu_map_location -from utils.nemo import unpack_nemo_ckpt +from convert_checkpoint import cpu_map_location, unpack_nemo_ckpt from tensorrt_llm._utils import str_dtype_to_torch, to_json_file, torch_to_numpy -from tensorrt_llm.lora_manager import LoraConfig, get_all_nemo_lora_weights +from tensorrt_llm.lora_manager import LoraManager, get_all_nemo_lora_weights log_format = "%(asctime)s %(name)s [%(levelname)s] %(message)s" logging.basicConfig(format=log_format) @@ -121,7 +120,7 @@ def lora_convert_cpp_runtime(out_dir, weights.append(in_out_weights) weight_config.append( np.array([ - LoraConfig.LORA_MODULE_IDS["attn_qkv"], layer_idx, adapter_size + LoraManager.LORA_MODULE_IDS["attn_qkv"], layer_idx, adapter_size ], dtype=np.int32)) all_weights = np.expand_dims(np.stack(weights), 0) diff --git a/examples/gpt/nemo_prompt_convert.py b/examples/gpt/nemo_prompt_convert.py index 5105917d9..06eeec30f 100755 --- a/examples/gpt/nemo_prompt_convert.py +++ b/examples/gpt/nemo_prompt_convert.py @@ -22,8 +22,7 @@ import numpy as np import torch import yaml -from utils.convert import cpu_map_location -from utils.nemo import unpack_nemo_ckpt +from convert_checkpoint import cpu_map_location, unpack_nemo_ckpt from tensorrt_llm._utils import torch_to_numpy diff --git a/examples/gpt/requirements.txt b/examples/gpt/requirements.txt index 3d0c299a3..33618106d 100644 --- a/examples/gpt/requirements.txt +++ b/examples/gpt/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.9.0.dev2024031200 +tensorrt_llm==0.9.0.dev2024031900 datasets~=2.14.5 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/gpt/smoothquant.py b/examples/gpt/smoothquant.py deleted file mode 100644 index b774dded4..000000000 --- a/examples/gpt/smoothquant.py +++ /dev/null @@ -1,155 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -''' -Utilities for SmoothQuant models -''' - -import functools -from collections import defaultdict - -import torch -import torch.nn as nn -from tqdm import tqdm -from transformers.pytorch_utils import Conv1D - - -@torch.no_grad() -def apply_smoothing(scales, - gemm_weights, - layernorm_weights=None, - layernorm_bias=None, - dtype=torch.float32, - layernorm_1p=False): - if not isinstance(gemm_weights, list): - gemm_weights = [gemm_weights] - - if layernorm_weights is not None: - assert layernorm_weights.numel() == scales.numel() - layernorm_weights.div_(scales).to(dtype) - if layernorm_bias is not None: - assert layernorm_bias.numel() == scales.numel() - layernorm_bias.div_(scales).to(dtype) - if layernorm_1p: - layernorm_weights += (1 / scales) - 1 - - for gemm in gemm_weights: - gemm.mul_(scales.view(1, -1)).to(dtype) - - -@torch.no_grad() -def smooth_gemm(gemm_weights, - act_scales, - layernorm_weights=None, - layernorm_bias=None, - alpha=0.5, - weight_scales=None): - if not isinstance(gemm_weights, list): - gemm_weights = [gemm_weights] - orig_dtype = gemm_weights[0].dtype - - for gemm in gemm_weights: - # gemm_weights are expected to be transposed - assert gemm.shape[1] == act_scales.numel() - - if weight_scales is None: - weight_scales = torch.cat( - [gemm.abs().max(dim=0, keepdim=True)[0] for gemm in gemm_weights], - dim=0) - weight_scales = weight_scales.max(dim=0)[0] - weight_scales.to(float).clamp(min=1e-5) - scales = (act_scales.to(gemm_weights[0].device).to(float).pow(alpha) / - weight_scales.pow(1 - alpha)).clamp(min=1e-5) - - apply_smoothing(scales, gemm_weights, layernorm_weights, layernorm_bias, - orig_dtype) - - return scales - - -@torch.no_grad() -def smooth_ln_fcs(ln, fcs, act_scales, alpha=0.5): - if not isinstance(fcs, list): - fcs = [fcs] - for fc in fcs: - assert isinstance(fc, nn.Linear) - assert ln.weight.numel() == fc.in_features == act_scales.numel() - - device, dtype = fcs[0].weight.device, fcs[0].weight.dtype - act_scales = act_scales.to(device=device, dtype=dtype) - weight_scales = torch.cat( - [fc.weight.abs().max(dim=0, keepdim=True)[0] for fc in fcs], dim=0) - weight_scales = weight_scales.max(dim=0)[0].clamp(min=1e-5) - - scales = (act_scales.pow(alpha) / - weight_scales.pow(1 - alpha)).clamp(min=1e-5).to(device).to(dtype) - - if ln is not None: - ln.weight.div_(scales) - ln.bias.div_(scales) - - for fc in fcs: - fc.weight.mul_(scales.view(1, -1)) - return scales - - -@torch.no_grad() -def capture_activation_range(model, - tokenizer, - dataset, - num_samples=512, - seq_len=512): - model.eval() - device = next(model.parameters()).device - act_scales = defaultdict(lambda: {"x": None, "y": None, "w": None}) - - def stat_tensor(name, tensor, act_scales, key): - hidden_dim = tensor.shape[-1] - tensor = tensor.view(-1, hidden_dim).abs().detach() - comming_max = torch.max(tensor, dim=0)[0].float() - - if act_scales[name][key] is None: - act_scales[name][key] = comming_max - else: - act_scales[name][key] = torch.max(act_scales[name][key], - comming_max) - - def stat_input_hook(m, x, y, name): - if isinstance(x, tuple): - x = x[0] - stat_tensor(name, x, act_scales, "x") - stat_tensor(name, y, act_scales, "y") - - if act_scales[name]["w"] is None: - act_scales[name]["w"] = m.weight.abs().clip(1e-8, - None).max(dim=0)[0] - - hooks = [] - for name, m in model.named_modules(): - if isinstance(m, nn.Linear) or isinstance(m, Conv1D): - hooks.append( - m.register_forward_hook( - functools.partial(stat_input_hook, name=name))) - - for i in tqdm(range(num_samples), desc="calibrating model"): - input_ids = tokenizer(dataset[i]["text"], - return_tensors="pt", - max_length=seq_len, - truncation=True).input_ids.to(device) - model(input_ids) - - for h in hooks: - h.remove() - - return act_scales diff --git a/examples/gpt/utils/convert.py b/examples/gpt/utils/convert.py deleted file mode 100644 index b80c8b5dc..000000000 --- a/examples/gpt/utils/convert.py +++ /dev/null @@ -1,360 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" - Utilities for exporting a model to our custom format. -""" - -import numpy as np -import torch - -from tensorrt_llm._utils import torch_to_numpy - - -def cpu_map_location(storage, loc): - return storage.cpu() - - -def gpu_map_location(storage, loc): - if loc.startswith("cuda"): - training_gpu_idx = int(loc.split(":")[1]) - inference_gpu_idx = training_gpu_idx % torch.cuda.device_count() - return storage.cuda(inference_gpu_idx) - elif loc.startswith("cpu"): - return storage.cpu() - else: - raise ValueError(f"Not handled {loc}") - - -def save_val(val, dir, key, tp_num=None): - suffix = "bin" if tp_num is None else f"{tp_num}.bin" - val.tofile(dir / f"model.{key}.{suffix}") - - -def save_split(split_vals, dir, key, i, split_factor): - for j, val in enumerate(split_vals): - save_val(val, dir, key, i * split_factor + j) - - -def generate_int8(weights, act_range, is_qkv=False, multi_query_mode=False): - """ - This function has two purposes: - - compute quantized weights, scaled either per-tensor or per-column - - compute scaling factors - - Depending on the GEMM API (CUTLASS/CUBLAS) the required scaling factors differ. - CUTLASS uses two sets of scaling factors. One for the activation X, one for the weight W. - CUBLAS only has one (we can't do per-row scaling). So we must provide pre-multiplied scaling factor. - - Here is the list of what we need (T means per-tensor, C per-column): - - scale_x_orig_quant puts fp activation into the quantized range (i.e. [-128, 127], for int8). Used before the GEMM. (T) - - scale_y_quant_orig puts quantized activation into the fp range. Used if the GEMM outputs int8. (T) - - scale_w_quant_orig puts weights from quant range to fp range (used with CUTLASS) (T, C) - - scale_y_accum_quant puts the GEMM result (XW) from accumulation range (int32) - to quant range (int8) (used for CUBLAS) (T, C) - - Note that we don't do anything special about row-parallel GEMM. Theoretically, we could have per-GPU scaling factors too, - but then the model would change depending on the number of GPUs used. - - For QKV projection, the behavior is special. Even if we have a single matrix to perform QKV projection, we consider it - as three different matrices: Q, K, and V. So per-tensor actually means one scaling factor for each Q, K and V. - """ - - # compute weight scaling factors for fp->int8 and int8->fp - if is_qkv and not multi_query_mode: - scale_w_orig_quant_t = 127. / act_range["w"].reshape(3, -1).max( - dim=-1, keepdims=True)[0].cpu().numpy() - scale_w_orig_quant_c = 127. / act_range["w"].reshape(3, - -1).cpu().numpy() - elif is_qkv and multi_query_mode: - raise ValueError( - f"Multi-query w/ int8 quant has not been supported yet") - else: - scale_w_orig_quant_t = 127. / act_range["w"].max().cpu().numpy() - scale_w_orig_quant_c = 127. / act_range["w"].cpu().numpy() - scale_w_quant_orig_t = 1.0 / scale_w_orig_quant_t - scale_w_quant_orig_c = 1.0 / scale_w_orig_quant_c - - # compute the rest of needed scaling factors - scale_x_orig_quant_t = np.array(127. / act_range["x"].max().item()) - scale_y_orig_quant_t = np.array(127. / act_range["y"].max().item()) - scale_y_quant_orig_t = np.array(act_range["y"].max().item() / 127.) - scale_y_accum_quant_t = scale_y_orig_quant_t / (scale_x_orig_quant_t * - scale_w_orig_quant_t) - scale_y_accum_quant_c = scale_y_orig_quant_t / (scale_x_orig_quant_t * - scale_w_orig_quant_c) - if is_qkv: - scale_y_accum_quant_t = np.broadcast_to(scale_y_accum_quant_t, - scale_w_orig_quant_c.shape) - scale_w_quant_orig_t = np.broadcast_to(scale_w_quant_orig_t, - scale_w_orig_quant_c.shape) - - to_i8 = lambda x: x.round().clip(-127, 127).astype(np.int8) - return { - "weight.int8": to_i8(weights * scale_w_orig_quant_t), - "weight.int8.col": to_i8(weights * scale_w_orig_quant_c), - "scale_x_orig_quant": scale_x_orig_quant_t.astype(np.float32), - "scale_w_quant_orig": scale_w_quant_orig_t.astype(np.float32), - "scale_w_quant_orig.col": scale_w_quant_orig_c.astype(np.float32), - "scale_y_accum_quant": scale_y_accum_quant_t.astype(np.float32), - "scale_y_accum_quant.col": scale_y_accum_quant_c.astype(np.float32), - "scale_y_quant_orig": scale_y_quant_orig_t.astype(np.float32), - } - - -def write_int8(vals, - dir, - base_key, - split_dim, - tp_rank, - split_factor, - kv_cache_only=False): - if not kv_cache_only: - save_split(np.split(vals["weight.int8"], split_factor, axis=split_dim), - dir, f"{base_key}.weight.int8", tp_rank, split_factor) - save_split( - np.split(vals["weight.int8.col"], split_factor, axis=split_dim), - dir, f"{base_key}.weight.int8.col", tp_rank, split_factor) - - saved_keys_once = ["scale_y_quant_orig"] - if not kv_cache_only: - saved_keys_once += [ - "scale_x_orig_quant", "scale_w_quant_orig", "scale_y_accum_quant" - ] - # per-column scaling factors are loaded per-gpu for ColumnParallel GEMMs (QKV, FC1) - if not kv_cache_only: - if split_dim == -1: - save_split( - np.split(vals["scale_w_quant_orig.col"], - split_factor, - axis=split_dim), dir, - f"{base_key}.scale_w_quant_orig.col", tp_rank, split_factor) - save_split( - np.split(vals["scale_y_accum_quant.col"], - split_factor, - axis=split_dim), dir, - f"{base_key}.scale_y_accum_quant.col", tp_rank, split_factor) - else: - saved_keys_once += [ - "scale_w_quant_orig.col", "scale_y_accum_quant.col" - ] - - if tp_rank == 0: - for save_key in saved_keys_once: - save_val(vals[save_key], dir, f"{base_key}.{save_key}") - - -# Note: in multi_query_mode, only query heads are split between multiple GPUs, while key/value head -# are not split as there is only one head per key/value. -@torch.no_grad() -def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, - storage_type, act_range, config): - use_attention_nemo_shape = config.get("use_attention_nemo_shape", False) - split_gated_activation = config.get("split_gated_activation", False) - multi_query_mode = config.get("multi_query_mode", False) - num_attention_heads = config.get("num_attention_heads", 0) - num_key_value_heads = config.get("num_key_value_heads", num_attention_heads) - tp_size = config.get("tp_size", 1) - int8_outputs = config.get("int8_outputs", None) - local_dim = config.get("local_dim", None) - - save_int8 = int8_outputs == "all" or int8_outputs == "kv_cache_only" - - if not isinstance(vals, list): - vals = [vals] - - if config.get("transpose_weights", False) and vals[0].ndim == 2: - vals = [val.T for val in vals] - if "layernorm.weight" in key and config.get("apply_layernorm_1p", False): - vals = [val + 1.0 for val in vals] - vals = [torch_to_numpy(val.cpu().to(storage_type)) for val in vals] - - if "input_layernorm.weight" in key or "input_layernorm.bias" in key or \ - "attention.dense.bias" in key or "post_attention_layernorm.weight" in key or \ - "post_attention_layernorm.bias" in key or "mlp.dense_4h_to_h.bias" in key or \ - "final_layernorm.weight" in key or "final_layernorm.bias" in key: - - # shared weights, only need to convert the weights of rank 0 - if tp_rank == 0: - save_val(vals[0], saved_dir, key) - - elif "attention.dense.weight" in key or "mlp.dense_4h_to_h.weight" in key: - cat_dim = 0 - val = np.concatenate(vals, axis=cat_dim) - split_vals = np.split(val, split_factor, axis=cat_dim) - save_split(split_vals, saved_dir, key, tp_rank, split_factor) - if act_range is not None and int8_outputs == "all": - base_key = key.replace(".weight", "") - vals_i8 = generate_int8(val, - act_range, - multi_query_mode=multi_query_mode) - write_int8(vals_i8, saved_dir, base_key, cat_dim, tp_rank, - split_factor) - - elif "mlp.dense_h_to_4h.weight" in key or "mlp.dense_h_to_4h.bias" in key: - if split_gated_activation: - splits = [np.split(val, 2, axis=-1) for val in vals] - vals, gates = list(zip(*splits)) - cat_dim = -1 - val = np.concatenate(vals, axis=cat_dim) - split_vals = np.split(val, split_factor, axis=cat_dim) - save_split(split_vals, saved_dir, key, tp_rank, split_factor) - if act_range is not None and int8_outputs == "all": - base_key = key.replace(".weight", "") - vals_i8 = generate_int8(val, - act_range, - multi_query_mode=multi_query_mode) - write_int8(vals_i8, saved_dir, base_key, cat_dim, tp_rank, - split_factor) - - if split_gated_activation: - assert not save_int8 - prefix, dot, suffix = key.rpartition(".") - key = prefix + ".gate" + dot + suffix - - gate = np.concatenate(gates, axis=cat_dim) - split_vals = np.split(gate, split_factor, axis=cat_dim) - save_split(split_vals, saved_dir, key, tp_rank, split_factor) - - elif "attention.query_key_value.bias" in key: - if local_dim is None: - local_dim = vals[0].shape[-1] // 3 - - if multi_query_mode: - val = vals[0] - # out_feature = local_dim + 2 * head_size; assumes local_dim equals to hidden_dim - b_q, b_kv = np.split(val, [local_dim], axis=-1) - b_q_split = np.split(b_q, split_factor, axis=-1) - split_vals = [np.concatenate((i, b_kv), axis=-1) for i in b_q_split] - elif num_attention_heads != num_key_value_heads: - # GQA mode - # split_vals = np.split(vals[0], split_factor, axis=-1) - assert num_key_value_heads % split_factor == 0 - val = vals[0] - qkv_hidden_dim = val.shape[0] - size_per_head = qkv_hidden_dim // (num_attention_heads + - 2 * num_key_value_heads) - num_attention_heads // num_key_value_heads - - val = val.reshape(num_attention_heads + 2 * num_key_value_heads, - size_per_head) - - # Split the QKV to separate variables. - qkv = np.split(val, [ - num_attention_heads, num_attention_heads + num_key_value_heads - ], - axis=0) - - q_split = np.split(qkv[0], split_factor, axis=0) - k_split = np.split(qkv[1], split_factor, axis=0) - v_split = np.split(qkv[2], split_factor, axis=0) - - # Concatenate Q, K, and V together - split_vals = [ - np.concatenate([ - q_split[i].reshape(-1), k_split[i].reshape(-1), - v_split[i].reshape(-1) - ], - axis=0) for i in range(split_factor) - ] - else: - if use_attention_nemo_shape: - head_num = num_attention_heads // tp_size - size_per_head = local_dim // head_num - nemo_shape = (head_num, 3, size_per_head) - vals = [val.reshape(nemo_shape) for val in vals] - vals = [val.transpose(1, 0, 2) for val in vals] - - vals = [val.reshape(3, local_dim) for val in vals] - val = np.concatenate(vals, axis=-1) - split_vals = np.split(val, split_factor, axis=-1) - save_split(split_vals, saved_dir, key, tp_rank, split_factor) - - elif "attention.query_key_value.weight" in key: - hidden_dim = vals[0].shape[0] - if local_dim is None: - local_dim = vals[0].shape[-1] // 3 - if multi_query_mode: - val = vals[0] - # out_feature = local_dim + 2 * head_size; assumes local_dim equals to hidden_dim - head_size = (val.shape[-1] - local_dim) // 2 - val = val.reshape(hidden_dim, local_dim + 2 * head_size) - w_q, w_kv = np.split(val, [local_dim], axis=-1) - w_q_split = np.split(w_q, split_factor, axis=-1) - split_vals = [np.concatenate((i, w_kv), axis=-1) for i in w_q_split] - elif num_attention_heads != num_key_value_heads: - # GQA mode. - assert num_key_value_heads % split_factor == 0 - val = vals[0] - size_per_head = hidden_dim // num_attention_heads - num_attention_heads // num_key_value_heads - - val = val.reshape(hidden_dim, - num_attention_heads + 2 * num_key_value_heads, - size_per_head) - - # Split the QKV to separate variables. - qkv = np.split(val, [ - num_attention_heads, num_attention_heads + num_key_value_heads - ], - axis=1) - - q_split = np.split(qkv[0], split_factor, axis=1) - k_split = np.split(qkv[1], split_factor, axis=1) - v_split = np.split(qkv[2], split_factor, axis=1) - - # Concatenate Q, K, and V together - split_vals = [ - np.concatenate([ - q_split[i].reshape(hidden_dim, -1), k_split[i].reshape( - hidden_dim, -1), v_split[i].reshape(hidden_dim, -1) - ], - axis=1) for i in range(split_factor) - ] - else: - if use_attention_nemo_shape: - head_num = num_attention_heads // tp_size - size_per_head = hidden_dim // num_attention_heads - vals = [ - val.reshape(hidden_dim, head_num, 3, size_per_head) - for val in vals - ] - vals = [val.transpose(0, 2, 1, 3) for val in vals] - - vals = [val.reshape(hidden_dim, 3, local_dim) for val in vals] - cat_dim = -1 - val = np.concatenate(vals, axis=cat_dim) - split_vals = np.split(val, split_factor, axis=cat_dim) - save_split(split_vals, saved_dir, key, tp_rank, split_factor) - if save_int8: - base_key = key.replace(".weight", "") - vals_i8 = generate_int8(val, - act_range, - is_qkv=True, - multi_query_mode=multi_query_mode) - write_int8(vals_i8, - saved_dir, - base_key, - cat_dim, - tp_rank, - split_factor, - kv_cache_only=int8_outputs == "kv_cache_only") - elif ("attention.query.weight" in key or "attention.query.bias" in key - or "attention.key_value.weight" in key - or "attention.key_value.bias" in key or "attention.key.weight" in key - or "attention.key.bias" in key or "attention.value.weight" in key - or "attention.value.bias" in key): - pass - else: - print(f"[WARNING] {key} not handled by converter") diff --git a/examples/gpt/utils/nemo.py b/examples/gpt/utils/nemo.py deleted file mode 100644 index cf2c27182..000000000 --- a/examples/gpt/utils/nemo.py +++ /dev/null @@ -1,456 +0,0 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import configparser -import functools -import logging -import os -import shutil -import tarfile -from collections import namedtuple -from pathlib import Path -from typing import Any, Dict, Optional, Tuple, Union - -import torch -import yaml -from transformers import GPT2Config, GPT2Tokenizer, T5Tokenizer -from utils.convert import cpu_map_location, gpu_map_location - -LOGGER = logging.getLogger(__name__) - -# The field names are the same as in .nemo config file -# Defaults and their locations in NeMo code are given for commit 9c7926db4ae375b77dae7eb57656213de1dd76a5 in main branch -# The commit from main is used instead of a release because there are `rotary_base` commit was introduced recently. -NemoRotaryEmbeddingParameters = namedtuple( - "NemoRotaryEmbeddingParameters", - [ - "position_embedding_type", "rotary_percentage", - "seq_len_interpolation_factor", "rotary_base" - ], - defaults=[ - # "position_embedding_type", the default is taken from - # https://github.com/NVIDIA/NeMo/blob/9c7926db4ae375b77dae7eb57656213de1dd76a5/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py#L370 - "learned_absolute", - # "rotary_percentage", the default is taken from - # https://github.com/NVIDIA/NeMo/blob/9c7926db4ae375b77dae7eb57656213de1dd76a5/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py#L370 - 1.0, - # "seq_len_interpolation_factor", the default is take from - # https://github.com/NVIDIA/NeMo/blob/9c7926db4ae375b77dae7eb57656213de1dd76a5/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py#L388 - None, - # "rotary_base", the default is taken from - # https://github.com/NVIDIA/NeMo/blob/9c7926db4ae375b77dae7eb57656213de1dd76a5/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py#L389 - 10000, - ]) - - -def set_parameter_from_config(params: Dict[str, Any], nemo_config: Dict[str, - Any], - param_name: str) -> None: - if param_name in nemo_config: - params[param_name] = nemo_config[param_name] - else: - LOGGER.debug( - f"A parameter '{param_name}' is missing in nemo checkpoint. " - f"The default value {repr(NemoRotaryEmbeddingParameters._field_defaults[param_name])} will be used." - ) - - -def extract_rotary_parameters_from_nemo_config( - nemo_config: Dict[str, Any]) -> NemoRotaryEmbeddingParameters: - params = {} - set_parameter_from_config(params, nemo_config, "position_embedding_type") - set_parameter_from_config(params, nemo_config, "rotary_percentage") - set_parameter_from_config(params, nemo_config, - "seq_len_interpolation_factor") - set_parameter_from_config(params, nemo_config, "rotary_base") - return NemoRotaryEmbeddingParameters(**params) - - -def nemo_to_gpt_config(nemo_model_config, vocab_size, eos_id, bos_id): - convertion_dict = { - "activation_function": "activation", - "layer_norm_epsilon": "layernorm_epsilon", - "n_embd": "hidden_size", - "n_head": "num_attention_heads", - "n_layer": "num_layers", - "n_positions": "max_position_embeddings", - "rotary_pct": "rotary_percentage", - "bias": "bias", - "intermediate_size": "ffn_hidden_size", - } - - kwargs = { - key: nemo_model_config[value] - for key, value in convertion_dict.items() if value in nemo_model_config - } - kwargs["vocab_size"] = vocab_size - kwargs["eos_token_id"] = eos_id - kwargs["bos_token_id"] = bos_id - - return GPT2Config(**kwargs) - - -def copy_tokenizer_files(config, out_dir): - basenames = { - "model": "tokenizer", - "vocab_file": "vocab", - "merge_file": "merges", - } - - for key in basenames.keys(): - if config[key] is None: - continue - path = Path(config[key]) - if not path.exists(): - LOGGER.debug(f"Tokenizer {key}: {path} file not found") - continue - - dst_path = out_dir / f"{basenames[key]}{path.suffix}" - LOGGER.debug(f"Copy tokenizer {key}: {path}->{dst_path}") - shutil.copy(path.as_posix(), dst_path.as_posix()) - - -def add_rotary_parameters_to_ini_config( - config: configparser.ConfigParser, - rotary_parameters: NemoRotaryEmbeddingParameters) -> None: - if rotary_parameters.position_embedding_type == "rope": - if rotary_parameters.rotary_percentage > 1.0 or rotary_parameters.rotary_percentage <= 0.0: - raise ValueError( - f"Rotary percentage has to suffice 0.0 < rotary_percentage <= 1.0, whereas " - f"rotary_percentage={rotary_parameters.rotary_percentage}") - config["gpt"]["rotary_pct"] = str(rotary_parameters.rotary_percentage) - config["gpt"]["rotary_base"] = str(rotary_parameters.rotary_base) - if rotary_parameters.seq_len_interpolation_factor is not None: - if rotary_parameters.seq_len_interpolation_factor <= 1.0: - raise ValueError( - f"Rotary scaling is supported only for seq_len_interpolation_factor > 1.0. " - f"Got seq_len_interpolation_factor={rotary_parameters.seq_len_interpolation_factor}" - ) - config["gpt"]["rotary_scaling_type"] = "linear" - config["gpt"]["rotary_scaling_factor"] = str( - float(rotary_parameters.seq_len_interpolation_factor)) - else: - # As in HF rotary_pct > 0.0 triggers RoPE. Dislabe RoPE if different embedding type is used - config["gpt"]["rotary_pct"] = "0.0" - - -def update_tokenizer_paths(tokenizer_config: Dict, - tokenizer_file_paths: Dict[str, Optional[str]]): - for key, new_path in tokenizer_file_paths.items(): - old_path = tokenizer_config[key] - if old_path is None: - continue - old_path = Path(old_path) - if new_path: - LOGGER.debug(f"Update tokenizer {key} {old_path} -> {new_path}") - tokenizer_config[key] = new_path.as_posix() - elif not old_path.exists(): - LOGGER.warning( - f"Tokenizer {key}'s path {old_path} does not exists: set it to None" - ) - tokenizer_config[key] = None - return tokenizer_config - - -def build_tokenizer(tokenizer_config: Dict): - if tokenizer_config["library"] == "sentencepiece": - tokenizer = T5Tokenizer(tokenizer_config["model"], extra_ids=0) - elif "GPT2" in tokenizer_config["type"]: - tokenizer = GPT2Tokenizer(tokenizer_config["vocab_file"], - tokenizer_config["merge_file"]) - else: - raise ValueError( - f'Tokenizer type {tokenizer_config["library"]} not handled') - - if tokenizer.bos_token_id is None: - tokenizer.add_special_tokens({"bos_token": ""}) - if tokenizer.eos_token_id is None: - tokenizer.add_special_tokens({"eos_token": ""}) - - return tokenizer - - -def get_eos_bos_ids_from_tokenizer_config( - tokenizer_config: Dict[str, Any]) -> Tuple[int, int]: - tokenizer = build_tokenizer(tokenizer_config) - return tokenizer.eos_token_id, tokenizer.bos_token_id - - -def nemo_config_to_ini_config( - nemo_model_config: Dict[str, Any], - eos_id: int, - bos_id: int, - vocab_size: int, - storage_type: str, -) -> configparser.ConfigParser: - gpt_model_config = nemo_to_gpt_config(nemo_model_config, vocab_size, eos_id, - bos_id) - config = configparser.ConfigParser() - config["gpt"] = {k: str(v) for k, v in vars(gpt_model_config).items()} - config["gpt"]["storage_dtype"] = storage_type - add_rotary_parameters_to_ini_config( - config, extract_rotary_parameters_from_nemo_config(nemo_model_config)) - return config - - -def add_special_tokens_to_tokenizer(tokenizer): - - # Need to add cls, sep, mask tokens to the tokenizer if they don't exist. - # If cls, sep and mask are not attributes of the tokenizer, add it. - if not hasattr(tokenizer, 'cls_token'): - tokenizer.add_special_tokens({'cls_token': ''}) - if not hasattr(tokenizer.tokenizer, 'sep_id'): - tokenizer.add_special_tokens({'sep_token': ''}) - if not hasattr(tokenizer.tokenizer, 'mask_id'): - tokenizer.add_special_tokens({'mask_token': ''}) - - # bos, eos, pad and unk may be present in the provided spm .model file, if they are, use it. - if not hasattr(tokenizer, 'pad_token'): - if hasattr(tokenizer.tokenizer, - 'pad_id') and tokenizer.tokenizer.pad_id() > 0: - tokenizer.pad_token = tokenizer.tokenizer.id_to_piece( - tokenizer.tokenizer.pad_id()) - else: - tokenizer.add_special_tokens({'pad_token': ''}) - else: - tokenizer.add_special_tokens({'pad_token': ''}) - - if not hasattr(tokenizer, 'bos_token'): - if hasattr(tokenizer.tokenizer, - 'bos_id') and tokenizer.tokenizer.bos_id() > 0: - tokenizer.bos_token = tokenizer.tokenizer.id_to_piece( - tokenizer.tokenizer.bos_id()) - else: - tokenizer.add_special_tokens({'bos_token': ''}) - else: - tokenizer.add_special_tokens({'bos_token': ''}) - - if not hasattr(tokenizer, 'eos_token'): - if hasattr(tokenizer.tokenizer, - 'eos_id') and tokenizer.tokenizer.eos_id() > 0: - tokenizer.eos_token = tokenizer.tokenizer.id_to_piece( - tokenizer.tokenizer.eos_id()) - else: - tokenizer.add_special_tokens({'eos_token': ''}) - else: - tokenizer.add_special_tokens({'eos_token': ''}) - - -def unpack_nemo_ckpt(nemo_archive_path: Union[str, Path], - out_dir_path: Union[str, Path]): - nemo_archive_path = Path(nemo_archive_path) - if not nemo_archive_path.exists(): - raise FileNotFoundError(f"{nemo_archive_path} does not exist") - - for tar_mode in ["r:", "r:gz"]: - try: - with tarfile.open(nemo_archive_path, mode=tar_mode) as tar_file: - - def is_within_directory(directory, target): - - abs_directory = os.path.abspath(directory) - abs_target = os.path.abspath(target) - - prefix = os.path.commonprefix([abs_directory, abs_target]) - - return prefix == abs_directory - - def safe_members(tar_file): - members = [] - for member in tar_file.getmembers(): - member_path = os.path.join(out_dir_path, member.name) - if not is_within_directory(out_dir_path, member_path): - raise Exception( - "Attempted Path Traversal in Tar File") - members.append(member) - return members - - tar_file.extractall(out_dir_path, - members=safe_members(tar_file), - numeric_owner=False) - - return out_dir_path - except tarfile.ReadError: - pass - - raise RuntimeError(f"Could not unpack {nemo_archive_path}") - - -def extract_layers_with_prefix(model_, prefix): - length_to_trim = len(prefix) - model_state = model_.get("state_dict", model_) - return { - key[length_to_trim:]: model_state[key] - for key in model_state.keys() if prefix in key - } - - -class UnpackedNemoCheckpointDir: - - def __init__(self, - checkpoints_dir: Union[str, Path], - load_checkpoints_to_cpu: bool = False): - self._checkpoints_dir = Path(checkpoints_dir) - self._load_checkpoints_to_cpu = load_checkpoints_to_cpu - - @property - @functools.lru_cache - def model_config(self): - model_config = None - - model_config_filename = "model_config.yaml" - model_configs_paths = list( - self._checkpoints_dir.rglob(model_config_filename)) - if model_configs_paths: - if len(model_configs_paths) > 1: - raise RuntimeError( - f"There are more than single {model_config_filename} " - f"in {self._checkpoints_dir}: {', '.join(map(lambda p: p.as_posix(), model_configs_paths))}" - ) - model_config_path = model_configs_paths[0] - LOGGER.debug("Loading model config from %s", model_config_path) - with model_config_path.open("r") as model_config_file: - model_config = yaml.load(model_config_file, - Loader=yaml.SafeLoader) - else: - LOGGER.debug("Searching model config in checkpoints") - # try to obtain from checkpoint - checkpoint_name = self.checkpoint_name - checkpoints_paths = sorted( - self._checkpoints_dir.rglob(checkpoint_name)) - if checkpoints_paths: - # assume that parallel ranks 0 checkpoint should have model config embedded - checkpoint_path = checkpoints_paths[0] - - map_location_fn = cpu_map_location if self._load_checkpoints_to_cpu else gpu_map_location - - model_00 = torch.load(checkpoint_path, - map_location=map_location_fn) - if "hyper_parameters" in model_00 and "cfg" in model_00[ - "hyper_parameters"]: - model_config = model_00["hyper_parameters"]["cfg"] - LOGGER.debug("Loaded model config from checkpoint %s", - checkpoint_path) - else: - LOGGER.debug("Could not find model config in checkpoint %s", - checkpoint_path) - - del model_00 - - if model_config is None: - LOGGER.warning( - "Could not find checkpoint with NeMo model config in %s", - self._checkpoints_dir) - - LOGGER.debug("Loaded model config %s", model_config) - - return model_config - - @property - def checkpoints_dir(self): - return self._checkpoints_dir - - def get_checkpoints_paths(self, - tensor_model_parallel_size=1, - pipeline_model_parallel_size=1): - """ - Injects tensor/pipeline model parallel ranks into the filepath. - Does nothing if not using model parallelism. - """ - - checkpoint_path_without_rank = self.checkpoints_dir / self.checkpoint_name - - def _inject_parallel_ranks(tp_rank, pp_rank): - if tensor_model_parallel_size > 1 or pipeline_model_parallel_size > 1: - if pipeline_model_parallel_size is None or pipeline_model_parallel_size == 1: - checkpoint_path = (checkpoint_path_without_rank.parent / - f"mp_rank_{tp_rank:02d}" / - checkpoint_path_without_rank.name) - else: - checkpoint_path = ( - checkpoint_path_without_rank.parent / - f"tp_rank_{tp_rank:02d}_pp_rank_{pp_rank:03d}" / - checkpoint_path_without_rank.name) - return checkpoint_path - else: - return checkpoint_path_without_rank - - return [[ - _inject_parallel_ranks(tp_rank=tp_rank, pp_rank=pp_rank) - for pp_rank in range(pipeline_model_parallel_size) - ] for tp_rank in range(tensor_model_parallel_size)] - - @property - @functools.lru_cache - def checkpoint_name(self): - patterns = [ - "model_weights.ckpt", # older megatron checkpoints - "*last.ckpt", # newer format of checkpoints - ] - for pattern in patterns: - model_files = sorted(list(self._checkpoints_dir.rglob(pattern))) - if model_files: - return model_files[0].name - - raise ValueError( - f"Could not find checkpoint files in {self._checkpoints_dir}") - - @functools.lru_cache - def get_tokenizer_file_path(self, tokenizer_key, file_key, - default_filename_pattern): - model_config = self.model_config - file_property = None - if tokenizer_key in model_config and file_key in model_config[ - tokenizer_key]: - file_property = model_config[tokenizer_key][file_key] - elif file_key in model_config: - file_property = model_config[file_key] - - LOGGER.debug("model_config[%s][%s]=%s", tokenizer_key, file_key, - file_property) - - if file_property and file_property.startswith("nemo:"): - filename = file_property.split("nemo:")[1] - filename_pattern = f"*{filename}" - elif file_property and file_property.startswith("/artifacts/"): - filename = Path(file_property).name - filename_pattern = f"*{filename}" - elif file_property is None or file_property == "None": - filename_pattern = None - else: - filename_pattern = default_filename_pattern - LOGGER.warning( - f"Tokenizer file from config: {tokenizer_key}.{file_key}={file_property} " - f"looks like unsupported path. Pattern {filename_pattern} will be used." - ) - - file_path = None - if filename_pattern is not None: - files_paths = list(self._checkpoints_dir.glob(filename_pattern)) - if files_paths: - assert len(files_paths) == 1 - file_path = files_paths[0] - - return file_path - - @functools.lru_cache - def get_all_tokenizer_file_paths(self): - return { - "model": - self.get_tokenizer_file_path("tokenizer", "model", "*.model"), - "vocab_file": - self.get_tokenizer_file_path("tokenizer", "vocab_file", "*vocab*"), - "merge_file": - self.get_tokenizer_file_path("tokenizer", "merge_file", - "*merge*.txt"), - } diff --git a/examples/gpt/weight.py b/examples/gpt/weight.py deleted file mode 100644 index f3594f1b5..000000000 --- a/examples/gpt/weight.py +++ /dev/null @@ -1,694 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import configparser -import logging -import time -from pathlib import Path - -import numpy as np -import torch - -import tensorrt_llm -from tensorrt_llm._utils import (numpy_to_torch, pad_vocab_size, - str_dtype_to_np, str_dtype_to_torch, - torch_to_numpy) -from tensorrt_llm.functional import is_gated_activation -from tensorrt_llm.models import GPTLMHeadModel -from tensorrt_llm.quantization import QuantMode - -LOGGER = logging.getLogger(__name__) - - -def gen_suffix(rank, use_smooth_quant, quant_per_channel): - suffix = f"{rank}.bin" - if use_smooth_quant: - sq_prefix = "int8." - if quant_per_channel: - sq_prefix += "col." - suffix = sq_prefix + suffix - return suffix - - -def extract_layer_idx(name): - ss = name.split('.') - for s in ss: - if s.isdigit(): - return s - return None - - -def split(v, tp_size, idx, dim=0): - if tp_size == 1: - return v - if len(v.shape) == 1: - return np.ascontiguousarray(np.split(v, tp_size)[idx]) - elif len(v.shape) == 2: - return np.ascontiguousarray(np.split(v, tp_size, axis=dim)[idx]) - return None - - -def parse_sc2_config(ini_file): - gpt_config = configparser.ConfigParser() - gpt_config.read(ini_file) - - n_embd = gpt_config.getint('gpt', 'hidden_size') - n_head = gpt_config.getint('gpt', 'num_attention_heads') - n_kv_head = gpt_config.getint('gpt', 'num_key_value_heads') - n_layer = gpt_config.getint('gpt', 'num_hidden_layers') - n_positions = gpt_config.getint('gpt', 'max_position_embeddings') - vocab_size = gpt_config.getint('gpt', 'vocab_size') - do_layer_norm_before = gpt_config.getboolean('gpt', - 'do_layer_norm_before', - fallback=True) - rotary_base = gpt_config.getfloat('gpt', 'rope_theta', fallback=None) - rotary_scaling_type = gpt_config.get('gpt', - 'rotary_scaling_type', - fallback=None) - rotary_scaling_factor = gpt_config.get('gpt', - 'rotary_scaling_factor', - fallback=None) - if rotary_scaling_type is None: - if rotary_scaling_factor is not None: - raise ValueError( - f"'rotary_scaling_factor={rotary_scaling_factor}' is found in ini " - f"config file {ini_file}, whereas 'rotary_scaling_type' is missing " - f"in the config. The 'rotary_scaling_factor' will be ignored and " - f"rotary scaling will not be used.") - rotary_scaling = None - else: - if rotary_scaling_factor is None: - raise ValueError( - f"'rotary_scaling_factor={rotary_scaling_factor}' was not found " - f"in ini config file {ini_file}, whereas 'rotary_scaling_type' is " - f"provided and equals {repr(rotary_scaling_type)}.") - rotary_scaling = [rotary_scaling_type, rotary_scaling_factor] - rotary_pct = 1.0 - hidden_act = "gelu" - bias = gpt_config.getboolean('gpt', 'use_bias', fallback=True) - inter_size = gpt_config.getint('gpt', 'intermediate_size', fallback=None) - dtype = gpt_config.get('gpt', 'storage_dtype', fallback='float32') - - if inter_size is None: - inter_size = 4 * n_embd - - multi_query_mode = gpt_config.getboolean('gpt', - 'multi_query_mode', - fallback=False) - prompt_num_tasks = gpt_config.getint('gpt', 'prompt_num_tasks', fallback=0) - prompt_max_vocab_size = gpt_config.getint('gpt', - 'prompt_max_vocab_size', - fallback=0) - return { - "n_embd": n_embd, - "n_head": n_head, - "n_kv_head": n_kv_head, - "n_layer": n_layer, - "n_positions": n_positions, - "vocab_size": vocab_size, - "do_layer_norm_before": do_layer_norm_before, - "hidden_act": hidden_act, - "rotary_pct": rotary_pct, - "rotary_base": rotary_base, - "rotary_scaling": rotary_scaling, - "bias": bias, - "inter_size": inter_size, - "multi_query_mode": multi_query_mode, - "dtype": dtype, - "prompt_num_tasks": prompt_num_tasks, - "prompt_max_vocab_size": prompt_max_vocab_size - } - - -def parse_ft_config(ini_file): - gpt_config = configparser.ConfigParser() - gpt_config.read(ini_file) - - if gpt_config.get("gpt", "model", fallback=None) == "starcoder2": - return parse_sc2_config(ini_file) - - n_embd = gpt_config.getint('gpt', 'n_embd') - n_head = gpt_config.getint('gpt', 'n_head') - n_layer = gpt_config.getint('gpt', 'n_layer') - n_positions = gpt_config.getint('gpt', 'n_positions') - vocab_size = gpt_config.getint('gpt', 'vocab_size') - do_layer_norm_before = gpt_config.getboolean('gpt', - 'do_layer_norm_before', - fallback=True) - rotary_base = gpt_config.getfloat('gpt', 'rotary_base', fallback=None) - rotary_scaling_type = gpt_config.get('gpt', - 'rotary_scaling_type', - fallback=None) - rotary_scaling_factor = gpt_config.get('gpt', - 'rotary_scaling_factor', - fallback=None) - if rotary_scaling_type is None: - if rotary_scaling_factor is not None: - raise ValueError( - f"'rotary_scaling_factor={rotary_scaling_factor}' is found in ini " - f"config file {ini_file}, whereas 'rotary_scaling_type' is missing " - f"in the config. The 'rotary_scaling_factor' will be ignored and " - f"rotary scaling will not be used.") - rotary_scaling = None - else: - if rotary_scaling_factor is None: - raise ValueError( - f"'rotary_scaling_factor={rotary_scaling_factor}' was not found " - f"in ini config file {ini_file}, whereas 'rotary_scaling_type' is " - f"provided and equals {repr(rotary_scaling_type)}.") - rotary_scaling = [rotary_scaling_type, rotary_scaling_factor] - rotary_pct = gpt_config.getfloat('gpt', 'rotary_pct', fallback=None) - hidden_act = gpt_config.get('gpt', 'activation_function') - bias = gpt_config.getboolean('gpt', 'bias', fallback=True) - inter_size = gpt_config.getint('gpt', 'intermediate_size', fallback=None) - dtype = gpt_config.get('gpt', 'storage_dtype', fallback='float32') - - if inter_size is None: - inter_size = 4 * n_embd - - multi_query_mode = gpt_config.getboolean('gpt', - 'multi_query_mode', - fallback=False) - prompt_num_tasks = gpt_config.getint('gpt', 'prompt_num_tasks', fallback=0) - prompt_max_vocab_size = gpt_config.getint('gpt', - 'prompt_max_vocab_size', - fallback=0) - return { - "n_embd": n_embd, - "n_head": n_head, - "n_kv_head": 1 if multi_query_mode else n_head, - "n_layer": n_layer, - "n_positions": n_positions, - "vocab_size": vocab_size, - "do_layer_norm_before": do_layer_norm_before, - "hidden_act": hidden_act, - "rotary_pct": rotary_pct, - "rotary_base": rotary_base, - "rotary_scaling": rotary_scaling, - "bias": bias, - "inter_size": inter_size, - "multi_query_mode": multi_query_mode, - "dtype": dtype, - "prompt_num_tasks": prompt_num_tasks, - "prompt_max_vocab_size": prompt_max_vocab_size - } - - -def check_embedding_share(dir_path): - share_embedding_table = False - lm_file = dir_path + '/' + 'model.lm_head.weight.bin' - if not Path(lm_file).exists(): - share_embedding_table = True - return share_embedding_table - - -def load_from_ft(tensorrt_llm_gpt: GPTLMHeadModel, - dir_path, - rank=0, - tensor_parallel=1, - dtype='float32', - use_parallel_embedding=False, - sharding_dim=0, - share_embedding_table=False, - scaling_factors=None): - tensorrt_llm.logger.info('Loading weights from FT...') - tik = time.time() - - quant_mode = getattr(tensorrt_llm_gpt, 'quant_mode', QuantMode(0)) - if quant_mode.is_int8_weight_only(): - plugin_weight_only_quant_type = torch.int8 - elif quant_mode.is_int4_weight_only(): - plugin_weight_only_quant_type = torch.quint4x2 - _parsed_params = parse_ft_config(Path(dir_path) / 'config.ini') - n_embd = _parsed_params["n_embd"] - n_head = _parsed_params["n_head"] - n_kv_head = _parsed_params["n_kv_head"] - head_size = n_embd // n_head - n_layer = _parsed_params["n_layer"] - n_positions = _parsed_params["n_positions"] - vocab_size = _parsed_params["vocab_size"] - do_layer_norm_before = _parsed_params["do_layer_norm_before"] - hidden_act = _parsed_params["hidden_act"] - bias = _parsed_params["bias"] - inter_size = _parsed_params["inter_size"] - - np_dtype = str_dtype_to_np(dtype) - - def fromfile(dir_path, name, shape=None, dtype=None): - dtype = np_dtype if dtype is None else dtype - p = dir_path + '/' + name - if Path(p).exists(): - t = np.fromfile(p, dtype=dtype) - if shape is not None: - t = t.reshape(shape) - return t - return None - - def set_smoothquant_scale_factors(module, - pre_scale_weight, - dir_path, - basename, - shape, - per_tok_dyn, - per_channel, - is_qkv=False, - rank=None): - suffix = "bin" - if per_channel: - if rank is not None: - suffix = f"{rank}." + suffix - suffix = "col." + suffix - - col_shape = shape if (per_channel or is_qkv) else [1, 1] - if per_tok_dyn: - if pre_scale_weight is not None: - pre_scale_weight.value = np.array([1.0], dtype=np.float32) - t = fromfile(dir_path, f"{basename}scale_w_quant_orig.{suffix}", - col_shape, np.float32) - module.per_channel_scale.value = t - else: - t = fromfile(dir_path, f"{basename}scale_x_orig_quant.bin", [1], - np.float32) - pre_scale_weight.value = t - t = fromfile(dir_path, f"{basename}scale_y_accum_quant.{suffix}", - col_shape, np.float32) - module.per_channel_scale.value = t - t = fromfile(dir_path, f"{basename}scale_y_quant_orig.bin", [1, 1], - np.float32) - module.act_scale.value = t - - # Determine the quantization mode. - quant_mode = getattr(tensorrt_llm_gpt, "quant_mode", QuantMode(0)) - # Do we use SmoothQuant? - use_smooth_quant = quant_mode.has_act_and_weight_quant() - # Do we use quantization per token? - quant_per_token_dyn = quant_mode.has_per_token_dynamic_scaling() - # Do we use quantization per channel? - quant_per_channel = quant_mode.has_per_channel_scaling() - - # Do we use INT4/INT8 weight-only? - use_weight_only = quant_mode.is_weight_only() - - # Int8 KV cache - use_int8_kv_cache = quant_mode.has_int8_kv_cache() - - #Enable FP8 Gemm - enable_fp8_qdq = quant_mode.has_fp8_qdq() - - # Debug - suffix = gen_suffix(rank, use_smooth_quant, quant_per_channel) - # The type of weights. - w_type = np_dtype if not use_smooth_quant else np.int8 - - pe = fromfile(dir_path, 'model.wpe.bin', [n_positions, n_embd]) - if pe is not None: - tensorrt_llm_gpt.position_embedding.weight.value = (pe) - - vocab_embedding_weight = fromfile(dir_path, 'model.wte.bin', - [vocab_size, n_embd]) - if not use_parallel_embedding: - tensorrt_llm_gpt.vocab_embedding.weight.value = vocab_embedding_weight - else: - if sharding_dim == 0: - if vocab_size % tensor_parallel != 0: - # padding - vocab_size_padded = pad_vocab_size( - tensorrt_llm_gpt.vocab_embedding.num_embeddings, - tensor_parallel) - pad_width = vocab_size_padded - vocab_size - vocab_embedding_weight = np.pad(vocab_embedding_weight, - ((0, pad_width), (0, 0)), - 'constant', - constant_values=0) - tensorrt_llm_gpt.vocab_embedding.weight.value = np.ascontiguousarray( - split(vocab_embedding_weight, - tensor_parallel, - rank, - dim=sharding_dim)) - - if do_layer_norm_before: - tensorrt_llm_gpt.ln_f.bias.value = (fromfile( - dir_path, 'model.final_layernorm.bias.bin')) - tensorrt_llm_gpt.ln_f.weight.value = (fromfile( - dir_path, 'model.final_layernorm.weight.bin')) - - # share input embedding - if not share_embedding_table: - lm_head_weight = fromfile(dir_path, 'model.lm_head.weight.bin', - [vocab_size, n_embd]) - if lm_head_weight is None: - lm_head_weight = fromfile(dir_path, 'model.wte.bin', - [vocab_size, n_embd]) - if vocab_size % tensor_parallel != 0: - # padding - vocab_size_padded = tensorrt_llm_gpt.lm_head.out_features * tensor_parallel - pad_width = vocab_size_padded - vocab_size - lm_head_weight = np.pad(lm_head_weight, ((0, pad_width), (0, 0)), - 'constant', - constant_values=0) - tensorrt_llm_gpt.lm_head.weight.value = np.ascontiguousarray( - split(lm_head_weight, tensor_parallel, rank)) - fake_fp8_sf_dt = np.float32 - for i in range(n_layer): - c_attn_out_dim = ((n_head // tensor_parallel) + - max(n_kv_head // tensor_parallel, 1) * 2) * head_size - gpt_layer = tensorrt_llm_gpt.layers[i] - gpt_layer.input_layernorm.weight.value = (fromfile( - dir_path, 'model.layers.' + str(i) + '.input_layernorm.weight.bin')) - gpt_layer.input_layernorm.bias.value = (fromfile( - dir_path, 'model.layers.' + str(i) + '.input_layernorm.bias.bin')) - t = fromfile( - dir_path, 'model.layers.' + str(i) + - '.attention.query_key_value.weight.' + suffix, - [n_embd, c_attn_out_dim], w_type) - if t is not None: - dst = gpt_layer.attention.qkv.weight - if use_smooth_quant: - dst.value = np.ascontiguousarray(np.transpose(t, [1, 0])) - set_smoothquant_scale_factors( - gpt_layer.attention.qkv, - gpt_layer.input_layernorm.scale_to_int, - dir_path, - 'model.layers.' + str(i) + '.attention.query_key_value.', - [1, c_attn_out_dim], - quant_per_token_dyn, - quant_per_channel, - rank=rank, - is_qkv=True) - elif use_weight_only: - processed_torch_weights, torch_weight_scales = torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix( - numpy_to_torch(t), plugin_weight_only_quant_type) - dst.value = torch_to_numpy(processed_torch_weights) - scales = tensorrt_llm_gpt.layers[ - i].attention.qkv.per_channel_scale - scales.value = torch_to_numpy(torch_weight_scales) - else: - dst.value = np.ascontiguousarray(np.transpose(t, [1, 0])) - if bias: - t = fromfile( - dir_path, 'model.layers.' + str(i) + - '.attention.query_key_value.bias.' + str(rank) + '.bin') - if t is not None: - dst = gpt_layer.attention.qkv.bias - dst.value = np.ascontiguousarray(t) - if enable_fp8_qdq: - tensorrt_llm_gpt.layers[ - i].attention.qkv.activation_scaling_factor.value = np.array( - [scaling_factors['qkv_act'][i]], dtype=fake_fp8_sf_dt) - tensorrt_llm_gpt.layers[ - i].attention.qkv.weights_scaling_factor.value = np.array( - [scaling_factors['qkv_weights'][i]], dtype=fake_fp8_sf_dt) - tensorrt_llm_gpt.layers[ - i].attention.kv_cache_scaling_factor.value = np.array( - [scaling_factors['qkv_output'][i]], dtype=np.float32) - - dst = gpt_layer.attention.dense.weight - t = fromfile( - dir_path, - 'model.layers.' + str(i) + '.attention.dense.weight.' + suffix, - [n_embd // tensor_parallel, n_embd], w_type) - if use_smooth_quant: - dst.value = np.ascontiguousarray(np.transpose(t, [1, 0])) - dense_scale = getattr(gpt_layer.attention, - "quantization_scaling_factor", None) - set_smoothquant_scale_factors( - gpt_layer.attention.dense, dense_scale, dir_path, - 'model.layers.' + str(i) + '.attention.dense.', [1, n_embd], - quant_per_token_dyn, quant_per_channel) - # change it to the real smoother if dense layer is applied smooth quant - gpt_layer.attention.dense.smoother.value = np.ones( - [1, n_embd // tensor_parallel], dtype=np.float32) - elif use_weight_only: - processed_torch_weights, torch_weight_scales = torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix( - numpy_to_torch(t), plugin_weight_only_quant_type) - dst.value = torch_to_numpy(processed_torch_weights) - scales = tensorrt_llm_gpt.layers[ - i].attention.dense.per_channel_scale - scales.value = torch_to_numpy(torch_weight_scales) - else: - dst.value = np.ascontiguousarray(np.transpose(t, [1, 0])) - - if bias: - dst = gpt_layer.attention.dense.bias - dst.value = fromfile( - dir_path, - 'model.layers.' + str(i) + '.attention.dense.bias.bin') - if enable_fp8_qdq: - tensorrt_llm_gpt.layers[ - i].attention.dense.activation_scaling_factor.value = np.array( - [scaling_factors['dense_act'][i]], dtype=fake_fp8_sf_dt) - tensorrt_llm_gpt.layers[ - i].attention.dense.weights_scaling_factor.value = np.array( - [scaling_factors['dense_weights'][i]], dtype=fake_fp8_sf_dt) - - dst = gpt_layer.post_layernorm.weight - dst.value = fromfile( - dir_path, - 'model.layers.' + str(i) + '.post_attention_layernorm.weight.bin') - - dst = gpt_layer.post_layernorm.bias - dst.value = fromfile( - dir_path, - 'model.layers.' + str(i) + '.post_attention_layernorm.bias.bin') - t = fromfile( - dir_path, - 'model.layers.' + str(i) + '.mlp.dense_h_to_4h.weight.' + suffix, - [n_embd, inter_size // tensor_parallel], w_type) - if use_smooth_quant: - tensorrt_llm_gpt.layers[ - i].mlp.fc.weight.value = np.ascontiguousarray( - np.transpose(t, [1, 0])) - set_smoothquant_scale_factors(gpt_layer.mlp.fc, - gpt_layer.post_layernorm.scale_to_int, - dir_path, - 'model.layers.' + str(i) + - '.mlp.dense_h_to_4h.', - [1, inter_size // tensor_parallel], - quant_per_token_dyn, - quant_per_channel, - rank=rank) - elif use_weight_only: - dst = gpt_layer.mlp.fc.weight - processed_torch_weights, torch_weight_scales = torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix( - numpy_to_torch(t), plugin_weight_only_quant_type) - dst.value = torch_to_numpy(processed_torch_weights) - scales = gpt_layer.mlp.fc.per_channel_scale - scales.value = torch_to_numpy(torch_weight_scales) - else: - tensorrt_llm_gpt.layers[ - i].mlp.fc.weight.value = np.ascontiguousarray( - np.transpose(t, [1, 0])) - if bias: - gpt_layer.mlp.fc.bias.value = fromfile( - dir_path, 'model.layers.' + str(i) + - '.mlp.dense_h_to_4h.bias.' + str(rank) + '.bin') - if is_gated_activation(hidden_act): - t = fromfile( - dir_path, 'model.layers.' + str(i) + - '.mlp.dense_h_to_4h.gate.weight.' + str(rank) + '.bin', - [n_embd, inter_size // tensor_parallel]) - tensorrt_llm_gpt.layers[ - i].mlp.gate.weight.value = np.ascontiguousarray( - np.transpose(t, [1, 0])) - if enable_fp8_qdq: - tensorrt_llm_gpt.layers[ - i].mlp.fc.activation_scaling_factor.value = np.array( - [scaling_factors['fc_act'][i]], dtype=fake_fp8_sf_dt) - tensorrt_llm_gpt.layers[ - i].mlp.fc.weights_scaling_factor.value = np.array( - [scaling_factors['fc_weights'][i]], dtype=fake_fp8_sf_dt) - - t = fromfile( - dir_path, - 'model.layers.' + str(i) + '.mlp.dense_4h_to_h.weight.' + suffix, - [inter_size // tensor_parallel, n_embd], w_type) - if use_smooth_quant: - tensorrt_llm_gpt.layers[ - i].mlp.proj.weight.value = np.ascontiguousarray( - np.transpose(t, [1, 0])) - proj_scale = getattr(gpt_layer.mlp, "quantization_scaling_factor", - None) - set_smoothquant_scale_factors( - gpt_layer.mlp.proj, proj_scale, dir_path, - 'model.layers.' + str(i) + '.mlp.dense_4h_to_h.', [1, n_embd], - quant_per_token_dyn, quant_per_channel) - # change it to the real smoother if proj layer is applied smooth quant - gpt_layer.mlp.proj.smoother.value = np.ones( - [1, inter_size // tensor_parallel], dtype=np.float32) - elif use_weight_only: - dst = gpt_layer.mlp.proj.weight - processed_torch_weights, torch_weight_scales = torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix( - numpy_to_torch(t), plugin_weight_only_quant_type) - dst.value = torch_to_numpy(processed_torch_weights) - scales = gpt_layer.mlp.proj.per_channel_scale - scales.value = torch_to_numpy(torch_weight_scales) - else: - gpt_layer.mlp.proj.weight.value = (np.ascontiguousarray( - np.transpose(t, [1, 0]))) - if bias: - gpt_layer.mlp.proj.bias.value = fromfile( - dir_path, - 'model.layers.' + str(i) + '.mlp.dense_4h_to_h.bias.bin') - - if use_int8_kv_cache: - t = fromfile( - dir_path, 'model.layers.' + str(i) + - '.attention.query_key_value.scale_y_quant_orig.bin', [1], - np.float32) - gpt_layer.attention.kv_cache_scaling_factor.value = t - - if enable_fp8_qdq: - tensorrt_llm_gpt.layers[ - i].mlp.proj.activation_scaling_factor.value = np.array( - [scaling_factors['proj_act'][i]], dtype=fake_fp8_sf_dt) - tensorrt_llm_gpt.layers[ - i].mlp.proj.weights_scaling_factor.value = np.array( - [scaling_factors['proj_weights'][i]], dtype=fake_fp8_sf_dt) - - tok = time.time() - t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) - tensorrt_llm.logger.info(f'Weights loaded. Total time: {t}') - - -def load_from_hf_gpt(tensorrt_llm_gpt: GPTLMHeadModel, - hf_gpt, - rank=0, - tensor_parallel=1, - dtype='float32', - multi_query_mode=False): - tensorrt_llm.logger.info('Loading weights from HF GPT...') - tik = time.time() - - valid_lm_head_weight = False - hidden_size = tensorrt_llm_gpt._hidden_size - head_size = tensorrt_llm_gpt._num_heads // hidden_size - for k, v in hf_gpt.state_dict().items(): - torch_dtype = str_dtype_to_torch(dtype) - v = torch_to_numpy(v.to(torch_dtype).detach().cpu()) - if 'wte.weight' in k: - tensorrt_llm_gpt.vocab_embedding.weight.value = v - elif 'wpe.weight' in k: - tensorrt_llm_gpt.position_embedding.weight.value = v - elif 'ln_f.weight' in k: - tensorrt_llm_gpt.ln_f.weight.value = v - elif 'ln_f.bias' in k: - tensorrt_llm_gpt.ln_f.bias.value = v - elif 'lm_head.weight' in k: - tensorrt_llm_gpt.lm_head.weight.value = np.ascontiguousarray( - split(v, tensor_parallel, rank)) - valid_lm_head_weight = True - else: - layer_idx = extract_layer_idx(k) - if layer_idx is None: - continue - idx = int(layer_idx) - if 'ln_1.weight' in k: - tensorrt_llm_gpt.layers[idx].input_layernorm.weight.value = v - elif 'ln_1.bias' in k: - tensorrt_llm_gpt.layers[idx].input_layernorm.bias.value = v - elif 'attn.c_attn.weight' in k: - if multi_query_mode: - # HF-StarCoder uses torch.nn.Linear - w_qkv = v.reshape(hidden_size + 2 * head_size, 3, - hidden_size) - w_q, w_kv = np.split(w_qkv, [hidden_size, 2 * head_size]) - w_q = split(w_q, tensor_parallel, rank) - dst = tensorrt_llm_gpt.layers[idx].attention.qkv.weight - dst.value = np.ascontiguousarray(np.concatenate(w_q, w_kv)) - else: - # HF-GPT uses Conv1D instead of Linear - v = v.transpose() - dst = tensorrt_llm_gpt.layers[idx].attention.qkv.weight - dst.value = np.ascontiguousarray( - split(v, tensor_parallel, rank)) - elif 'attn.c_attn.bias' in k: - if multi_query_mode: - v.reshape(hidden_size + 2 * head_size, 3) - bias_q, bias_kv = np.split(w_qkv, - [hidden_size, 2 * head_size]) - bias_q = split(bias_q, tensor_parallel, rank) - dst = tensorrt_llm_gpt.layers[idx].attention.qkv.bias - dst.value = np.ascontiguousarray( - np.concatenate(bias_q, bias_kv)) - else: - dst = tensorrt_llm_gpt.layers[idx].attention.qkv.bias - dst.value = np.ascontiguousarray( - split(v, tensor_parallel, rank)) - elif 'attn.q_attn.weight' in k: - # Get the corresponding kv_atten.weight: - # ex: transformer.h.23.attn.kv_attn.weight - u = hf_gpt.state_dict()[k.replace('q_attn', 'kv_attn')] - u = u.to(torch_dtype).cpu().numpy(force=True) - # HF-SantaCoder uses transformer.Conv1D so we transpose to match shape - # In addition, kv_head must be broadcasted to all ranks so split is not applied - v = split(v.transpose(), tensor_parallel, rank) # W_q - u = u.transpose() # W_kv - dst = tensorrt_llm_gpt.layers[idx].attention.qkv.weight - dst.value = np.ascontiguousarray(np.concatenate((v, u))) - elif 'attn.q_attn.bias' in k: - # Get the corresponding kv_atten.bias: - # ex: transformer.h.23.attn.kv_attn.bias - u = hf_gpt.state_dict()[k.replace('q_attn', 'kv_attn')] - u = u.to(torch_dtype).cpu().numpy(force=True) - v = split(v, tensor_parallel, rank) - dst = tensorrt_llm_gpt.layers[idx].attention.qkv.bias - dst.value = np.ascontiguousarray(np.concatenate((v, u))) - elif 'attn.c_proj.weight' in k: - v = v.transpose() - dst = tensorrt_llm_gpt.layers[idx].attention.dense.weight - dst.value = np.ascontiguousarray( - split(v, tensor_parallel, rank, dim=1)) - elif 'attn.c_proj.bias' in k: - dst = tensorrt_llm_gpt.layers[idx].attention.dense.bias - dst.value = v - elif 'ln_2.weight' in k: - dst = tensorrt_llm_gpt.layers[idx].post_layernorm.weight - dst.value = v - elif 'ln_2.bias' in k: - dst = tensorrt_llm_gpt.layers[idx].post_layernorm.bias - dst.value = v - elif 'mlp.c_fc.weight' in k: - v = v.transpose() - tensorrt_llm_gpt.layers[ - idx].mlp.fc.weight.value = np.ascontiguousarray( - split(v, tensor_parallel, rank)) - elif 'mlp.c_fc.bias' in k: - tensorrt_llm_gpt.layers[ - idx].mlp.fc.bias.value = np.ascontiguousarray( - split(v, tensor_parallel, rank)) - elif 'mlp.c_proj.weight' in k: - v = v.transpose() - tensorrt_llm_gpt.layers[ - idx].mlp.proj.weight.value = np.ascontiguousarray( - split(v, tensor_parallel, rank, dim=1)) - elif 'mlp.c_proj.bias' in k: - tensorrt_llm_gpt.layers[idx].mlp.proj.bias.value = v - - if not valid_lm_head_weight: - # Use wte as lm_head weight to match the load_from_ft implementation. - lm_head_weight = tensorrt_llm_gpt.vocab_embedding.weight.raw_value - vocab_size = hf_gpt.config.vocab_size - if vocab_size % tensor_parallel != 0: - # padding - vocab_size_padded = tensorrt_llm_gpt.lm_head.out_features * tensor_parallel - pad_width = vocab_size_padded - vocab_size - lm_head_weight = np.pad(lm_head_weight, ((0, pad_width), (0, 0)), - 'constant', - constant_values=0) - tensorrt_llm_gpt.lm_head.weight.value = np.ascontiguousarray( - split(lm_head_weight, tensor_parallel, rank)) - - tok = time.time() - t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) - tensorrt_llm.logger.info(f'Weights loaded. Total time: {t}') diff --git a/examples/gptneox/requirements.txt b/examples/gptneox/requirements.txt index 648a50ba9..4e510a309 100644 --- a/examples/gptneox/requirements.txt +++ b/examples/gptneox/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.9.0.dev2024031200 +tensorrt_llm==0.9.0.dev2024031900 datasets~=2.14.5 rouge_score~=0.1.2 evaluate~=0.4.1 diff --git a/examples/hf_lora_convert.py b/examples/hf_lora_convert.py index efd566bc7..b1e938f6d 100644 --- a/examples/hf_lora_convert.py +++ b/examples/hf_lora_convert.py @@ -15,8 +15,6 @@ # limitations under the License. import argparse import datetime -#from utils.convert import cpu_map_location -#from utils.nemo import unpack_nemo_ckpt import json import logging import re @@ -27,7 +25,7 @@ import torch from tensorrt_llm._utils import str_dtype_to_torch -from tensorrt_llm.lora_manager import LoraConfig +from tensorrt_llm.lora_manager import LoraManager log_format = "%(asctime)s %(name)s [%(levelname)s] %(message)s" logging.basicConfig(format=log_format) @@ -70,7 +68,7 @@ def get_all_lora_weights(lora_weights): "up_proj": "mlp_gate" } # lora modules on llama hf_modules_to_module_id = { - k: LoraConfig.LORA_MODULE_IDS[v] + k: LoraManager.LORA_MODULE_IDS[v] for k, v in hf_modules_to_trtllm_modules.items() } diff --git a/examples/high-level-api/README.md b/examples/high-level-api/README.md index b4a1c227b..5ae9ffafe 100644 --- a/examples/high-level-api/README.md +++ b/examples/high-level-api/README.md @@ -4,10 +4,10 @@ Here we show you a preview of how it works and how to use it. Note that the APIs are not stable and only support the LLaMA model. We appreciate your patience and understanding as we improve this API. -You can refer to [llm_examples.py](llm_examples.py) for all of the examples, and run it with the [run_examples.sh](./run_examples.sh) script, the command is as follows: +You can refer to [llm_examples.py](llm_examples.py) for all of the examples, and run it with the [run_examples.py](./run_examples.py) script, the command is as follows: ```sh -./run_examples.sh +python3 ./run_examples.py ``` For 7B, 13B models those could be held in a single GPU, it should run all the examples automatically and print the results. diff --git a/examples/high-level-api/llm_examples.py b/examples/high-level-api/llm_examples.py index 0e7da7b49..7f4d932b5 100644 --- a/examples/high-level-api/llm_examples.py +++ b/examples/high-level-api/llm_examples.py @@ -7,7 +7,7 @@ import torch -from tensorrt_llm import LLM, ModelConfig +from tensorrt_llm import LLM, ModelConfig, logger from tensorrt_llm.hlapi.llm import KvCacheConfig, SamplingConfig from tensorrt_llm.hlapi.utils import get_device_count @@ -229,6 +229,7 @@ def _parse_arguments(): parser.add_argument('--world_size', type=int, default=1) parser.add_argument('--tp_size', type=int, default=1) parser.add_argument('--streaming', action='store_true') + parser.add_argument('--log_level', type=str, default='info') return parser.parse_args() @@ -243,7 +244,7 @@ def _get_functions(): if __name__ == '__main__': args = _parse_arguments() - + logger.set_level(args.log_level) tasks = dict( run_llm_from_huggingface_model=lambda: run_llm_from_huggingface_model( [args.prompt], diff --git a/examples/high-level-api/requirements.txt b/examples/high-level-api/requirements.txt new file mode 100644 index 000000000..171260862 --- /dev/null +++ b/examples/high-level-api/requirements.txt @@ -0,0 +1,2 @@ +--extra-index-url https://pypi.nvidia.com +tensorrt_llm==0.9.0.dev2024031900 diff --git a/examples/high-level-api/run_examples.py b/examples/high-level-api/run_examples.py new file mode 100644 index 000000000..c308e8b0d --- /dev/null +++ b/examples/high-level-api/run_examples.py @@ -0,0 +1,44 @@ +import os +import subprocess +import sys + +PROMPT = "Tell a story" +LLAMA_MODEL_DIR = sys.argv[1] +TMP_ENGINE_DIR = sys.argv[2] if len(sys.argv) > 2 else "./tllm.engine.example" +EXAMPLES_ROOT = sys.argv[3] if len(sys.argv) > 3 else "" +LLM_EXAMPLES = os.path.join(EXAMPLES_ROOT, 'llm_examples.py') + +run_cmd = [ + sys.executable, LLM_EXAMPLES, "--task=run_llm_from_huggingface_model", + f"--prompt={PROMPT}", f"--hf_model_dir={LLAMA_MODEL_DIR}", + f"--dump_engine_dir={TMP_ENGINE_DIR}" +] +subprocess.run(run_cmd, check=True) + +# TP enabled +run_cmd = [ + sys.executable, LLM_EXAMPLES, "--task=run_llm_from_huggingface_model", + f"--prompt={PROMPT}", f"--hf_model_dir={LLAMA_MODEL_DIR}", "--tp_size=2" +] +subprocess.run(run_cmd, check=True) + +run_cmd = [ + sys.executable, LLM_EXAMPLES, "--task=run_llm_from_tllm_engine", + f"--prompt={PROMPT}", f"--hf_model_dir={LLAMA_MODEL_DIR}", + f"--dump_engine_dir={TMP_ENGINE_DIR}" +] +subprocess.run(run_cmd, check=True) + +run_cmd = [ + sys.executable, LLM_EXAMPLES, "--task=run_llm_generate_async_example", + f"--prompt={PROMPT}", f"--hf_model_dir={LLAMA_MODEL_DIR}" +] +subprocess.run(run_cmd, check=True) + +# Both TP and streaming enabled +run_cmd = [ + sys.executable, LLM_EXAMPLES, "--task=run_llm_generate_async_example", + f"--prompt={PROMPT}", f"--hf_model_dir={LLAMA_MODEL_DIR}", "--streaming", + "--tp_size=2" +] +subprocess.run(run_cmd, check=True) diff --git a/examples/high-level-api/run_examples.sh b/examples/high-level-api/run_examples.sh deleted file mode 100755 index f7ef57a10..000000000 --- a/examples/high-level-api/run_examples.sh +++ /dev/null @@ -1,38 +0,0 @@ -#!/bin/bash -set -ex - -PROMPT="Tell a story" -LLAMA_MODEL_DIR=$1 -default_engine_dir="./tllm.engine.example" -TMP_ENGINE_DIR="${2:-$default_engine_dir}" - -python3 llm_examples.py --task run_llm_from_huggingface_model \ - --prompt="$PROMPT" \ - --hf_model_dir=$LLAMA_MODEL_DIR \ - --dump_engine_dir=$TMP_ENGINE_DIR - -# TP enabled -python3 llm_examples.py --task run_llm_from_huggingface_model \ - --prompt="$PROMPT" \ - --hf_model_dir=$LLAMA_MODEL_DIR \ - --tp_size=2 - -python3 llm_examples.py --task run_llm_from_tllm_engine \ - --prompt="$PROMPT" \ - --hf_model_dir=$LLAMA_MODEL_DIR \ - --dump_engine_dir=$TMP_ENGINE_DIR - -python3 llm_examples.py --task run_llm_generate_async_example \ - --prompt="$PROMPT" \ - --hf_model_dir=$LLAMA_MODEL_DIR - -# Both TP and streaming enabled -python3 llm_examples.py --task run_llm_generate_async_example \ - --prompt="$PROMPT" \ - --hf_model_dir=$LLAMA_MODEL_DIR \ - --streaming \ - --tp_size=2 - -python3 llm_examples.py --task run_llm_with_async_future \ - --prompt="$PROMPT" \ - --hf_model_dir=$LLAMA_MODEL_DIR diff --git a/examples/high-level-api/run_quant_examples.py b/examples/high-level-api/run_quant_examples.py new file mode 100644 index 000000000..0a75adc36 --- /dev/null +++ b/examples/high-level-api/run_quant_examples.py @@ -0,0 +1,22 @@ +import os +import subprocess +import sys + +PROMPT = "Tell a story" +LLAMA_MODEL_DIR = sys.argv[1] +EXAMPLES_ROOT = sys.argv[2] if len(sys.argv) > 2 else "" +LLM_EXAMPLES = os.path.join(EXAMPLES_ROOT, 'llm_examples.py') + +run_cmd = [ + sys.executable, LLM_EXAMPLES, "--task=run_llm_with_quantization", + f"--prompt={PROMPT}", f"--hf_model_dir={LLAMA_MODEL_DIR}", + "--quant_type=int4_awq" +] +subprocess.run(run_cmd, check=True) + +run_cmd = [ + sys.executable, LLM_EXAMPLES, "--task=run_llm_with_quantization", + f"--prompt={PROMPT}", f"--hf_model_dir={LLAMA_MODEL_DIR}", + "--quant_type=fp8" +] +subprocess.run(run_cmd, check=True) diff --git a/examples/high-level-api/run_quant_examples.sh b/examples/high-level-api/run_quant_examples.sh deleted file mode 100755 index fdd3a7762..000000000 --- a/examples/high-level-api/run_quant_examples.sh +++ /dev/null @@ -1,16 +0,0 @@ -#!/bin/bash -set -ex - -PROMPT="Tell a story" -LLAMA_MODEL_DIR=$1 - - -python3 llm_examples.py --task run_llm_with_quantization \ - --prompt="$PROMPT" \ - --hf_model_dir=$LLAMA_MODEL_DIR \ - --quant_type="int4_awq" - -python3 llm_examples.py --task run_llm_with_quantization \ - --prompt="$PROMPT" \ - --hf_model_dir=$LLAMA_MODEL_DIR \ - --quant_type="fp8" diff --git a/examples/internlm/requirements.txt b/examples/internlm/requirements.txt index fe334af0b..dcce2709e 100644 --- a/examples/internlm/requirements.txt +++ b/examples/internlm/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.9.0.dev2024031200 +tensorrt_llm==0.9.0.dev2024031900 datasets==2.14.5 rouge_score~=0.1.2 sentencepiece~=0.1.99 diff --git a/examples/llama/.gitignore b/examples/llama/.gitignore index b43358a79..02915ec1a 100644 --- a/examples/llama/.gitignore +++ b/examples/llama/.gitignore @@ -1,3 +1,5 @@ llama* tokenizer.model *output* +*.safetensors +*.json diff --git a/examples/llama/README.md b/examples/llama/README.md index 13f44833b..e59bd59e8 100644 --- a/examples/llama/README.md +++ b/examples/llama/README.md @@ -82,7 +82,7 @@ python convert_checkpoint.py --model_dir ./tmp/llama/7B/ \ trtllm-build --checkpoint_dir ./tllm_checkpoint_1gpu_fp16 \ --output_dir ./tmp/llama/7B/trt_engines/fp16/2-gpu/ \ --gemm_plugin float16 \ - --world_size 2 + --auto_parallel 2 # Build LLaMA 7B using 2-way tensor parallelism. python convert_checkpoint.py --model_dir ./tmp/llama/7B/ \ @@ -573,22 +573,22 @@ git-lfs clone https://huggingface.co/meta-llama/Llama-2-13b-hf git-lfs clone https://huggingface.co/hfl/chinese-llama-2-lora-13b ``` -* Build engine, setting `--use_lora_plugin` and `--hf_lora_dir`. If lora has separate lm_head and embedding, they will replace lm_head and embedding of base model. +* Build engine, setting `--use_lora_plugin` and `--lora_dir`. If lora has separate lm_head and embedding, they will replace lm_head and embedding of base model. ```bash -python convert_checkpoint.py --model_dir /tmp/llama-v2-13b-hf \ - --output_dir ./tllm_checkpoint_2gpu_lora \ +python convert_checkpoint.py --model_dir Llama-2-13b-hf \ + --output_dir ./tllm_checkpoint_2gpu \ --dtype float16 \ - --tp_size 2 \ - --hf_lora_dir /tmp/chinese-llama-2-lora-13b + --tp_size 2 -trtllm-build --checkpoint_dir ./tllm_checkpoint_2gpu_lora \ +trtllm-build --checkpoint_dir ./tllm_checkpoint_2gpu \ --output_dir /tmp/new_lora_13b/trt_engines/fp16/2-gpu/ \ --gemm_plugin float16 \ --lora_plugin float16 \ --max_batch_size 1 \ --max_input_len 512 \ - --max_output_len 50 + --max_output_len 50 \ + --lora_dir chinese-llama-2-lora-13b ``` * Run inference. Need to setup the `lora_dir`. Remember to use lora tokenizer because lora model has larger vocab size. @@ -597,31 +597,30 @@ trtllm-build --checkpoint_dir ./tllm_checkpoint_2gpu_lora \ mpirun -n 2 python ../run.py --engine_dir "/tmp/new_lora_13b/trt_engines/fp16/2-gpu/" \ --max_output_len 50 \ --tokenizer_dir "chinese-llama-2-lora-13b/" \ - --input_text "今天天气很好,我到公园的时后," \ - --lora_dir "chinese-llama-2-lora-13b/" \ + --input_text "今天天气很好,我到公园的时候," \ --lora_task_uids 0 \ --no_add_special_tokens \ --use_py_session - Input: "今天天气很好,我到公园的时后," -Output: "发现公园里人很多,有的在打羽毛球,有的在打乒乓球,有的在跳绳,还有的在跑步。我和妈妈来到一个空地上,我和妈妈一起跳绳,我跳了1" + Input: "今天天气很好,我到公园的时候," +Output: "发现公园里到处都是人,有的在跑步,有的在打羽毛球,还有的在跳绳,我和妈妈一起在公园里散步,我和妈妈在公园里散步的时候,看见了一位老爷爷在打羽毛球" ``` Users who want to skip LoRA module may pass uid -1 with `--lora_task_uids -1`. In that case, the model will not run the LoRA module and the results will be -different. +different. Since the LoRA tokenizer, embedding and LM head are still used, +the results will also be different with vanilla LLaMA and significantly degrade compared with `--lora_task_uids 0`. ```bash mpirun -n 2 python ../run.py --engine_dir "/tmp/new_lora_13b/trt_engines/fp16/2-gpu/" \ --max_output_len 50 \ --tokenizer_dir "chinese-llama-2-lora-13b/" \ - --input_text "今天天气很好,我到公园的时后," \ - --lora_dir "chinese-llama-2-lora-13b/" \ + --input_text "今天天气很好,我到公园的时候," \ --lora_task_uids -1 \ --no_add_special_tokens \ --use_py_session - Input: "今天天气很好,我到公园的时后," -Output: "我看见一个人坐在那边边看书书,我看起来还挺像你,可是我走过过去问了一下他说你是你吗,他说没有,然后我就说你看我看看你像你,他说说你看我像你,我说你是你,他说你是你," + Input: "今天天气很好,我到公园的时候," +Output: "看见好多人们都看书,看书书看书书,看书书看书书书书书书书书书书书书书书书书书书书书书书书书书书书书书书书书书书书书" ``` ### Run LLaMa with several lora checkpoints @@ -649,24 +648,23 @@ git-lfs clone https://huggingface.co/kunishou/Japanese-Alpaca-LoRA-7b-v0 BASE_LLAMA_MODEL=llama-7b-hf/ python convert_checkpoint.py --model_dir ${BASE_LLAMA_MODEL} \ - --output_dir ./tllm_checkpoint_1gpu_lora_rank \ - --dtype float16 \ - --hf_lora_dir /tmp/Japanese-Alpaca-LoRA-7b-v0 \ - --max_lora_rank 8 \ - --lora_target_modules "attn_q" "attn_k" "attn_v" -trtllm-build --checkpoint_dir ./tllm_checkpoint_1gpu_lora_rank \ + --output_dir ./tllm_checkpoint_1gpu \ + --dtype float16 +trtllm-build --checkpoint_dir ./tllm_checkpoint_1gpu \ --output_dir /tmp/llama_7b_with_lora_qkv/trt_engines/fp16/1-gpu/ \ --gemm_plugin float16 \ --lora_plugin float16 \ --max_batch_size 8 \ --max_input_len 512 \ - --max_output_len 50 + --max_output_len 50 \ + --lora_dir "luotuo-lora-7b-0.1/" "Japanese-Alpaca-LoRA-7b-v0/" \ + --max_lora_rank 8 \ + --lora_target_modules attn_q attn_k attn_v python ../run.py --engine_dir "/tmp/llama_7b_with_lora_qkv/trt_engines/fp16/1-gpu/" \ --max_output_len 10 \ --tokenizer_dir ${BASE_LLAMA_MODEL} \ --input_text "美国的首都在哪里? \n答案:" "美国的首都在哪里? \n答案:" "美国的首都在哪里? \n答案:" "アメリカ合衆国の首都はどこですか? \n答え:" "アメリカ合衆国の首都はどこですか? \n答え:" "アメリカ合衆国の首都はどこですか? \n答え:" \ - --lora_dir "luotuo-lora-7b-0.1/" "Japanese-Alpaca-LoRA-7b-v0/" \ --lora_task_uids -1 0 1 -1 0 1 \ --use_py_session --top_p 0.5 --top_k 0 ``` diff --git a/examples/llama/convert_checkpoint.py b/examples/llama/convert_checkpoint.py index ab5cd2f77..433ec4d78 100644 --- a/examples/llama/convert_checkpoint.py +++ b/examples/llama/convert_checkpoint.py @@ -1,19 +1,19 @@ import argparse import json import os +import sys import time import traceback from concurrent.futures import ThreadPoolExecutor, as_completed -import safetensors - import tensorrt_llm +from tensorrt_llm._utils import release_gc from tensorrt_llm.layers import MoeConfig from tensorrt_llm.mapping import Mapping from tensorrt_llm.models import LLaMAForCausalLM -from tensorrt_llm.models.llama.convert import (create_config_from_hugging_face, - from_hugging_face, quantize) from tensorrt_llm.models.llama.weight import load_from_gptq_llama +from tensorrt_llm.models.modeling_utils import QuantizationConfig +from tensorrt_llm.quantization import mode as quant_algo def parse_arguments(): @@ -155,9 +155,6 @@ def parse_arguments(): help= 'Try to reduce the engine size by sharing the embedding lookup table between two layers.' 'Note: the flag might not take effect when the criteria are not met.') - parser.add_argument('--use_prompt_tuning', - action="store_true", - default=False) parser.add_argument('--output_dir', type=str, default='tllm_checkpoint', @@ -227,85 +224,49 @@ def parse_arguments(): ) args = parser.parse_args() + # changing the default to be consistent as the cli help said. + if args.moe_num_experts and args.moe_top_k == 0: + args.moe_top_k = 1 return args -def args_to_quantization(args: argparse.Namespace): +def args_to_quantization(args: argparse.Namespace) -> QuantizationConfig: '''return config dict with quantization info based on the command line args ''' - config = { - 'quantization': { - 'quant_algo': None, - 'kv_cache_quant_algo': None, - 'exclude_modules': ['lm_head'], - } - } - + quant_config = QuantizationConfig() + quant_config.exclude_modules = ['lm_head'] if args.use_weight_only: if args.weight_only_precision == 'int8': - config['quantization']['quant_algo'] = 'W8A16' + quant_config.quant_algo = quant_algo.W8A16 elif args.weight_only_precision == 'int4': - config['quantization']['quant_algo'] = 'W4A16' + quant_config.quant_algo = quant_algo.W4A16 elif args.smoothquant: if args.per_channel: if args.per_token: - config['quantization'][ - 'quant_algo'] = 'W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN' + quant_config.quant_algo = quant_algo.W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN else: - config['quantization'][ - 'quant_algo'] = 'W8A8_SQ_PER_CHANNEL_PER_TENSOR_PLUGIN' + quant_config.quant_algo = quant_algo.W8A8_SQ_PER_CHANNEL_PER_TENSOR_PLUGIN else: if args.per_token: - config['quantization'][ - 'quant_algo'] = 'W8A8_SQ_PER_TENSOR_PER_TOKEN_PLUGIN' + quant_config.quant_algo = quant_algo.W8A8_SQ_PER_TENSOR_PER_TOKEN_PLUGIN else: - config['quantization'][ - 'quant_algo'] = 'W8A8_SQ_PER_TENSOR_PLUGIN' + quant_config.quant_algo = quant_algo.W8A8_SQ_PER_TENSOR_PLUGIN if args.int8_kv_cache: - config['quantization']['kv_cache_quant_algo'] = 'INT8' + quant_config.kv_cache_quant_algo = quant_algo.INT8 if args.weight_only_precision == 'int4_gptq': - config['quantization'].update({ - "group_size": args.group_size, - "has_zero_point": True, - "pre_quant_scale": False, - 'quant_algo': 'W4A16_GPTQ' - }) - return config - - -def has_any_quant(args): - config = args_to_quantization(args) - return config['quantization']['quant_algo'] is not None or config[ - 'quantization']['kv_cache_quant_algo'] is not None + quant_config.group_size = args.group_size + quant_config.has_zero_point = True + quant_config.pre_quant_scale = False + quant_config.quant_algo = quant_algo.W4A16_GPTQ + return quant_config -def create_config_from_args(args: argparse.Namespace): - config = {} - mapping = Mapping(world_size=args.tp_size * args.pp_size, - tp_size=args.tp_size, - pp_size=args.pp_size) - # Need to convert the cli args to the kay-value pairs and override them in the generate config dict. - # Ideally these fields will be moved out of the config and pass them into build API, keep them here for compatibility purpose for now, - # before the refactor is done. - override_fields = {'moe_tp_mode': args.moe_tp_mode} - override_fields.update(args_to_quantization(args)) - override_fields.update(args_to_build_options(args)) - - assert args.model_dir is not None - kwargs = { - 'hf_lora_dir': args.hf_lora_dir, - 'lora_target_modules': args.lora_target_modules, - 'max_lora_rank': args.max_lora_rank, - } - config = create_config_from_hugging_face(args.model_dir, - args.dtype, - mapping, - override_fields=override_fields, - **kwargs) - return config +def has_any_quant(args): + quant_config = args_to_quantization(args) + return quant_config.quant_algo is not None or quant_config.kv_cache_quant_algo is not None def convert_and_save_meta(args, rank): @@ -313,21 +274,15 @@ def convert_and_save_meta(args, rank): tp_size=args.tp_size, pp_size=args.pp_size, rank=rank) - override_fields = {'moe_tp_mode': args.moe_tp_mode} - override_fields.update(args_to_quantization(args)) - override_fields.update(args_to_build_options(args)) - - assert not has_any_quant( - args - ), "quantization from meta checkpoint or empty model were never supported" + assert not has_any_quant(args), \ + "quantization from meta checkpoint or empty model were never supported" assert not args.hf_lora_dir, "lora is only supported when loading from hf model dir for now" - kwargs = {} - assert args.meta_ckpt_dir is not None - llama = LLaMAForCausalLM.from_meta_ckpt(args.meta_ckpt_dir, - args.dtype, - mapping, - override_fileds=override_fields, - **kwargs) + llama = LLaMAForCausalLM.from_meta_ckpt( + args.meta_ckpt_dir, + args.dtype, + mapping, + use_parallel_embedding=args.use_parallel_embedding, + embedding_sharding_dim=args.embedding_sharding_dim) llama.save_checkpoint(args.output_dir, save_config=(rank == 0)) @@ -336,61 +291,67 @@ def args_to_build_options(args): 'use_parallel_embedding': args.use_parallel_embedding, 'embedding_sharding_dim': args.embedding_sharding_dim, 'share_embedding_table': args.use_embedding_sharing, - 'use_prompt_tuning': args.use_prompt_tuning, 'disable_weight_only_quant_plugin': args.disable_weight_only_quant_plugin } def from_cli_args(args): - config = {} - mapping = Mapping(world_size=args.tp_size * args.pp_size, - tp_size=args.tp_size, - pp_size=args.pp_size) - architecture = "LlamaForCausalLM" - n_layer = args.n_layer - n_head = args.n_head - n_embd = args.n_embd - inter_size = args.inter_size - n_kv_head = args.n_kv_head if args.n_kv_head is not None else n_head # default to MHA - vocab_size = args.vocab_size - n_positions = args.n_positions - hidden_act = args.hidden_act - rotary_base = args.rotary_base - rms_norm_eps = args.rms_norm_eps - moe_num_experts = args.moe_num_experts - moe_top_k = args.moe_top_k - moe_tp_mode = args.moe_tp_mode - config['moe_normalization_mode'] = args.moe_renorm_mode - # config values from reading model config - config.update({ - 'architecture': architecture, + n_kv_head = args.n_kv_head if args.n_kv_head is not None else args.n_head + config = { + 'architecture': "LlamaForCausalLM", 'dtype': args.dtype, 'logits_dtype': 'float32', - 'num_hidden_layers': n_layer, - 'num_attention_heads': n_head, - 'hidden_size': n_embd, - 'intermediate_size': inter_size, + 'num_hidden_layers': args.n_layer, + 'num_attention_heads': args.n_head, + 'hidden_size': args.n_embd, + 'intermediate_size': args.inter_size, 'num_key_value_heads': n_kv_head, - 'vocab_size': vocab_size, + 'vocab_size': args.vocab_size, 'position_embedding_type': 'rope_gpt_neox', - 'max_position_embeddings': n_positions, - 'hidden_act': hidden_act, - 'rotary_base': rotary_base, - 'norm_epsilon': rms_norm_eps, - 'moe_num_experts': moe_num_experts, - 'moe_top_k': moe_top_k, - 'moe_tp_mode': moe_tp_mode, + 'max_position_embeddings': args.n_positions, + 'hidden_act': args.hidden_act, + 'rotary_base': args.rotary_base, + 'norm_epsilon': args.rms_norm_eps, + 'moe_num_experts': args.moe_num_experts, + 'moe_top_k': args.moe_top_k, + 'moe_tp_mode': args.moe_tp_mode, + 'moe_normalization_mode': args.moe_renorm_mode, 'mapping': { - 'world_size': mapping.tp_size * mapping.pp_size, - 'tp_size': mapping.tp_size, - 'pp_size': mapping.pp_size - } - }) + 'world_size': args.tp_size * args.pp_size, + 'tp_size': args.tp_size, + 'pp_size': args.pp_size + }, + 'quantization': args_to_quantization(args).asdict() + } config.update(args_to_build_options(args)) return config +def preload_model(model_dir): + from transformers import AutoConfig, AutoModelForCausalLM + if "vila" in model_dir: + sys.path.append(model_dir + "/../VILA") + from llava.model import LlavaConfig, LlavaLlamaForCausalLM + AutoConfig.register("llava_llama", LlavaConfig) + AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM) + + hf_config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True) + if hf_config.model_type == "llava": + from transformers import LlavaForConditionalGeneration + hf_llava = LlavaForConditionalGeneration.from_pretrained( + model_dir, torch_dtype="auto") + model = hf_llava.language_model + else: + model = AutoModelForCausalLM.from_pretrained( + model_dir, + device_map='auto', + torch_dtype='auto', + trust_remote_code=True, + ) + return model + + def convert_and_save_hf(args): model_dir = args.model_dir load_model_on_cpu = args.load_model_on_cpu @@ -400,9 +361,8 @@ def convert_and_save_hf(args): # Ideally these fields will be moved out of the config and pass them into build API, keep them here for compatibility purpose for now, # before the refactor is done. override_fields = {'moe_tp_mode': args.moe_tp_mode} - override_fields.update(args_to_quantization(args)) + quantization = args_to_quantization(args) override_fields.update(args_to_build_options(args)) - assert model_dir is not None if args.smoothquant is not None or args.int8_kv_cache: assert not args.load_by_shard, "When using quantization, TRT-LLM needs to load the whole HF model, thus load by shard not supported" @@ -412,55 +372,60 @@ def convert_and_save_hf(args): rank=-1, #intentinoally make -1 to avoid mistake tp_size=args.tp_size, pp_size=args.pp_size) - quantize(args.dtype, - args.model_dir, - args.output_dir, - mapping, - override_fields=override_fields, - dataset_cache_dir=args.dataset_cache_dir, - smoothquant_val=args.smoothquant, - int8_kv_cache=args.int8_kv_cache, - hf_lora_dir=args.hf_lora_dir, - lora_target_modules=args.lora_target_modules, - max_lora_rank=args.max_lora_rank) + LLaMAForCausalLM.quantize(args.model_dir, + args.output_dir, + quantization, + dtype=args.dtype, + mapping=mapping, + override_fields=override_fields, + dataset_cache_dir=args.dataset_cache_dir, + smoothquant_val=args.smoothquant, + hf_lora_dir=args.hf_lora_dir, + lora_target_modules=args.lora_target_modules, + max_lora_rank=args.max_lora_rank) else: - for rank in range(world_size): + # When not loading by shard, preload one complete model and then slice per rank weights from this + # this saves the disk reloading time + hf_model = preload_model(model_dir) if not args.load_by_shard else None + + def convert_and_save_rank(args, rank): mapping = Mapping(world_size=world_size, rank=rank, tp_size=args.tp_size, pp_size=args.pp_size) - #TODO: change to LLaMAForCausalLM.from_hugging_face after refactor is done - llama = from_hugging_face( - LLaMAForCausalLM, + llama = LLaMAForCausalLM.from_hugging_face( model_dir, args.dtype, mapping=mapping, + quantization=quantization, load_by_shard=load_by_shard, load_model_on_cpu=load_model_on_cpu, override_fields=override_fields, hf_lora_dir=args.hf_lora_dir, lora_target_modules=args.lora_target_modules, - max_lora_rank=args.max_lora_rank) + max_lora_rank=args.max_lora_rank, + preloaded_model=hf_model) llama.save_checkpoint(args.output_dir, save_config=(rank == 0)) + del llama + release_gc() + + execute(args.workers, [convert_and_save_rank] * world_size, args) def convert_and_save_gptq(args, rank): - config = create_config_from_args(args) - if rank == 0: - with open(os.path.join(args.output_dir, 'config.json'), 'w') as f: - json.dump(config, f, indent=4) - mapping = Mapping(world_size=config['mapping']['tp_size'] * - config['mapping']['pp_size'], + mapping = Mapping(world_size=args.tp_size * args.pp_size, + tp_size=args.tp_size, rank=rank, - tp_size=config['mapping']['tp_size'], - pp_size=config['mapping']['pp_size']) - weights = load_from_gptq_llama(args.ammo_quant_ckpt_path, - config['num_hidden_layers'], - config['vocab_size'], - mapping, - dtype=config['dtype']) - safetensors.torch.save_file( - weights, os.path.join(args.output_dir, f'rank{rank}.safetensors')) + pp_size=args.pp_size) + llama = LLaMAForCausalLM.from_hugging_face( + args.model_dir, + args.dtype, + mapping=mapping, + quantization=args_to_quantization(args), + skip_loading_weights=True) + weights = load_from_gptq_llama(llama.config, args.ammo_quant_ckpt_path) + llama.load(weights) + llama.save_checkpoint(args.output_dir, rank == 0) def execute(workers, func, args): @@ -486,22 +451,19 @@ def main(): print(tensorrt_llm.__version__) args = parse_arguments() - # changing the default to be consistent as the cli help said. - if args.moe_num_experts and args.moe_top_k == 0: - args.moe_top_k = 1 world_size = args.tp_size * args.pp_size tik = time.time() if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) - ####### save config - if (args.model_dir is None and args.meta_ckpt_dir is None): + if (args.model_dir is None + and args.meta_ckpt_dir is None): # generate fake config.json config = from_cli_args(args) with open(os.path.join(args.output_dir, 'config.json'), 'w') as f: json.dump(config, f, indent=4) - return elif args.meta_ckpt_dir is not None: + assert args.model_dir is None, "Shall not specify both meta checkpoint dir and hugging face dir" execute(args.workers, [convert_and_save_meta] * world_size, args) elif args.weight_only_precision == 'int4_gptq': assert args.model_dir is not None diff --git a/examples/llama/requirements.txt b/examples/llama/requirements.txt index c2524c4cb..02a327381 100644 --- a/examples/llama/requirements.txt +++ b/examples/llama/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.9.0.dev2024031200 +tensorrt_llm==0.9.0.dev2024031900 datasets==2.14.6 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/mamba/README.md b/examples/mamba/README.md index 83a3fac5a..bc9c3ab88 100644 --- a/examples/mamba/README.md +++ b/examples/mamba/README.md @@ -29,7 +29,9 @@ Install the dependency packages and setup `git-lfs`. ```bash # Install dependencies -pip install -r requirements.txt +git clone --branch v1.1.1 https://github.com/Dao-AILab/causal-conv1d.git +git clone --branch v1.1.1 https://github.com/state-spaces/mamba.git +pip install ./causal-conv1d/ ./mamba/ # Setup git-lfs git lfs install @@ -39,29 +41,29 @@ There are six HF checkpoints available. Use one of the following commands to fet ```bash # mamba-2.8b-slimpj -git clone https://huggingface.co/state-spaces/mamba-2.8b-slimpj +git clone https://huggingface.co/state-spaces/mamba-2.8b-slimpj ./mamba_model/mamba-2.8b-slimpj # mamba-2.8b -git clone https://huggingface.co/state-spaces/mamba-2.8b +git clone https://huggingface.co/state-spaces/mamba-2.8b ./mamba_model/mamba-2.8b # mamba-1.4b -git clone https://huggingface.co/state-spaces/mamba-1.4b +git clone https://huggingface.co/state-spaces/mamba-1.4b ./mamba_model/mamba-1.4b # mamba-790m -git clone https://huggingface.co/state-spaces/mamba-790m +git clone https://huggingface.co/state-spaces/mamba-790m ./mamba_model/mamba-790m # mamba-370m -git clone https://huggingface.co/state-spaces/mamba-370m +git clone https://huggingface.co/state-spaces/mamba-370m ./mamba_model/mamba-370m # mamba-130m -git clone https://huggingface.co/state-spaces/mamba-130m +git clone https://huggingface.co/state-spaces/mamba-130m ./mamba_model/mamba-130m ``` Since mamba models use tokenizer from gpt-neox-20b model, use the following command to fetch the checkpoint of gpt-neox-20b. ```bash # gpt-neox-20b -git clone https://huggingface.co/EleutherAI/gpt-neox-20b +git clone https://huggingface.co/EleutherAI/gpt-neox-20b ./mamba_model/gpt-neox-20b ``` ### 2. Convert weights from HF Transformers to TensorRT-LLM format @@ -69,34 +71,34 @@ The [`convert_checkpoint.py`](./convert_checkpoint.py) script converts HF weight ```bash # mamba-2.8b-slimpj -python convert_checkpoint.py --model_dir ./mamba-2.8b-slimpj/ \ +python convert_checkpoint.py --model_dir ./mamba_model/mamba-2.8b-slimpj/ \ --dtype bfloat16 \ - --output_dir ./mamba/mamba-2.8b-slimpj/trt_ckpt/bf16/1-gpu/ + --output_dir ./mamba_model/mamba-2.8b-slimpj/trt_ckpt/bf16/1-gpu/ # mamba-2.8b -python convert_checkpoint.py --model_dir ./mamba-2.8b/ \ +python convert_checkpoint.py --model_dir ./mamba_model/mamba-2.8b/ \ --dtype bfloat16 \ - --output_dir ./mamba/mamba-2.8b/trt_ckpt/bf16/1-gpu/ + --output_dir ./mamba_model/mamba-2.8b/trt_ckpt/bf16/1-gpu/ # mamba-1.4b -python convert_checkpoint.py --model_dir ./mamba-1.4b/ \ +python convert_checkpoint.py --model_dir ./mamba_model/mamba-1.4b/ \ --dtype float16 \ - --output_dir ./mamba/mamba-1.4b/trt_ckpt/fp16/1-gpu/ + --output_dir ./mamba_model/mamba-1.4b/trt_ckpt/fp16/1-gpu/ # mamba-790m -python convert_checkpoint.py --model_dir ./mamba-790m/ \ +python convert_checkpoint.py --model_dir ./mamba_model/mamba-790m/ \ --dtype float16 \ - --output_dir ./mamba/mamba-790m/trt_ckpt/fp16/1-gpu/ + --output_dir ./mamba_model/mamba-790m/trt_ckpt/fp16/1-gpu/ # mamba-370m -python convert_checkpoint.py --model_dir ./mamba-370m/ \ +python convert_checkpoint.py --model_dir ./mamba_model/mamba-370m/ \ --dtype float16 \ - --output_dir ./mamba/mamba-370m/trt_ckpt/fp16/1-gpu/ + --output_dir ./mamba_model/mamba-370m/trt_ckpt/fp16/1-gpu/ # mamba-130m -python convert_checkpoint.py --model_dir ./mamba-130m/ \ +python convert_checkpoint.py --model_dir ./mamba_model/mamba-130m/ \ --dtype float16 \ - --output_dir ./mamba/mamba-130m/trt_ckpt/fp16/1-gpu/ + --output_dir ./mamba_model/mamba-130m/trt_ckpt/fp16/1-gpu/ ``` ### 3. Build TensorRT engine(s) @@ -104,7 +106,7 @@ The `trtllm-build` command builds TensorRT-LLM engines from TensorRT-LLM checkpo ```bash # mamba-2.8b-slimpj -trtllm-build --checkpoint_dir ./mamba/mamba-2.8b-slimpj/trt_ckpt/bf16/1-gpu/ \ +trtllm-build --checkpoint_dir ./mamba_model/mamba-2.8b-slimpj/trt_ckpt/bf16/1-gpu/ \ --gpt_attention_plugin disable \ --paged_kv_cache disable \ --remove_input_padding disable \ @@ -112,10 +114,10 @@ trtllm-build --checkpoint_dir ./mamba/mamba-2.8b-slimpj/trt_ckpt/bf16/1-gpu/ \ --max_batch_size 8 \ --max_input_len 924 \ --max_output_len 100 \ - --output_dir ./mamba/mamba-2.8b-slimpj/trt_engines/bf16/1-gpu/ + --output_dir ./mamba_model/mamba-2.8b-slimpj/trt_engines/bf16/1-gpu/ # mamba-2.8b -trtllm-build --checkpoint_dir ./mamba/mamba-2.8b/trt_ckpt/bf16/1-gpu/ \ +trtllm-build --checkpoint_dir ./mamba_model/mamba-2.8b/trt_ckpt/bf16/1-gpu/ \ --gpt_attention_plugin disable \ --paged_kv_cache disable \ --remove_input_padding disable \ @@ -123,10 +125,10 @@ trtllm-build --checkpoint_dir ./mamba/mamba-2.8b/trt_ckpt/bf16/1-gpu/ \ --max_batch_size 8 \ --max_input_len 924 \ --max_output_len 100 \ - --output_dir ./mamba/mamba-2.8b/trt_engines/bf16/1-gpu/ + --output_dir ./mamba_model/mamba-2.8b/trt_engines/bf16/1-gpu/ # mamba-1.4b -trtllm-build --checkpoint_dir ./mamba/mamba-1.4b/trt_ckpt/fp16/1-gpu/ \ +trtllm-build --checkpoint_dir ./mamba_model/mamba-1.4b/trt_ckpt/fp16/1-gpu/ \ --gpt_attention_plugin disable \ --paged_kv_cache disable \ --remove_input_padding disable \ @@ -134,10 +136,10 @@ trtllm-build --checkpoint_dir ./mamba/mamba-1.4b/trt_ckpt/fp16/1-gpu/ \ --max_batch_size 8 \ --max_input_len 924 \ --max_output_len 100 \ - --output_dir ./mamba/mamba-1.4b/trt_engines/fp16/1-gpu/ + --output_dir ./mamba_model/mamba-1.4b/trt_engines/fp16/1-gpu/ # mamba-790m -trtllm-build --checkpoint_dir ./mamba/mamba-790m/trt_ckpt/fp16/1-gpu/ \ +trtllm-build --checkpoint_dir ./mamba_model/mamba-790m/trt_ckpt/fp16/1-gpu/ \ --gpt_attention_plugin disable \ --paged_kv_cache disable \ --remove_input_padding disable \ @@ -145,10 +147,10 @@ trtllm-build --checkpoint_dir ./mamba/mamba-790m/trt_ckpt/fp16/1-gpu/ \ --max_batch_size 8 \ --max_input_len 924 \ --max_output_len 100 \ - --output_dir ./mamba/mamba-790m/trt_engines/fp16/1-gpu/ + --output_dir ./mamba_model/mamba-790m/trt_engines/fp16/1-gpu/ # mamba-370m -trtllm-build --checkpoint_dir ./mamba/mamba-370m/trt_ckpt/fp16/1-gpu/ \ +trtllm-build --checkpoint_dir ./mamba_model/mamba-370m/trt_ckpt/fp16/1-gpu/ \ --gpt_attention_plugin disable \ --paged_kv_cache disable \ --remove_input_padding disable \ @@ -156,10 +158,10 @@ trtllm-build --checkpoint_dir ./mamba/mamba-370m/trt_ckpt/fp16/1-gpu/ \ --max_batch_size 8 \ --max_input_len 924 \ --max_output_len 100 \ - --output_dir ./mamba/mamba-370m/trt_engines/fp16/1-gpu/ + --output_dir ./mamba_model/mamba-370m/trt_engines/fp16/1-gpu/ # mamba-130m -trtllm-build --checkpoint_dir ./mamba/mamba-130m/trt_ckpt/fp16/1-gpu/ \ +trtllm-build --checkpoint_dir ./mamba_model/mamba-130m/trt_ckpt/fp16/1-gpu/ \ --gpt_attention_plugin disable \ --paged_kv_cache disable \ --remove_input_padding disable \ @@ -167,7 +169,7 @@ trtllm-build --checkpoint_dir ./mamba/mamba-130m/trt_ckpt/fp16/1-gpu/ \ --max_batch_size 8 \ --max_input_len 924 \ --max_output_len 100 \ - --output_dir ./mamba/mamba-130m/trt_engines/fp16/1-gpu/ + --output_dir ./mamba_model/mamba-130m/trt_engines/fp16/1-gpu/ ``` ### 4. Run summarization task with the TensorRT engine(s) @@ -181,48 +183,48 @@ The following section describes how to run a TensorRT-LLM Mamba model to summari # mamba-2.8b-slimpj python ../summarize.py --test_trt_llm \ --use_py_session \ - --hf_model_dir ./mamba-2.8b-slimpj/ \ - --tokenizer_dir ./gpt-neox-20b/ \ + --hf_model_dir ./mamba_model/mamba-2.8b-slimpj/ \ + --tokenizer_dir ./mamba_model/gpt-neox-20b/ \ --data_type bf16 \ - --engine_dir ./mamba/mamba-2.8b-slimpj/trt_engines/bf16/1-gpu/ + --engine_dir ./mamba_model/mamba-2.8b-slimpj/trt_engines/bf16/1-gpu/ # mamba-2.8b python ../summarize.py --test_trt_llm \ --use_py_session \ - --hf_model_dir ./mamba-2.8b/ \ - --tokenizer_dir ./gpt-neox-20b/ \ + --hf_model_dir ./mamba_model/mamba-2.8b/ \ + --tokenizer_dir ./mamba_model/gpt-neox-20b/ \ --data_type bf16 \ - --engine_dir ./mamba/mamba-2.8b/trt_engines/bf16/1-gpu/ + --engine_dir ./mamba_model/mamba-2.8b/trt_engines/bf16/1-gpu/ # mamba-1.4b python ../summarize.py --test_trt_llm \ --use_py_session \ - --hf_model_dir ./mamba-1.4b/ \ - --tokenizer_dir ./gpt-neox-20b/ \ + --hf_model_dir ./mamba_model/mamba-1.4b/ \ + --tokenizer_dir ./mamba_model/gpt-neox-20b/ \ --data_type fp16 \ - --engine_dir ./mamba/mamba-1.4b/trt_engines/fp16/1-gpu/ + --engine_dir ./mamba_model/mamba-1.4b/trt_engines/fp16/1-gpu/ # mamba-790m python ../summarize.py --test_trt_llm \ --use_py_session \ - --hf_model_dir ./mamba-790m/ \ - --tokenizer_dir ./gpt-neox-20b/ \ + --hf_model_dir ./mamba_model/mamba-790m/ \ + --tokenizer_dir ./mamba_model/gpt-neox-20b/ \ --data_type fp16 \ - --engine_dir ./mamba/mamba-790m/trt_engines/fp16/1-gpu/ + --engine_dir ./mamba_model/mamba-790m/trt_engines/fp16/1-gpu/ # mamba-370m python ../summarize.py --test_trt_llm \ --use_py_session \ - --hf_model_dir ./mamba-370m/ \ - --tokenizer_dir ./gpt-neox-20b/ \ + --hf_model_dir ./mamba_model/mamba-370m/ \ + --tokenizer_dir ./mamba_model/gpt-neox-20b/ \ --data_type fp16 \ - --engine_dir ./mamba/mamba-370m/trt_engines/fp16/1-gpu/ + --engine_dir ./mamba_model/mamba-370m/trt_engines/fp16/1-gpu/ # mamba-130m python ../summarize.py --test_trt_llm \ --use_py_session \ - --hf_model_dir ./mamba-130m/ \ - --tokenizer_dir ./gpt-neox-20b/ \ + --hf_model_dir ./mamba_model/mamba-130m/ \ + --tokenizer_dir ./mamba_model/gpt-neox-20b/ \ --data_type fp16 \ - --engine_dir ./mamba/mamba-130m/trt_engines/fp16/1-gpu/ + --engine_dir ./mamba_model/mamba-130m/trt_engines/fp16/1-gpu/ ``` diff --git a/examples/mamba/requirements.txt b/examples/mamba/requirements.txt index b6b028069..2c5c4942b 100644 --- a/examples/mamba/requirements.txt +++ b/examples/mamba/requirements.txt @@ -1,4 +1,4 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.9.0.dev2024031200 +tensorrt_llm==0.9.0.dev2024031900 mamba-ssm==1.1.1 causal-conv1d==1.1.1 # 1.1.2 needs torch 2.2, while TRT-LLM sticks to pre 2.2 diff --git a/examples/medusa/convert_checkpoint.py b/examples/medusa/convert_checkpoint.py index 8645ee9b6..7c6682cdf 100644 --- a/examples/medusa/convert_checkpoint.py +++ b/examples/medusa/convert_checkpoint.py @@ -23,8 +23,7 @@ from tensorrt_llm._utils import str_dtype_to_torch from tensorrt_llm.logger import logger from tensorrt_llm.mapping import Mapping -from tensorrt_llm.models.llama.weight import (load_from_gptq_llama, - load_from_hf_checkpoint) +from tensorrt_llm.models.llama.weight import load_from_hf_checkpoint from tensorrt_llm.models.modeling_utils import PretrainedConfig try: @@ -164,9 +163,6 @@ def parse_arguments(): help= 'Try to reduce the engine size by sharing the embedding lookup table between two layers.' 'Note: the flag might not take effect when the criteria are not met.') - parser.add_argument('--use_prompt_tuning', - action="store_true", - default=False) parser.add_argument('--output_dir', type=str, default='tllm_checkpoint', @@ -700,71 +696,6 @@ def multi_query_split(data, local_dim, head_size, tp_size, cur_rank): return results -class QkvWeightHelper: - """ A helper utility for loading QKV weights from sharded files. """ - - def __init__(self, config: PretrainedConfig): - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.num_kv_heads = config.num_key_value_heads - self.tp_size = config.mapping.tp_size - self.tp_rank = config.mapping.tp_rank - self.is_mha = self.num_heads == self.num_kv_heads - self._qkv_weights = {} - - @staticmethod - def is_qkv_weight(name): - for k in ['q_proj', 'k_proj', 'v_proj']: - if 'self_attn' in name and k in name: - return True - return False - - def add_weight(self, i: int, name: str, weight: torch.Tensor): - if 'q_proj' in name: - tag = 'q' - elif 'k_proj' in name: - tag = 'k' - elif 'v_proj' in name: - tag = 'v' - else: - raise ValueError(f'Got an unexpected parameter of name {name}') - if i not in self._qkv_weights: - self._qkv_weights[i] = {} - self._qkv_weights[i][tag] = weight - - def is_qkv_prepared(self, layer_idx): - if layer_idx not in self._qkv_weights: - return False - weights = self._qkv_weights[layer_idx] - return 'q' in weights and 'k' in weights and 'v' in weights - - def split_qkv_weights(self, layer_idx): - if not self.is_qkv_prepared(layer_idx): - return None - weights = self._qkv_weights.pop(layer_idx) # to prevent memory leak. - q, k, v = (torch.tensor(weights[t]) for t in ['q', 'k', 'v']) - - if not self.is_mha: - head_size = self.hidden_size // self.num_heads - if self.num_kv_heads < self.tp_size: - # duplicate the KV heads up to tensor_parallel - k = dup_kv_weight(k, self.num_kv_heads, self.tp_size) - v = dup_kv_weight(v, self.num_kv_heads, self.tp_size) - assert k.shape[0] % (self.tp_size * head_size) == 0 - assert v.shape[0] % (self.tp_size * head_size) == 0 - wq = split(q, self.tp_size, self.tp_rank) - wk = split(k, self.tp_size, self.tp_rank) - wv = split(v, self.tp_size, self.tp_rank) - fused_qkv = torch.cat((wq, wk, wv), dim=0) - else: - qkv = torch.cat([q, k, v], dim=0) - qkv = qkv.reshape(3, q.shape[0], q.shape[1]) - fused_qkv = split(qkv, self.tp_size, self.tp_rank, dim=1) - fused_qkv = fused_qkv.reshape(3 * (q.shape[0] // self.tp_size), - q.shape[1]) - return fused_qkv - - def convert_hf_llama(hf_model, mapping, rank=0, @@ -1137,7 +1068,6 @@ def convert_hf_llama(hf_model, 'use_parallel_embedding': args.use_parallel_embedding, 'embedding_sharding_dim': args.embedding_sharding_dim, 'share_embedding_table': args.use_embedding_sharing, - 'use_prompt_tuning': args.use_prompt_tuning, 'max_draft_len': args.max_medusa_token_len, 'num_medusa_heads': args.num_medusa_heads, 'num_medusa_layers': args.num_medusa_layers @@ -1228,11 +1158,7 @@ def covert_and_save(rank, convert_args): pp_size=args.pp_size) if args.use_weight_only and args.weight_only_precision == 'int4_gptq': - - weights = load_from_gptq_llama(args.ammo_quant_ckpt_path, - hf_config, - mapping, - dtype=args.dtype) + assert False, "Never supported" else: if args.load_by_shard: weights = load_from_hf_checkpoint( diff --git a/examples/medusa/requirements.txt b/examples/medusa/requirements.txt index c14ed53cd..b9ba2a338 100644 --- a/examples/medusa/requirements.txt +++ b/examples/medusa/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.9.0.dev2024031200 +tensorrt_llm==0.9.0.dev2024031900 datasets~=2.14.5 rouge_score~=0.1.2 sentencepiece~=0.1.99 diff --git a/examples/mixtral/requirements.txt b/examples/mixtral/requirements.txt index 598fa22ca..0916a18b2 100644 --- a/examples/mixtral/requirements.txt +++ b/examples/mixtral/requirements.txt @@ -1,4 +1,4 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.9.0.dev2024031200 -transformers==4.36.1 +tensorrt_llm==0.9.0.dev2024031900 +transformers==4.38.2 accelerate==0.25.0 diff --git a/examples/model_api/llama.py b/examples/model_api/llama.py index 9a6f995e9..252c2ab89 100644 --- a/examples/model_api/llama.py +++ b/examples/model_api/llama.py @@ -1,5 +1,6 @@ import argparse import os +from pathlib import Path from tensorrt_llm.executor import GenerationExecutor from tensorrt_llm.models import LLaMAForCausalLM @@ -17,7 +18,7 @@ def parse_args(): parser = argparse.ArgumentParser(description="Llama single model example") parser.add_argument( "--engine_dir", - type=str, + type=Path, required=True, help= "Directory to save and load the engine. When -c is specified, always rebuild and save to this dir. When -c is not specified, load engine when the engine_dir exists, rebuild otherwise" @@ -43,13 +44,14 @@ def main(): tokenizer_dir = args.hf_model_dir max_batch_size, max_isl, max_osl = 1, 256, 20 - if args.clean_build or not os.path.exists(args.engine_dir): + if args.clean_build or not args.engine_dir.exists(): + args.engine_dir.mkdir(exist_ok=True, parents=True) os.makedirs(args.engine_dir, exist_ok=True) llama = LLaMAForCausalLM.from_hugging_face(args.hf_model_dir) llama.to_trt(max_batch_size, max_isl, max_osl) - llama.save(args.engine_dir) + llama.save(str(args.engine_dir)) - executor = GenerationExecutor(args.engine_dir, tokenizer_dir) + executor = GenerationExecutor.create(args.engine_dir, tokenizer_dir) for inp in read_input(): output = executor.generate(inp, max_new_tokens=20) diff --git a/examples/model_api/llama_multi_gpu.py b/examples/model_api/llama_multi_gpu.py index f1e8b8e71..b09f2af2c 100644 --- a/examples/model_api/llama_multi_gpu.py +++ b/examples/model_api/llama_multi_gpu.py @@ -1,12 +1,13 @@ import argparse import os +from pathlib import Path import torch from mpi4py.futures import MPIPoolExecutor import tensorrt_llm from tensorrt_llm import Mapping, mpi_barrier -from tensorrt_llm.executor import GenerationExecutor +from tensorrt_llm.executor import GenerationExecutorWorker from tensorrt_llm.models import LLaMAForCausalLM @@ -36,10 +37,12 @@ def build_and_run_llama(hf_model_dir, engine_dir, tp_size, rank, clean_build): mpi_barrier() # make sure every rank engine build finished generate_len = 20 # change on your needs, hard code for simplicity here - executor = GenerationExecutor(engine_dir, tokenizer_dir) + executor = GenerationExecutorWorker(Path(engine_dir), tokenizer_dir) - output_streams = executor.generate_async(dataset(), True, - [generate_len] * len(dataset())) + output_streams = executor.generate_async(dataset(), + True, + max_new_tokens=[generate_len] * + len(dataset())) if rank == 0: for stream in output_streams: for state in stream: diff --git a/examples/model_api/llama_quantize.py b/examples/model_api/llama_quantize.py new file mode 100644 index 000000000..278e005e1 --- /dev/null +++ b/examples/model_api/llama_quantize.py @@ -0,0 +1,78 @@ +import argparse +import os +from pathlib import Path + +import tensorrt_llm +import tensorrt_llm.quantization.mode as quant_mode +from tensorrt_llm.builder import BuildConfig, build +from tensorrt_llm.executor import GenerationExecutor +from tensorrt_llm.models import LLaMAForCausalLM +from tensorrt_llm.models.modeling_utils import QuantizationConfig + + +def read_input(): + while (True): + input_text = input("<") + if input_text in ("q", "quit"): + break + yield input_text + + +def parse_args(): + parser = argparse.ArgumentParser(description="Llama single model example") + parser.add_argument( + "--cache_dir", + type=str, + required=True, + help= + "Directory to save and load the engine and checkpoint. When -c is specified, always rebuild and save to this dir. When -c is not specified, load engine when the engine_dir exists, rebuild otherwise" + ) + parser.add_argument( + "--hf_model_dir", + type=str, + required=True, + help="Read the model data and tokenizer from this directory") + parser.add_argument( + "-c", + "--clean_build", + default=False, + action="store_true", + help= + "Clean build the engine even if the cache dir exists, be careful, this overwrites the cache dir!!" + ) + return parser.parse_args() + + +def main(): + tensorrt_llm.logger.set_level('verbose') + args = parse_args() + tokenizer_dir = args.hf_model_dir + max_batch_size, max_isl, max_osl = 1, 256, 20 + build_config = BuildConfig(max_input_len=max_isl, + max_output_len=max_osl, + max_batch_size=max_batch_size) + cache_dir = Path(args.cache_dir) + checkpoint_dir = cache_dir / "trtllm_checkpoint" + engine_dir = cache_dir / "trtllm_engine" + + if args.clean_build or not cache_dir.exists(): + os.makedirs(cache_dir, exist_ok=True) + quant_config = QuantizationConfig() + quant_config.quant_algo = quant_mode.W4A16_AWQ + if not checkpoint_dir.exists(): + LLaMAForCausalLM.quantize(args.hf_model_dir, + checkpoint_dir, + quant_config=quant_config, + calib_batches=1) + llama = LLaMAForCausalLM.from_checkpoint(checkpoint_dir) + engine = build(llama, build_config) + engine.save(engine_dir) + + executor = GenerationExecutor(engine_dir, tokenizer_dir) + + for inp in read_input(): + output = executor.generate(inp, max_new_tokens=20) + print(f">{output.text}") + + +main() diff --git a/examples/mpt/requirements.txt b/examples/mpt/requirements.txt index 9f15a5f00..8b80a3084 100644 --- a/examples/mpt/requirements.txt +++ b/examples/mpt/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.9.0.dev2024031200 +tensorrt_llm==0.9.0.dev2024031900 datasets~=2.14.5 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/multimodal/build_visual_engine.py b/examples/multimodal/build_visual_engine.py index 22389d9ed..291685920 100644 --- a/examples/multimodal/build_visual_engine.py +++ b/examples/multimodal/build_visual_engine.py @@ -45,7 +45,7 @@ def build_trt_engine(img_height, img_width, output_dir, max_batch_size): parser = trt.OnnxParser(network, logger) with open(onnx_file, 'rb') as model: - if not parser.parse(model.read(), "/".join(onnx_file.split("/"))): + if not parser.parse(model.read(), os.path.abspath(onnx_file)): logger.log(trt.Logger.ERROR, "Failed parsing %s" % onnx_file) for error in range(parser.num_errors): logger.log(trt.Logger.ERROR, parser.get_error(error)) diff --git a/examples/opt/requirements.txt b/examples/opt/requirements.txt index 9f15a5f00..8b80a3084 100644 --- a/examples/opt/requirements.txt +++ b/examples/opt/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.9.0.dev2024031200 +tensorrt_llm==0.9.0.dev2024031900 datasets~=2.14.5 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/phi/requirements.txt b/examples/phi/requirements.txt index 6896c3fa4..456d84b04 100644 --- a/examples/phi/requirements.txt +++ b/examples/phi/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.9.0.dev2024031200 +tensorrt_llm==0.9.0.dev2024031900 datasets~=2.14.5 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/quantization/requirements.txt b/examples/quantization/requirements.txt index a2a9ceb9c..010471278 100644 --- a/examples/quantization/requirements.txt +++ b/examples/quantization/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.9.0.dev2024031200 +tensorrt_llm==0.9.0.dev2024031900 datasets>=2.14.4 nemo-toolkit[all]<=1.20.0,>=1.18.0 rouge_score~=0.1.2 diff --git a/examples/qwen/README.md b/examples/qwen/README.md index 7343bfca5..f8bb24a5a 100644 --- a/examples/qwen/README.md +++ b/examples/qwen/README.md @@ -1,12 +1,12 @@ # Qwen -This document shows how to build and run a Qwen model in TensorRT-LLM on both single GPU, single node multi-GPU and multi-node multi-GPU. +This document shows how to build and run a Qwen model in TensorRT-LLM on both single GPU, single node multi-GPU. ## Overview The TensorRT-LLM Qwen implementation can be found in [model.py](../../tensorrt_llm/models/qwen/model.py). The TensorRT-LLM Qwen example code is located in [`examples/qwen`](./). There is one main file: -* [`build.py`](./build.py) to build the [TensorRT](https://developer.nvidia.com/tensorrt) engine(s) needed to run the Qwen model. +* [`convert_checkpoint.py`](./convert_checkpoint.py) to build the [TensorRT](https://developer.nvidia.com/tensorrt) engine(s) needed to run the Qwen model. In addition, there are two shared files in the parent folder [`examples`](../) for inference and evaluation: @@ -16,13 +16,13 @@ In addition, there are two shared files in the parent folder [`examples`](../) f ## Support Matrix | Model Name | FP16 | FMHA | WO | AWQ | GPTQ | SQ | TP | PP | ST | C++ Runtime | benchmark | IFB | Arch | | :--------------: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |:---:|:---: | :---------: | :-------: | :---: | :---: | -| Qwen-7B-Chat | Y | Y | Y | Y | Y | Y | Y | Y | Y | Y | Y | Y | Ampere+ | -| Qwen-14B-Chat | Y | Y | Y | Y* | Y | Y | Y | Y | Y | Y | Y | Y | Ampere+ | -| Qwen-72B-Chat | Y | Y | Y | - | Y | Y | Y | Y | Y | Y | Y | Y | Ampere+ | +| Qwen-7B(-Chat) | Y | Y | Y | Y | Y | Y | Y | Y | Y | Y | Y | Y | Ampere+ | +| Qwen-14B(-Chat) | Y | Y | Y | Y* | Y | Y | Y | Y | Y | Y | Y | Y | Ampere+ | +| Qwen-72B(-Chat) | Y | Y | Y | - | Y | Y | Y | Y | Y | Y | Y | Y | Ampere+ | *Please note that Qwen-14B-Chat model supports AWQ only with single GPU. * Model Name: the name of the model, the same as the name on HuggingFace -* FMHA: Fused MultiHead Attention (see introduction below) +* FMHA: Fused MultiHead Attention * WO: Weight Only Quantization (int8 / int4) * AWQ: Activation Aware Weight Quantization (int4) * GPTQ: Generative Pretrained Transformer Quantization (int4) @@ -30,9 +30,9 @@ In addition, there are two shared files in the parent folder [`examples`](../) f * TP: Tensor Parallel * PP: Pipeline Parallel * ST: Strongly Typed -* IFB: In-flight Batching (see introduction below) +* IFB: In-flight Batching -*Currently Qwen models does not support dynamic NTK and logn attention. Therefore, accuracy on long sequence input is not promised. +*Currently Qwen models does not support dynamic NTK and logn attention. Therefore, accuracy on long sequence input for the 7B and 14B model is not promised. ## Usage @@ -40,215 +40,139 @@ The TensorRT-LLM Qwen example code locates at [examples/qwen](./). It takes HF w ### Build TensorRT engine(s) -Need to prepare the HF Qwen checkpoint first by following the guides here [Qwen-7B-Chat](https://huggingface.co/Qwen/Qwen-7B-Chat) or [Qwen-14B-Chat](https://huggingface.co/Qwen/Qwen-14B-Chat) +Need to prepare the HF Qwen checkpoint first by following the guides here [Qwen-7B-Chat](https://huggingface.co/Qwen/Qwen-7B-Chat) -Create a `tmp/Qwen` directory to store the weights downloaded from huaggingface. -```bash -mkdir -p ./tmp/Qwen -``` +TensorRT-LLM builds TensorRT engine(s) from HF checkpoint. If no checkpoint directory is specified, TensorRT-LLM will build engine(s) with dummy weights. -Store Qwen-7B-Chat or Qwen-14B-Chat separately. -- for Qwen-7B-Chat -```bash -mv Qwen-7B-Chat ./tmp/Qwen/7B -``` -- for Qwen-14B-Chat -``` -mv Qwen-14B-Chat ./tmp/Qwen/14B -``` - -TensorRT-LLM Qwen builds TensorRT engine(s) from HF checkpoint. If no checkpoint directory is specified, TensorRT-LLM will build engine(s) with dummy weights. - -Normally `build.py` only requires single GPU, but if you've already got all the GPUs needed for inference, you could enable parallel-building to make the engine building process faster by adding `--parallel_build` argument. Please note that currently `parallel_build` feature only supports single node. +Normally `trtllm-build` only requires single GPU, but if you've already got all the GPUs needed while inferencing, you could enable parallelly building to make the engine building process faster by adding `--workers` argument. Please note that currently `workers` feature only supports single node. Here're some examples: ```bash # Build a single-GPU float16 engine from HF weights. -# use_gpt_attention_plugin is necessary in Qwen. # Try use_gemm_plugin to prevent accuracy issue. -# It is recommend to use --remove_input_padding along with --use_gpt_attention_plugin for better performance - -# Build the Qwen 7B model using a single GPU and FP16. -python build.py --model_dir ./tmp/Qwen/7B/ \ - --dtype float16 \ - --remove_input_padding \ - --use_gpt_attention_plugin float16 \ - --enable_context_fmha \ - --use_gemm_plugin float16 \ - --output_dir ./tmp/Qwen/7B/trt_engines/fp16/1-gpu/ - -# Build the Qwen 7B model using a single GPU and BF16. -python build.py --model_dir ./tmp/Qwen/7B/ \ - --dtype bfloat16 \ - --remove_input_padding \ - --use_gpt_attention_plugin bfloat16 \ - --enable_context_fmha \ - --use_gemm_plugin bfloat16 \ - --output_dir ./tmp/Qwen/7B/trt_engines/bf16/1-gpu/ - -# Build the Qwen 7B model using a single GPU and apply INT8 weight-only quantization. -python build.py --model_dir ./tmp/Qwen/7B/ \ - --dtype float16 \ - --remove_input_padding \ - --use_gpt_attention_plugin float16 \ - --use_gemm_plugin float16 \ - --use_weight_only \ - --weight_only_precision int8 \ - --output_dir ./tmp/Qwen/7B/trt_engines/int8_weight_only/1-gpu/ - -# Build the Qwen 7B model using a single GPU and apply INT4 weight-only quantization. -python build.py --model_dir ./tmp/Qwen/7B/ \ - --dtype float16 \ - --remove_input_padding \ - --use_gpt_attention_plugin float16 \ - --use_gemm_plugin float16 \ - --use_weight_only \ - --weight_only_precision int4 \ - --output_dir ./tmp/Qwen/7B/trt_engines/int4_weight_only/1-gpu/ - -# Build Qwen 7B using 2-way tensor parallelism. -python build.py --model_dir ./tmp/Qwen/7B/ \ - --dtype float16 \ - --remove_input_padding \ - --use_gpt_attention_plugin float16 \ - --enable_context_fmha \ - --use_gemm_plugin float16 \ - --output_dir ./tmp/Qwen/7B/trt_engines/fp16/2-gpu/ \ - --world_size 2 \ - --tp_size 2 - -# Build Qwen 7B using 2-way tensor parallelism and 2-way pipeline parallelism. -python build.py --model_dir ./tmp/Qwen/7B/ \ - --dtype float16 \ - --remove_input_padding \ - --use_gpt_attention_plugin float16 \ - --enable_context_fmha \ - --use_gemm_plugin float16 \ - --output_dir ./tmp/Qwen/7B/trt_engines/fp16/2-gpu/ \ - --world_size 4 \ - --tp_size 2 \ - --pp_size 2 - -# Build Qwen 14B using 2-way tensor parallelism. -python build.py --model_dir ./tmp/Qwen/14B \ - --dtype float16 \ - --remove_input_padding \ - --use_gpt_attention_plugin float16 \ - --enable_context_fmha \ - --use_gemm_plugin float16 \ - --output_dir ./tmp/Qwen/14B/trt_engines/fp16/2-gpu/ \ - --world_size 2 \ - --tp_size 2 - -# Build Qwen 72B using 8-way tensor parallelism. -python build.py --model_dir ./tmp/Qwen/72B \ - --dtype float16 \ - --remove_input_padding \ - --use_gpt_attention_plugin float16 \ - --enable_context_fmha \ - --use_gemm_plugin float16 \ - --output_dir ./tmp/Qwen/72B/trt_engines/fp16/8-gpu/ \ - --world_size 8 \ - --tp_size 8 -``` -**Demo output of engine building:** -```python -python3 build.py --model_dir /llm-models/Qwen-7B-Chat/ --output_dir /engine_qwen -``` -``` -[11/09/2023-00:57:06] [TRT-LLM] [I] Serially build TensorRT engines. -[11/09/2023-00:57:06] [TRT] [I] [MemUsageChange] Init CUDA: CPU +14, GPU +0, now: CPU 118, GPU 427 (MiB) -[11/09/2023-00:57:08] [TRT] [I] [MemUsageChange] Init builder kernel library: CPU +1974, GPU +350, now: CPU 2227, GPU 777 (MiB) -[11/09/2023-00:57:08] [TRT-LLM] [W] Invalid timing cache, using freshly created one -[11/09/2023-00:57:14] [TRT-LLM] [I] Loading HF QWen ... from /llm-models/Qwen-7B-Chat/ -...... -[11/09/2023-01:01:34] [TRT] [I] [MemUsageStats] Peak memory usage during Engine building and serialization: CPU: 47322 MiB -[11/09/2023-01:01:34] [TRT-LLM] [I] Total time of building qwen_float16_tp1_rank0.engine: 00:03:44 -[11/09/2023-01:01:34] [TRT-LLM] [I] Config saved to /engine_qwen/config.json. -[11/09/2023-01:01:34] [TRT-LLM] [I] Serializing engine to /engine_qwen/qwen_float16_tp1_rank0.engine... -[11/09/2023-01:01:49] [TRT-LLM] [I] Engine serialized. Total time: 00:00:14 -[11/09/2023-01:01:49] [TRT-LLM] [I] Timing cache serialized to /engine_qwen/model.cache -[11/09/2023-01:01:50] [TRT-LLM] [I] Total time of building all 1 engines: 00:04:43 -``` - - -#### INT8 weight only + INT8 KV cache -For INT8 KV cache, [`hf_qwen_convert.py`](./hf_qwen_convert.py) features a -`--calibrate-kv-cache, -kv` option. Setting `-kv` will calibrate the model, -and then export the scaling factors needed for INT8 KV cache inference. +# Build the Qwen-7B-Chat model using a single GPU and FP16. +python convert_checkpoint.py --model_dir ./tmp/Qwen/7B/ \ + --output_dir ./tllm_checkpoint_1gpu_fp16 \ + --dtype float16 + +trtllm-build --checkpoint_dir ./tllm_checkpoint_1gpu_fp16 \ + --output_dir ./tmp/qwen/7B/trt_engines/fp16/1-gpu \ + --gemm_plugin float16 + +# Build the Qwen-7B-Chat model using a single GPU and BF16. +python convert_checkpoint.py --model_dir ./tmp/Qwen/7B/ \ + --output_dir ./tllm_checkpoint_1gpu_bf16 \ + --dtype bfloat16 + +trtllm-build --checkpoint_dir ./tllm_checkpoint_1gpu_bf16 \ + --output_dir ./tmp/qwen/7B/trt_engines/bf16/1-gpu \ + --gpt_attention_plugin bfloat16 \ + --gemm_plugin bfloat16 + +# Build the Qwen-7B-Chat model using a single GPU and apply INT8 weight-only quantization. +python convert_checkpoint.py --model_dir ./tmp/Qwen/7B/ \ + --output_dir ./tllm_checkpoint_1gpu_fp16_wq \ + --dtype float16 \ + --use_weight_only \ + --weight_only_precision int8 + +trtllm-build --checkpoint_dir ./tllm_checkpoint_1gpu_fp16_wq \ + --output_dir ./tmp/qwen/7B/trt_engines/weight_only/1-gpu/ \ + --gemm_plugin float16 + +# Build the Qwen-7B-Chat model using a single GPU and apply INT4 weight-only quantization. +python convert_checkpoint.py --model_dir ./tmp/Qwen/7B/ \ + --output_dir ./tllm_checkpoint_1gpu_fp16_wq \ + --dtype float16 \ + --use_weight_only \ + --weight_only_precision int4 + +trtllm-build --checkpoint_dir ./tllm_checkpoint_1gpu_fp16_wq \ + --output_dir ./tmp/qwen/7B/trt_engines/weight_only/1-gpu/ \ + --gemm_plugin float16 + +# Build Qwen-7B-Chat using 2-way tensor parallelism. +python convert_checkpoint.py --model_dir ./tmp/Qwen/7B/ \ + --output_dir ./tllm_checkpoint_2gpu_tp2 \ + --dtype float16 \ + --tp_size 2 + +trtllm-build --checkpoint_dir ./tllm_checkpoint_2gpu_tp2 \ + --output_dir ./tmp/qwen/7B/trt_engines/fp16/2-gpu/ \ + --gemm_plugin float16 + +# Build Qwen-7B-Chat using 2-way tensor parallelism and 2-way pipeline parallelism. +python convert_checkpoint.py --model_dir ./tmp/Qwen/7B/ \ + --output_dir ./tllm_checkpoint_4gpu_tp2_pp2 \ + --dtype float16 \ + --tp_size 2 \ + --pp_size 2 +trtllm-build --checkpoint_dir ./tllm_checkpoint_4gpu_tp2_pp2 \ + --output_dir ./tmp/qwen/7B/trt_engines/fp16/4-gpu/ \ + --gemm_plugin float16 + +# Build Qwen-14B-Chat using 2-way tensor parallelism. +python convert_checkpoint.py --model_dir ./tmp/Qwen/14B/ \ + --output_dir ./tllm_checkpoint_2gpu_tp2 \ + --dtype float16 \ + --tp_size 2 + +trtllm-build --checkpoint_dir ./tllm_checkpoint_2gpu_tp2 \ + --output_dir ./tmp/qwen/14B/trt_engines/fp16/2-gpu/ \ + --gemm_plugin float16 \ + +# Build Qwen-72B-Chat using 8-way tensor parallelism. +python convert_checkpoint.py --model_dir ./tmp/Qwen/72B/ \ + --output_dir ./tllm_checkpoint_8gpu_tp8 \ + --dtype float16 \ + --tp_size 8 + +trtllm-build --checkpoint_dir ./tllm_checkpoint_8gpu_tp8 \ + --output_dir ./tmp/qwen/72B/trt_engines/fp16/8-gpu/ \ + --gemm_plugin float16 \ +``` + +#### INT8 KV cache +INT8 KV cache could be enabled to reduce memory footprint. It will bring more performance gains when batch size gets larger. + +For INT8 KV cache, [`convert_checkpoint.py`](./convert_checkpoint.py) features a +`--int8_kv_cache` option. Setting `--int8_kv_cache` will calibrate the model, +and then export the scaling factors needed for INT8 KV cache inference. Remember to set `--strongly_typed` when building the engine if you are not using INT8 weight only quantization at the same time. Example: ```bash -python3 hf_qwen_convert.py \ - -i ./tmp/Qwen/7B/ \ - -o ./tmp/Qwen/7B/int8_kv_cache/ \ - --calibrate-kv-cache -t float16 -``` - -[`build.py`](./build.py) add new options for the support of INT8 KV cache. - -`--int8_kv_cache` is the command-line option to enable INT8 KV cache. - -In addition, it could be combined with INT8 weight-only quantization, as follows: - -Examples of INT8 weight-only quantization + INT8 KV cache - -```bash -# Build model with both INT8 weight-only and INT8 KV cache enabled -python build.py --bin_model_dir ./tmp/Qwen/7B/int8_kv_cache/1-gpu/ \ - --dtype float16 \ - --model_dir ./tmp/Qwen/7B \ - --use_gpt_attention_plugin float16 \ - --use_gemm_plugin float16 \ - --output_dir ./tmp/Qwen/7B/trt_engines/int8_kv_cache_weight_only/1-gpu \ - --int8_kv_cache \ - --use_weight_only -``` +python convert_checkpoint.py --model_dir ./tmp/Qwen/7B/ \ + --output_dir ./tllm_checkpoint_1gpu_fp16_int8kv + --dtype float16 \ + --int8_kv_cache -- run -```bash -python3 ../run.py --input_text "你好,请问你叫什么?" \ - --max_output_len=50 \ - --tokenizer_dir ./tmp/Qwen/7B/ \ - --engine_dir=./tmp/Qwen/7B/trt_engines/int8_kv_cache_weight_only/1-gpu +trtllm-build --checkpoint_dir ./tllm_checkpoint_1gpu_sq \ + --output_dir ./engine_outputs \ + --strongly_typed + --gemm_plugin float16 ``` -Test with `../summarize.py`: +[`convert_checkpoint.py`](./convert_checkpoint.py) add new options for the support of INT8 KV cache. -- validate huggingface -```bash -python3 ../summarize.py --test_hf \ - --tokenizer_dir ./tmp/Qwen/7B \ - --model_dir ./tmp/Qwen/7B \ - --max_input_length 2048 \ - --output_len 2048 -``` - -- validate trt-llm -```bash -python3 ../summarize.py --test_trt_llm \ - --tokenizer_dir ./tmp/Qwen/7B \ - --engine_dir ./tmp/Qwen/7B/trt_engines/int8_kv_cache_weight_only/1-gpu \ - --max_input_length 2048 \ - --output_len 2048 -``` - #### SmoothQuant -The smoothquant supports both Qwen v1 and Qwen v2. Unlike the FP16 build where the HF weights are processed and loaded into the TensorRT-LLM directly, the SmoothQuant needs to load INT8 weights which should be pre-processed before building an engine. +The smoothquant supports Qwen models. Unlike the FP16 build where the HF weights are processed and loaded into the TensorRT-LLM directly, the SmoothQuant needs to load INT8 weights which should be pre-processed before building an engine. Example: ```bash -python3 hf_qwen_convert.py -i ./tmp/Qwen/7B -o ./tmp/Qwen/7B/sq0.5/ -sq 0.5 --tensor-parallelism 1 --storage-type float16 +python3 convert_checkpoint.py --model_dir ./tmp/Qwen/7B/ --output_dir ./tllm_checkpoint_1gpu_sq --dtype float16 --smoothquant 0.5 +trtllm-build --checkpoint_dir ./tllm_checkpoint_1gpu_sq \ + --output_dir ./engine_outputs \ + --gemm_plugin float16 ``` -[`build.py`](./build.py) add new options for the support of INT8 inference of SmoothQuant models. +[`convert_checkpoint.py`](./convert_checkpoint.py) add new options for the support of INT8 inference of SmoothQuant models. -`--use_smooth_quant` is the starting point of INT8 inference. By default, it +`--smoothquant` is the starting point of INT8 inference. By default, it will run the model in the _per-tensor_ mode. Then, you can add any combination of `--per-token` and `--per-channel` to get the corresponding behaviors. @@ -257,99 +181,36 @@ Examples of build invocations: ```bash # Build model for SmoothQuant in the _per_token_ + _per_channel_ mode -python3 build.py --bin_model_dir=./tmp/Qwen/7B/sq0.5/1-gpu/ \ - --use_gpt_attention_plugin float16 \ - --remove_input_padding \ - --enable_context_fmha \ - --use_smooth_quant \ - --per_token \ - --per_channel \ - --model_dir ./tmp/Qwen/7B \ - --output_dir ./tmp/Qwen/7B/trt_engines/sq0.5/1-gpu/ -``` +python3 convert_checkpoint.py --model_dir ./tmp/Qwen/7B/ \ + --output_dir ./tllm_checkpoint_1gpu_sq \ + --dtype float16 \ + --smoothquant 0.5 \ + --per_token \ + --per_channel -- run -```bash -python3 ../run.py --input_text "你好,请问你叫什么?" \ - --max_output_len=50 \ - --tokenizer_dir ./tmp/Qwen/7B/ \ - --engine_dir=./tmp/Qwen/7B/trt_engines/sq0.5/1-gpu/ +trtllm-build --checkpoint_dir ./tllm_checkpoint_1gpu_sq \ + --output_dir ./engine_outputs \ + --gemm_plugin float16 ``` -- summarize -```bash -python ../summarize.py --test_trt_llm \ - --tokenizer_dir ./tmp/Qwen/7B/ \ - --data_type fp16 \ - --engine_dir=./tmp/Qwen/7B/trt_engines/sq0.5/1-gpu/ \ - --max_input_length 2048 \ - --output_len 2048 -``` #### INT4-GPTQ -To run the GPTQ Qwen example, the following steps are required: -1. Install auto-gptq module: +You may find the official GPTQ quantized INT4 weights of Qwen-7B-Chat here: [Qwen-7B-Chat-Int4](https://huggingface.co/Qwen/Qwen-7B-Chat-Int4). And you need to first install auto-gptq: ```bash pip install auto-gptq ``` -2. Download quantized weights, for Qwen-7B-Chat, you can find it [here](https://huggingface.co/Qwen/Qwen-7B-Chat-Int4): +Example of building engine for INT4 GPTQ quantized Qwen model: ```bash -git lfs install -git clone https://huggingface.co/Qwen/Qwen-7B-Chat-Int4 -``` +python3 convert_checkpoint.py --model_dir ./tmp/Qwen-7B-Chat-Int4 \ + --output_dir ./tllm_checkpoint_1gpu_gptq \ + --dtype float16 \ + --use_weight_only \ + --weight_only_precision int4_gptq \ + --per_group \ -3. Build TRT-LLM engine: -```bash -python build.py --model_dir Qwen-7B-Chat-Int4 \ -                --quant_ckpt_path Qwen-7B-Chat-Int4 \ -                --dtype float16 \ -                --remove_input_padding \ -                --use_gpt_attention_plugin float16 \ -                --enable_context_fmha \ -                --use_gemm_plugin float16 \ -                --use_weight_only \ -                --weight_only_precision int4_gptq \ -                --per_group \ -                --world_size 1 \ -                --tp_size 1 \ -                --output_dir ./tmp/Qwen/7B/trt_engines/int4-gptq/1-gpu -``` - -4. Run int4-gptq -```bash -python3 ../run.py --input_text "你好,请问你叫什么?" \ - --max_output_len=50 \ - --tokenizer_dir Qwen-7B-Chat-Int4 \ - --engine_dir=./tmp/Qwen/7B/trt_engines/int4-gptq/1-gpu -``` -``` -...... -Input [Text 0]: "<|im_start|>system -You are a helpful assistant.<|im_end|> -<|im_start|>user -你好,请问你叫什么?<|im_end|> -<|im_start|>assistant -" -Output [Text 0 Beam 0]: "你好,我是通义千问,由阿里云开发。" -``` - -5. Summarize -- validate huggingface -```bash -python3 ../summarize.py --test_hf \ - --tokenizer_dir ./tmp/Qwen/7B \ - --model_dir ./tmp/Qwen/7B \ - --max_input_length 2048 \ - --output_len 2048 -``` - -- validate trt-llm -```bash -python3 ../summarize.py --test_trt_llm \ - --tokenizer_dir ./tmp/Qwen/7B \ - --engine_dir ./tmp/Qwen/7B/trt_engines/int4-gptq/1-gpu \ - --max_input_length 2048 \ - --output_len 2048 +trtllm-build --checkpoint_dir ./tllm_checkpoint_1gpu_gptq \ + --output_dir ./tmp/Qwen/7B/trt_engines/int4_GPTQ/1-gpu/ \ + --gemm_plugin float16 ``` #### INT4-AWQ @@ -358,70 +219,27 @@ To run the AWQ Qwen example, the following steps are required: NVIDIA AMMO toolkit is used for AWQ weight quantization. Please see [examples/quantization/README.md](/examples/quantization/README.md#preparation) for AMMO installation instructions. -```bash -python3 ../quantization/quantize.py --model_dir ./tmp/Qwen/7B \ - --dtype float16 \ - --qformat int4_awq \ - --output_dir ./qwen_7b_4bit_gs128_awq.pt \ - --calib_size 32 -``` - -2. TRT-LLM engine: -```bash -python build.py --model_dir ./tmp/Qwen/7B \ - --quant_ckpt_path ./qwen_7b_4bit_gs128_awq.pt \ - --dtype float16 \ - --remove_input_padding \ - --use_gpt_attention_plugin float16 \ - --enable_context_fmha \ - --use_gemm_plugin float16 \ - --use_weight_only \ - --weight_only_precision int4_awq \ - --per_group \ - --world_size 1 \ - --tp_size 1 \ - --output_dir ./tmp/Qwen/7B/trt_engines/int4-awq/1-gpu -``` -3. Run int4-awq -```bash -python3 ../run.py --input_text "你好,请问你叫什么?" \ - --max_output_len=50 \ - --tokenizer_dir ./tmp/Qwen/7B/ \ - --engine_dir=./tmp/Qwen/7B/trt_engines/int4-awq/1-gpu -``` -``` -...... -Input [Text 0]: "<|im_start|>system -You are a helpful assistant.<|im_end|> -<|im_start|>user -你好,请问你叫什么?<|im_end|> -<|im_start|>assistant -" -Output [Text 0 Beam 0]: "你好,我叫通义千问,是由阿里云开发的AI助手。有什么我可以帮助你的吗?" -``` + ```bash + # Quantize Qwen-7B-Chat checkpoint into INT4 AWQ format + python ../quantization/quantize.py --model_dir ./tmp/Qwen/7B/ \ + --dtype float16 \ + --qformat int4_awq \ + --awq_block_size 128 \ + --output_dir ./quantized_int4-awq \ + --calib_size 32 + ``` -4. Summarize -- validate huggingface -```bash -python3 ../summarize.py --test_hf \ - --tokenizer_dir ./tmp/Qwen/7B \ - --model_dir ./tmp/Qwen/7B \ - --max_input_length 2048 \ - --output_len 2048 -``` +2. Build TRT-LLM engine: -- validate trt-llm -```bash -python3 ../summarize.py --test_trt_llm \ - --tokenizer_dir ./tmp/Qwen/7B \ - --engine_dir ./tmp/Qwen/7B/trt_engines/int4-awq/1-gpu \ - --max_input_length 2048 \ - --output_len 2048 -``` + ```bash + trtllm-build --checkpoint_dir ./quantized_int4-awq \ + --output_dir ./tmp/qwen/7B/trt_engines/int4_AWQ/1-gpu/ \ + --gemm_plugin float16 + ``` ### Run -To run a TensorRT-LLM Qwen model using the engines generated by build.py +To run a TensorRT-LLM Qwen model using the engines generated by `trtllm-build` ```bash # With fp16 inference @@ -441,40 +259,77 @@ python3 ../run.py --input_text "你好,请问你叫什么?" \ --max_output_len=50 \ --tokenizer_dir ./tmp/Qwen/7B/ \ --engine_dir=./tmp/Qwen/7B/trt_engines/int8_weight_only/1-gpu/ +``` +``` +Input [Text 0]: "<|im_start|>system +You are a helpful assistant.<|im_end|> +<|im_start|>user +你好,请问你叫什么?<|im_end|> +<|im_start|>assistant +" +Output [Text 0 Beam 0]: "你好,我是来自阿里云的大规模语言模型,我叫通义千问。<|im_end|> +<|im_start|> +<|im_start|> + +" +``` +```bash # With int4 weight only inference python3 ../run.py --input_text "你好,请问你叫什么?" \ --max_output_len=50 \ --tokenizer_dir ./tmp/Qwen/7B/ \ --engine_dir=./tmp/Qwen/7B/trt_engines/int4_weight_only/1-gpu/ - -# Run 72B model with 8-gpu -mpirun -n 8 --allow-run-as-root \ - python ../run.py --input_text "What is your name?" \ - --max_output_len=50 \ - --tokenizer_dir ./tmp/Qwen/72B/ \ - --engine_dir=./tmp/Qwen/72B/trt_engines/fp16/8-gpu/ +``` +``` +Input [Text 0]: "<|im_start|>system +You are a helpful assistant.<|im_end|> +<|im_start|>user +你好,请问你叫什么?<|im_end|> +<|im_start|>assistant +" +Output [Text 0 Beam 0]: "我叫通义千问,是由阿里云开发的预训练语言模型。<|im_end|> +" ``` -**Demo output of run.py:** ```bash +# With INT4 GPTQ quantization python3 ../run.py --input_text "你好,请问你叫什么?" \ --max_output_len=50 \ - --tokenizer_dir /llm-models/Qwen-7B-Chat/ \ - --engine_dir /engine_qwen + --tokenizer_dir ./tmp/Qwen-7B-Chat-Int4 \ + --engine_dir=./tmp/Qwen/7B/trt_engines/int4_GPTQ/1-gpu/ +``` +``` +Input [Text 0]: "<|im_start|>system +You are a helpful assistant.<|im_end|> +<|im_start|>user +你好,请问你叫什么?<|im_end|> +<|im_start|>assistant +" +Output [Text 0 Beam 0]: "你好,我是通义千问,由阿里云开发。<|im_end|> +" ``` +```bash +# With INT4 AWQ quantization +python3 ../run.py --input_text "你好,请问你叫什么?" \ + --max_output_len=50 \ + --tokenizer_dir ./tmp/Qwen/7B/ \ + --engine_dir=./tmp/Qwen/7B/trt_engines/int4_AWQ/1-gpu/ ``` -Loading engine from /engine_qwen/qwen_float16_tp1_rank0.engine -Input: "<|im_start|>system +``` +Input [Text 0]: "<|im_start|>system You are a helpful assistant.<|im_end|> <|im_start|>user 你好,请问你叫什么?<|im_end|> <|im_start|>assistant " -Output: "我是来自阿里云的大规模语言模型,我叫通义千问。" +Output [Text 0 Beam 0]: "你好,我是通义千问,由阿里云开发。<|im_end|> +" ``` + ```bash +# Run 72B model with 8-gpu mpirun -n 8 --allow-run-as-root \ python ../run.py --input_text "What is your name?" \ --max_output_len=50 \ @@ -490,12 +345,13 @@ What is your name?<|im_end|> " Output [Text 0 Beam 0]: "I am QianWen, a large language model created by Alibaba Cloud." ``` + ### Summarization using the Qwen model ```bash # Run summarization using the Qwen 7B model in FP16. python ../summarize.py --test_trt_llm \ - --tokenizer_dir ./tmp/Qwen/7B/ \ + --hf_model_dir ./tmp/Qwen/7B/ \ --data_type fp16 \ --engine_dir ./tmp/Qwen/7B/trt_engines/fp16/1-gpu/ \ --max_input_length 2048 \ @@ -503,7 +359,7 @@ python ../summarize.py --test_trt_llm \ # Run summarization using the Qwen 7B model in BF16. python ../summarize.py --test_trt_llm \ - --tokenizer_dir ./tmp/Qwen/7B/ \ + --hf_model_dir ./tmp/Qwen/7B/ \ --data_type fp16 \ --engine_dir ./tmp/Qwen/7B/trt_engines/bf16/1-gpu/ \ --max_input_length 2048 \ @@ -511,7 +367,7 @@ python ../summarize.py --test_trt_llm \ # Run summarization using the Qwen 7B model quantized to INT8. python ../summarize.py --test_trt_llm \ - --tokenizer_dir ./tmp/Qwen/7B/ \ + --hf_model_dir ./tmp/Qwen/7B/ \ --data_type fp16 \ --engine_dir ./tmp/Qwen/7B/trt_engines/int8_weight_only/1-gpu/ \ --max_input_length 2048 \ @@ -519,7 +375,7 @@ python ../summarize.py --test_trt_llm \ # Run summarization using the Qwen 7B model quantized to INT4. python ../summarize.py --test_trt_llm \ - --tokenizer_dir ./tmp/Qwen/7B/ \ + --hf_model_dir ./tmp/Qwen/7B/ \ --data_type fp16 \ --engine_dir ./tmp/Qwen/7B/trt_engines/int4_weight_only/1-gpu/ \ --max_input_length 2048 \ @@ -528,7 +384,7 @@ python ../summarize.py --test_trt_llm \ # Run summarization using the Qwen 7B model in FP16 using two GPUs. mpirun -n 2 --allow-run-as-root \ python ../summarize.py --test_trt_llm \ - --tokenizer_dir ./tmp/Qwen/7B/ \ + --hf_model_dir ./tmp/Qwen/7B/ \ --data_type fp16 \ --engine_dir ./tmp/Qwen/7B/trt_engines/fp16/2-gpu/ \ --max_input_length 2048 \ @@ -537,15 +393,20 @@ mpirun -n 2 --allow-run-as-root \ # Run summarization using the Qwen 14B model in FP16 using two GPUs. mpirun -n 2 --allow-run-as-root \ python ../summarize.py --test_trt_llm \ - --tokenizer_dir ./tmp/Qwen/14B/ \ + --hf_model_dir ./tmp/Qwen/14B/ \ --data_type fp16 \ --engine_dir ./tmp/Qwen/14B/trt_engines/fp16/2-gpu/ \ --max_input_length 2048 \ --output_len 2048 ``` **Demo output of summarize.py:** -```python -python3 ../summarize.py --test_trt_llm --tokenizer_dir /llm-models/Qwen-7B-Chat/ --engine_dir /engine_qwen --max_input_length 2048 --output_len 2048 +```bash +python ../summarize.py --test_trt_llm \ + --hf_model_dir ./tmp/Qwen/7B/ \ + --data_type fp16 \ + --engine_dir ./tmp/Qwen/7B/trt_engines/fp16/1-gpu/ \ + --max_input_length 2048 \ + --output_len 2048 ``` ``` [11/09/2023-02:21:10] [TRT-LLM] [I] Load tokenizer takes: 0.4043385982513428 sec diff --git a/examples/qwen/build.py b/examples/qwen/build.py deleted file mode 100644 index 442676d00..000000000 --- a/examples/qwen/build.py +++ /dev/null @@ -1,808 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import argparse -import math -import os -import time - -# isort: off -import torch -import torch.multiprocessing as mp -import tensorrt as trt -# isort: on -from transformers import AutoConfig, AutoModelForCausalLM -from weight import (load_from_awq_qwen, load_from_binary, load_from_gptq_qwen, - load_from_hf_qwen) - -import tensorrt_llm -from tensorrt_llm import profiler -from tensorrt_llm._common import check_max_num_tokens -from tensorrt_llm._utils import str_dtype_to_trt -from tensorrt_llm.builder import Builder -from tensorrt_llm.layers.attention import PositionEmbeddingType -from tensorrt_llm.logger import logger -from tensorrt_llm.mapping import Mapping -from tensorrt_llm.models import quantize_model -from tensorrt_llm.network import net_guard -from tensorrt_llm.plugin.plugin import ContextFMHAType -from tensorrt_llm.quantization import QuantMode - -MODEL_NAME = "qwen" - -import onnx -from onnx import TensorProto, helper - -now_dir = os.path.dirname(os.path.abspath(__file__)) - - -def trt_dtype_to_onnx(dtype): - if dtype == trt.float16: - return TensorProto.DataType.FLOAT16 - elif dtype == trt.float32: - return TensorProto.DataType.FLOAT - elif dtype == trt.int32: - return TensorProto.DataType.INT32 - else: - raise TypeError("%s is not supported" % dtype) - - -def to_onnx(network, path): - inputs = [] - for i in range(network.num_inputs): - network_input = network.get_input(i) - inputs.append( - helper.make_tensor_value_info( - network_input.name, trt_dtype_to_onnx(network_input.dtype), - list(network_input.shape))) - - outputs = [] - for i in range(network.num_outputs): - network_output = network.get_output(i) - outputs.append( - helper.make_tensor_value_info( - network_output.name, trt_dtype_to_onnx(network_output.dtype), - list(network_output.shape))) - - nodes = [] - for i in range(network.num_layers): - layer = network.get_layer(i) - layer_inputs = [] - for j in range(layer.num_inputs): - ipt = layer.get_input(j) - if ipt is not None: - layer_inputs.append(layer.get_input(j).name) - layer_outputs = [ - layer.get_output(j).name for j in range(layer.num_outputs) - ] - nodes.append( - helper.make_node(str(layer.type), - name=layer.name, - inputs=layer_inputs, - outputs=layer_outputs, - domain="com.nvidia")) - - onnx_model = helper.make_model(helper.make_graph(nodes, - 'attention', - inputs, - outputs, - initializer=None), - producer_name='NVIDIA') - onnx.save(onnx_model, path) - - -def get_engine_name(model, dtype, tp_size, pp_size, rank): - if pp_size == 1: - return '{}_{}_tp{}_rank{}.engine'.format(model, dtype, tp_size, rank) - return '{}_{}_tp{}_pp{}_rank{}.engine'.format(model, dtype, tp_size, - pp_size, rank) - - -def serialize_engine(engine, path): - logger.info(f'Serializing engine to {path}...') - tik = time.time() - with open(path, 'wb') as f: - f.write(engine) - tok = time.time() - t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) - logger.info(f'Engine serialized. Total time: {t}') - - -def parse_arguments(): - parser = argparse.ArgumentParser() - parser.add_argument('--world_size', type=int, default=1) - parser.add_argument('--tp_size', type=int, default=1) - parser.add_argument('--pp_size', type=int, default=1) - parser.add_argument('--model_dir', type=str, default=None) - parser.add_argument('--bin_model_dir', type=str, default=None) - parser.add_argument("--quant_ckpt_path", type=str, default=None) - parser.add_argument('--dtype', - type=str, - default='float16', - choices=['float32', 'bfloat16', 'float16']) - parser.add_argument( - '--timing_cache', - type=str, - default='model.cache', - help= - 'The path of to read timing cache from, will be ignored if the file does not exist' - ) - parser.add_argument( - '--profiling_verbosity', - type=str, - default='layer_names_only', - choices=['layer_names_only', 'detailed', 'none'], - help= - 'The profiling verbosity for the generated TRT engine. Set to detailed can inspect tactic choices and kernel parameters.' - ) - parser.add_argument('--log_level', - type=str, - default='info', - choices=[ - 'internal_error', - 'error', - 'warning', - 'info', - 'verbose', - ]) - parser.add_argument('--vocab_size', type=int, default=32000) - parser.add_argument('--n_layer', type=int, default=32) - parser.add_argument('--n_positions', type=int, default=2048) - parser.add_argument('--n_embd', type=int, default=4096) - parser.add_argument('--n_head', type=int, default=32) - parser.add_argument('--n_kv_head', type=int, default=None) - parser.add_argument('--inter_size', type=int, default=11008) - parser.add_argument('--hidden_act', type=str, default='silu') - parser.add_argument('--rms_norm_eps', type=float, default=1e-06) - parser.add_argument('--max_batch_size', type=int, default=2) - parser.add_argument('--max_input_len', type=int, default=2048) - parser.add_argument('--max_output_len', type=int, default=2048) - parser.add_argument('--max_beam_width', type=int, default=1) - parser.add_argument('--rotary_base', type=float, default=10000.0) - parser.add_argument('--rotary_scaling', nargs=2, type=str, default=None) - parser.add_argument('--use_gpt_attention_plugin', - nargs='?', - const='float16', - type=str, - default=False, - choices=['float16', 'bfloat16', 'float32']) - parser.add_argument('--use_gemm_plugin', - nargs='?', - const='float16', - type=str, - default=False, - choices=['float16', 'bfloat16', 'float32']) - parser.add_argument('--parallel_build', default=False, action='store_true') - parser.add_argument('--enable_context_fmha', - default=False, - action='store_true') - parser.add_argument('--enable_context_fmha_fp32_acc', - default=False, - action='store_true') - parser.add_argument( - '--use_paged_context_fmha', - action='store_true', - help= - 'Activates paged context FMHA. This mode of the context FMHA is required for chunked context, speculative decoding and reuse of KV cache blocks. Context FMHA performance is worse when this mode is on.' - ) - parser.add_argument( - '--multi_block_mode', - default=False, - action='store_true', - help= - 'Split long kv sequence into multiple blocks (applied to generation MHA kernels). \ - It is beneficial when batch x num_heads cannot fully utilize GPU.' - ) - parser.add_argument( - '--disable_xqa', - default=False, - action='store_true', - help= - 'Disable XQA optimization for the generation MHA. See more details in docs/gpt_attention.' - ) - parser.add_argument('--visualize', default=False, action='store_true') - parser.add_argument('--enable_debug_output', - default=False, - action='store_true') - parser.add_argument('--gpus_per_node', type=int, default=8) - parser.add_argument('--builder_opt', type=int, default=None) - parser.add_argument( - '--output_dir', - type=str, - default='engine_outputs', - help= - 'The path to save the serialized engine files, timing cache file and model configs' - ) - parser.add_argument('--remove_input_padding', - default=False, - action='store_true') - parser.add_argument( - '--use_fused_mlp', - default=False, - action='store_true', - help= - 'Enable horizontal fusion in GatedMLP, reduces layer input traffic and potentially improves performance. ' - 'For FP8 PTQ, the downside is slight reduction of accuracy because one of the quantization scaling factors are discarded ' - '(0.45734 vs 0.45755 for LLaMA-v2 7B using ammo/examples/hf/instruct_eval/mmlu.py).' - ) - - # Arguments related to the quantization of the model. - parser.add_argument( - '--use_smooth_quant', - default=False, - action="store_true", - help= - 'Use the SmoothQuant method to quantize activations and weights for the various GEMMs.' - 'See --per_channel and --per_token for finer-grained quantization options.' - ) - parser.add_argument( - '--per_channel', - default=False, - action="store_true", - help= - 'By default, we use a single static scaling factor for the GEMM\'s result. ' - 'per_channel instead uses a different static scaling factor for each channel. ' - 'The latter is usually more accurate, but a little slower.') - parser.add_argument( - '--per_token', - default=False, - action="store_true", - help= - 'By default, we use a single static scaling factor to scale activations in the int8 range. ' - 'per_token chooses at run time, and for each token, a custom scaling factor. ' - 'The latter is usually more accurate, but a little slower.') - - parser.add_argument( - '--per_group', - default=False, - action="store_true", - help= - 'By default, we use a single static scaling factor to scale weights in the int4 range. ' - 'per_group chooses at run time, and for each group, a custom scaling factor. ' - 'The flag is built for GPTQ/AWQ quantization.') - parser.add_argument( - "--group_size", - type=int, - default=128, - help="group size used in gptq/awq quantization.", - ) - parser.add_argument( - '--int8_kv_cache', - default=False, - action="store_true", - help= - 'By default, we use dtype for KV cache. int8_kv_cache chooses int8 quantization for KV' - ) - parser.add_argument( - '--use_parallel_embedding', - action="store_true", - default=False, - help= - 'By default embedding parallelism is disabled. By setting this flag, embedding parallelism is enabled' - ) - parser.add_argument( - '--embedding_sharding_dim', - type=int, - default=1, # Meta does TP on hidden dim - choices=[0, 1], - help= - 'By default the embedding lookup table is sharded along vocab dimension (embedding_sharding_dim=0). ' - 'To shard it along hidden dimension, set embedding_sharding_dim=1' - 'Note: embedding sharing is only enabled when embedding_sharding_dim = 0' - ) - parser.add_argument( - '--use_weight_only', - default=False, - action="store_true", - help='Quantize weights for the various GEMMs to INT4/INT8.' - 'See --weight_only_precision to set the precision') - parser.add_argument( - '--disable_weight_only_quant_plugin', - default=False, - action="store_true", - help= - 'By default, using plugin implementation for weight quantization. Enabling disable_weight_only_quant_plugin flag will use ootb implementation instead of plugin.' - 'You must also use --use_weight_only for that argument to have an impact.' - ) - parser.add_argument( - '--weight_only_precision', - const='int8', - type=str, - nargs='?', - default='int8', - choices=['int8', 'int4', 'int4_awq', 'int4_gptq'], - help= - 'Define the precision for the weights when using weight-only quantization.' - 'You must also use --use_weight_only for that argument to have an impact.' - ) - parser.add_argument( - '--quantize_lm_head', - default=False, - action="store_true", - help='Quantize lm_head weights as well when using int4_awq.') - parser.add_argument( - '--use_inflight_batching', - action="store_true", - default=False, - help="Activates inflight batching mode of gptAttentionPlugin.") - parser.add_argument( - '--paged_kv_cache', - action="store_true", - default=False, - help= - 'By default we use contiguous KV cache. By setting this flag you enable paged KV cache' - ) - parser.add_argument('--tokens_per_block', - type=int, - default=128, - help='Number of tokens per block in paged KV cache') - - parser.add_argument( - '--max_num_tokens', - type=int, - default=None, - help= - 'Define the max number of tokens supported by the engine, note that it takes no effect if --remove_input_padding is not set' - ) - parser.add_argument( - '--strongly_typed', - default=False, - action="store_true", - help= - 'This option is introduced with trt 9.1.0.1+ and will reduce the building time significantly for fp8.' - ) - parser.add_argument( - '--use_custom_all_reduce', - action='store_true', - help= - 'Activates latency-optimized algorithm for all-reduce instead of NCCL.') - parser.add_argument( - '--max_prompt_embedding_table_size', - type=int, - default=0, - help='Setting to a value > 0 enables support for prompt tuning.') - parser.add_argument( - '--gather_all_token_logits', - action='store_true', - default=False, - help='Enable both gather_context_logits and gather_generation_logits') - parser.add_argument('--gather_context_logits', - action='store_true', - default=False, - help='Gather context logits') - parser.add_argument('--gather_generation_logits', - action='store_true', - default=False, - help='Gather generation logits') - parser.add_argument( - '--use_lookup_plugin', - nargs='?', - const=None, - default=False, - choices=['float16', 'float32', 'bfloat16'], - help="Activates the lookup plugin which enables embedding sharing.") - - args = parser.parse_args() - logger.set_level(args.log_level) - - assert not ( - args.use_smooth_quant and args.use_weight_only - ), "You cannot enable both SmoothQuant and INT8 weight-only together." - - if not args.remove_input_padding: - if args.use_gpt_attention_plugin: - logger.warning( - f"It is recommended to specify --remove_input_padding when using GPT attention plugin" - ) - - if args.use_inflight_batching: - if not args.use_gpt_attention_plugin: - args.use_gpt_attention_plugin = 'float16' - logger.info( - f"Using GPT attention plugin for inflight batching mode. Setting to default '{args.use_gpt_attention_plugin}'" - ) - if not args.remove_input_padding: - args.remove_input_padding = True - logger.info( - "Using remove input padding for inflight batching mode.") - if not args.paged_kv_cache: - args.paged_kv_cache = True - logger.info("Using paged KV cache for inflight batching mode.") - - if args.use_smooth_quant: - args.quant_mode = QuantMode.use_smooth_quant(args.per_token, - args.per_channel) - elif args.use_weight_only: - args.quant_mode = QuantMode.from_description( - quantize_weights=True, - quantize_activations=False, - per_token=False, - per_channel=False, - per_group=args.per_group, - use_int4_weights="int4" in args.weight_only_precision) - else: - args.quant_mode = QuantMode(0) - - if args.int8_kv_cache: - args.quant_mode = args.quant_mode.set_int8_kv_cache() - - if args.rotary_scaling is not None: - assert args.use_gpt_attention_plugin, "RoPE scaling is only supported through GPT attention plugin." - rotary_scaling = { - "type": args.rotary_scaling[0], - "factor": float(args.rotary_scaling[1]) - } - assert rotary_scaling["type"] in ["linear", "dynamic"] - assert rotary_scaling["factor"] > 1.0 - args.rotary_scaling = rotary_scaling - - if args.model_dir is not None: - hf_config = AutoConfig.from_pretrained( - args.model_dir, - trust_remote_code=True, - ) - args.inter_size = hf_config.intermediate_size # override the inter_size for QWen - args.n_embd = hf_config.hidden_size - args.n_head = hf_config.num_attention_heads - if hasattr(hf_config, "num_key_value_heads"): - args.n_kv_head = hf_config.num_key_value_heads - args.n_layer = hf_config.num_hidden_layers - args.n_positions = hf_config.max_position_embeddings - args.vocab_size = hf_config.vocab_size - args.hidden_act = "silu" - args.rms_norm_eps = hf_config.layer_norm_epsilon - args.kv_channels = hf_config.kv_channels - args.rotary_base = hf_config.rotary_emb_base - if args.n_kv_head is None: - args.n_kv_head = args.n_head - if args.n_kv_head != args.n_head: - assert (args.n_head % args.n_kv_head) == 0, \ - "MQA/GQA requires the number of heads to be divisible by the number of K/V heads." - assert (args.n_kv_head % args.tp_size) == 0 or (args.tp_size % args.n_kv_head) == 0, \ - "MQA/GQA requires either the number of K/V heads to be divisible by the tensor parallelism size OR " \ - "the tensor parallelism size to be divisible by the number of K/V heads." - - assert args.pp_size * args.tp_size == args.world_size - - if args.weight_only_precision == 'int4_awq': - inter_alignment = args.tp_size * 128 - if args.inter_size % inter_alignment != 0: - args.inter_size = int((args.inter_size + inter_alignment - 1) / - inter_alignment) * inter_alignment - logger.info("To use awq we pad intermediate_size to {}.".format( - args.inter_size)) - - if args.quantize_lm_head: - vocab_alignment = args.tp_size * 64 - if args.vocab_size % vocab_alignment != 0: - args.vocab_size = int((args.vocab_size + vocab_alignment - 1) / - vocab_alignment) * vocab_alignment - logger.info("To use awq we pad vocab_size to {}.".format( - args.vocab_size)) - - args.max_num_tokens = check_max_num_tokens( - max_num_tokens=args.max_num_tokens, - max_batch_size=args.max_batch_size, - max_input_len=args.max_input_len, - remove_input_padding=args.remove_input_padding, - enable_context_fmha=args.enable_context_fmha, - tokens_per_block=args.tokens_per_block) - - assert (math.log2(args.tokens_per_block).is_integer() - ), "tokens_per_block must be power of 2" - if args.enable_context_fmha or args.enable_context_fmha_fp32_acc: - assert (args.tokens_per_block >= - 128), "Context fMHA requires >= 128 tokens per block" - if args.gather_all_token_logits: - args.gather_context_logits = True - args.gather_generation_logits = True - return args - - -def get_model_object(args, mapping, trt_dtype=None): - if trt_dtype is None: - trt_dtype = str_dtype_to_trt(args.dtype) - # Initialize Module - tensorrt_llm_qwen = tensorrt_llm.models.QWenForCausalLM( - num_layers=args.n_layer, - num_heads=args.n_head, - num_kv_heads=args.n_kv_head, - hidden_size=args.n_embd, - seq_length=args.max_input_len, - vocab_size=args.vocab_size, - hidden_act=args.hidden_act, - max_position_embeddings=args.n_positions, - dtype=trt_dtype, - mlp_hidden_size=args.inter_size, - position_embedding_type=PositionEmbeddingType.rope_gpt_neox, - mapping=mapping, - rotary_base=args.rotary_base, - rotary_scaling=args.rotary_scaling, - use_parallel_embedding=args.use_parallel_embedding, - embedding_sharding_dim=args.embedding_sharding_dim, - quant_mode=args.quant_mode, - rms_norm_eps=args.rms_norm_eps, - use_fused_mlp=args.use_fused_mlp, - use_prompt_tuning=args.max_prompt_embedding_table_size > 0) - quantize_kwargs = {} - if args.use_smooth_quant or args.use_weight_only: - if args.weight_only_precision == 'int4_awq': - exclude_modules = ['lm_head'] if not args.quantize_lm_head else [] - quantize_kwargs = { - "group_size": args.group_size, - "zero": False, - "pre_quant_scale": True, - "exclude_modules": exclude_modules, - } - elif args.weight_only_precision == 'int4_gptq': - quantize_kwargs = { - "group_size": args.group_size, - "zero": True, - "pre_quant_scale": False, - } - tensorrt_llm_qwen = quantize_model(tensorrt_llm_qwen, args.quant_mode, - **quantize_kwargs) - if args.per_group: - if args.weight_only_precision == 'int4_awq': - load_from_awq_qwen(tensorrt_llm_qwen=tensorrt_llm_qwen, - quant_ckpt_path=args.quant_ckpt_path, - quantize_lm_head=args.quantize_lm_head, - mapping=mapping, - dtype=args.dtype) - else: - load_from_gptq_qwen(tensorrt_llm_qwen=tensorrt_llm_qwen, - quant_ckpt_path=args.quant_ckpt_path, - mapping=mapping, - dtype=args.dtype) - elif args.model_dir is not None and \ - (args.bin_model_dir is None or not os.path.exists(args.bin_model_dir)): - logger.info(f'Loading HF QWen ... from {args.model_dir}') - tik = time.time() - hf_qwen = AutoModelForCausalLM.from_pretrained( - args.model_dir, - device_map={ - "transformer": "cpu", - "lm_head": "cpu" - }, # Load to CPU memory - torch_dtype="auto", - trust_remote_code=True, - ) - tok = time.time() - t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) - logger.info(f'HF QWen loaded. Total time: {t}') - load_from_hf_qwen(tensorrt_llm_qwen, - hf_qwen, - mapping, - dtype=args.dtype, - multi_query_mode=(args.n_kv_head != args.n_head)) - del hf_qwen - elif args.bin_model_dir is not None: - logger.info(f'Loading QWen ... from {args.bin_model_dir}') - load_from_binary(tensorrt_llm_qwen, - args.bin_model_dir, - mapping, - dtype=args.dtype, - multi_query_mode=(args.n_kv_head != args.n_head)) - else: - raise ValueError( - "You must specify either --model_dir or --bin_model_dir") - - return tensorrt_llm_qwen - - -def update_plugin_configs(args, network): - network.plugin_config.to_legacy_setting() - if args.use_gpt_attention_plugin: - network.plugin_config.set_gpt_attention_plugin( - dtype=args.use_gpt_attention_plugin) - if args.use_gemm_plugin: - network.plugin_config.set_gemm_plugin(dtype=args.use_gemm_plugin) - # Quantization plugins. - if args.use_smooth_quant: - network.plugin_config.set_smooth_quant_gemm_plugin(dtype=args.dtype) - network.plugin_config.set_rmsnorm_quantization_plugin(dtype=args.dtype) - network.plugin_config.set_quantize_tensor_plugin() - network.plugin_config.set_quantize_per_token_plugin() - assert not (args.enable_context_fmha and args.enable_context_fmha_fp32_acc) - if args.enable_context_fmha: - network.plugin_config.set_context_fmha(ContextFMHAType.enabled) - if args.enable_context_fmha_fp32_acc: - network.plugin_config.set_context_fmha( - ContextFMHAType.enabled_with_fp32_acc) - if args.multi_block_mode: - network.plugin_config.enable_mmha_multi_block_mode() - if not args.disable_xqa: - network.plugin_config.enable_xqa_optimization() - - if args.use_weight_only and not args.disable_weight_only_quant_plugin: - if args.per_group: - network.plugin_config.set_weight_only_groupwise_quant_matmul_plugin( - dtype=args.dtype) - else: - network.plugin_config.set_weight_only_quant_matmul_plugin( - dtype=args.dtype) - if args.world_size > 1: - network.plugin_config.set_nccl_plugin(args.dtype, - args.use_custom_all_reduce) - if args.remove_input_padding: - network.plugin_config.enable_remove_input_padding() - if args.paged_kv_cache: - network.plugin_config.enable_paged_kv_cache(args.tokens_per_block) - if args.use_lookup_plugin: - network.plugin_config.set_lookup_plugin(dtype=args.dtype) - if args.use_paged_context_fmha: - assert args.enable_context_fmha or args.enable_context_fmha_fp32_acc, "context fmha must be enabled" - network.plugin_config.set_paged_context_fmha() - return - - -def build_rank_engine(builder: Builder, - builder_config: tensorrt_llm.builder.BuilderConfig, - engine_name, rank, args): - ''' - @brief: Build the engine on the given rank. - @param rank: The rank to build the engine. - @param args: The cmd line arguments. - @return: The built engine. - ''' - dtype = str_dtype_to_trt(args.dtype) - mapping = Mapping(world_size=args.world_size, - rank=rank, - tp_size=args.tp_size, - pp_size=args.pp_size) - - assert args.n_layer % args.pp_size == 0, \ - f"num_layers {args.n_layer} must be a multiple of pipeline parallelism size {args.pp_size}" - - profiler.print_memory_usage(f'Rank {rank} Engine build starts') - # Initialize Module - tensorrt_llm_qwen = get_model_object(args, mapping=mapping, trt_dtype=dtype) - profiler.print_memory_usage(f'Rank {rank} model weight loaded.') - - # Module -> Network - network = builder.create_network() - network.trt_network.name = engine_name - update_plugin_configs(args, network) - - with net_guard(network): - # Prepare - network.set_named_parameters(tensorrt_llm_qwen.named_parameters()) - - # Forward - inputs = tensorrt_llm_qwen.prepare_inputs( - max_batch_size=args.max_batch_size, - max_input_len=args.max_input_len, - max_seq_len=args.max_input_len + args.max_output_len, - use_cache=True, - max_beam_width=args.max_beam_width, - max_num_tokens=args.max_num_tokens, - prompt_embedding_table_size=args.max_prompt_embedding_table_size, - gather_context_logits=args.gather_context_logits, - gather_generation_logits=args.gather_generation_logits) - tensorrt_llm_qwen(*inputs) - if args.enable_debug_output: - # mark intermediate nodes' outputs - for k, v in tensorrt_llm_qwen.named_network_outputs(): - v = v.trt_tensor - v.name = k - network.trt_network.mark_output(v) - v.dtype = kv_dtype - if args.visualize: - model_path = os.path.join(args.output_dir, 'test.onnx') - to_onnx(network.trt_network, model_path) - - tensorrt_llm.graph_rewriting.optimize(network) - - engine = None - - # Network -> Engine - engine = builder.build_engine(network, builder_config) - if rank == 0: - config_path = os.path.join(args.output_dir, 'config.json') - builder.save_config(builder_config, config_path) - return engine - - -def get_builder_config_namespace(args, cache): - # NOTE: int8 flag is required to be true when INT8 tensors are exposed to TRT - # TRT-LLM has INT8 I/O when act/weights are quantized without group-scaling (AWQ, GPTQ) - # OR INT8 KV cache is set to contiguous (without paged KV cache enabled). - int8_trt_flag = (args.quant_mode.has_act_or_weight_quant() - and not args.quant_mode.has_per_group_scaling()) or ( - not args.paged_kv_cache - and args.quant_mode.has_int8_kv_cache()) - config = argparse.Namespace( - name=MODEL_NAME, - precision=args.dtype, - timing_cache=args.timing_cache if cache is None else cache, - profiling_verbosity=args.profiling_verbosity, - tensor_parallel=args.tp_size, - pipeline_parallel=args.pp_size, - parallel_build=args.parallel_build, - num_layers=args.n_layer, - num_heads=args.n_head, - num_kv_heads=args.n_kv_head, - hidden_size=args.n_embd, - vocab_size=args.vocab_size, - hidden_act=args.hidden_act, - max_position_embeddings=args.n_positions, - max_batch_size=args.max_batch_size, - max_beam_width=args.max_beam_width, - max_input_len=args.max_input_len, - max_output_len=args.max_output_len, - max_num_tokens=args.max_num_tokens, - int8=int8_trt_flag, - quant_mode=args.quant_mode, - strongly_typed=args.strongly_typed, - opt_level=args.builder_opt, - max_prompt_embedding_table_size=args.max_prompt_embedding_table_size, - gather_context_logits=args.gather_context_logits, - gather_generation_logits=args.gather_generation_logits, - mlp_hidden_size=args.inter_size, - ) - return config - - -def build(rank, args): - torch.cuda.set_device(rank % args.gpus_per_node) - logger.set_level(args.log_level) - os.makedirs(args.output_dir, exist_ok=True) - - # when doing serializing build, all ranks share one engine - builder = Builder() - - cache = None - for cur_rank in range(args.world_size): - # skip other ranks if parallel_build is enabled - if args.parallel_build and cur_rank != rank: - continue - builder_config = builder.create_builder_config( - **vars(get_builder_config_namespace(args, cache))) - engine_name = get_engine_name(MODEL_NAME, args.dtype, args.tp_size, - args.pp_size, cur_rank) - engine = build_rank_engine(builder, builder_config, engine_name, - cur_rank, args) - assert engine is not None, f'Failed to build engine for rank {cur_rank}' - - if cur_rank == 0: - # Use in-memory timing cache for multiple builder passes. - if not args.parallel_build: - cache = builder_config.trt_builder_config.get_timing_cache() - - serialize_engine(engine, os.path.join(args.output_dir, engine_name)) - - if rank == 0: - ok = builder.save_timing_cache( - builder_config, os.path.join(args.output_dir, "model.cache")) - assert ok, "Failed to save timing cache." - - -if __name__ == '__main__': - args = parse_arguments() - logger.set_level(args.log_level) - tik = time.time() - if args.parallel_build and args.world_size > 1 and \ - torch.cuda.device_count() >= args.world_size: - logger.warning( - f'Parallel build TensorRT engines. Please make sure that all of the {args.world_size} GPUs are totally free.' - ) - mp.spawn(build, nprocs=args.world_size, args=(args, )) - else: - args.parallel_build = False - logger.info('Serially build TensorRT engines.') - build(0, args) - - tok = time.time() - t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) - logger.info(f'Total time of building all {args.world_size} engines: {t}') diff --git a/examples/qwen/convert_checkpoint.py b/examples/qwen/convert_checkpoint.py new file mode 100644 index 000000000..444feb6aa --- /dev/null +++ b/examples/qwen/convert_checkpoint.py @@ -0,0 +1,1421 @@ +import argparse +import functools +import json +import os +import time +import traceback +from collections import defaultdict +from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import List + +import numpy as np +import safetensors +import torch +import torch.nn as nn +from datasets import load_dataset +from tqdm import tqdm +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer +from transformers.pytorch_utils import Conv1D + +import tensorrt_llm +from tensorrt_llm._utils import pad_vocab_size, str_dtype_to_torch +from tensorrt_llm.logger import logger +from tensorrt_llm.mapping import Mapping + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument('--model_dir', type=str, default=None) + parser.add_argument('--tp_size', + type=int, + default=1, + help='N-way tensor parallelism size') + parser.add_argument('--pp_size', + type=int, + default=1, + help='N-way pipeline parallelism size') + parser.add_argument('--dtype', + type=str, + default='float16', + choices=['float32', 'bfloat16', 'float16']) + parser.add_argument('--vocab_size', type=int, default=32000) + parser.add_argument('--n_positions', type=int, default=2048) + parser.add_argument('--n_layer', type=int, default=32) + parser.add_argument('--n_head', type=int, default=32) + parser.add_argument('--n_kv_head', type=int, default=None) + parser.add_argument('--n_embd', type=int, default=4096) + parser.add_argument('--inter_size', type=int, default=22016) + parser.add_argument('--rms_norm_eps', type=float, default=1e-06) + + parser.add_argument( + '--use_weight_only', + default=False, + action="store_true", + help='Quantize weights for the various GEMMs to INT4/INT8.' + 'See --weight_only_precision to set the precision') + parser.add_argument( + '--disable_weight_only_quant_plugin', + default=False, + action="store_true", + help= + 'By default, using plugin implementation for weight quantization. Enabling disable_weight_only_quant_plugin flag will use ootb implementation instead of plugin.' + 'You must also use --use_weight_only for that argument to have an impact.' + ) + parser.add_argument( + '--weight_only_precision', + const='int8', + type=str, + nargs='?', + default='int8', + choices=['int8', 'int4', 'int4_gptq'], + help= + 'Define the precision for the weights when using weight-only quantization.' + 'You must also use --use_weight_only for that argument to have an impact.' + ) + parser.add_argument( + "--smoothquant", + "-sq", + type=float, + default=None, + help="Set the α parameter (see https://arxiv.org/pdf/2211.10438.pdf)" + " to Smoothquant the model, and output int8 weights." + " A good first try is 0.5. Must be in [0, 1]") + parser.add_argument( + '--per_channel', + action="store_true", + default=False, + help= + 'By default, we use a single static scaling factor for the GEMM\'s result. ' + 'per_channel instead uses a different static scaling factor for each channel. ' + 'The latter is usually more accurate, but a little slower.') + parser.add_argument( + '--per_token', + action="store_true", + default=False, + help= + 'By default, we use a single static scaling factor to scale activations in the int8 range. ' + 'per_token chooses at run time, and for each token, a custom scaling factor. ' + 'The latter is usually more accurate, but a little slower.') + parser.add_argument( + '--int8_kv_cache', + default=False, + action="store_true", + help= + 'By default, we use dtype for KV cache. int8_kv_cache chooses int8 quantization for KV' + ) + parser.add_argument( + '--ammo_quant_ckpt_path', + type=str, + default=None, + help='Path of a quantized model checkpoint in .npz format') + + parser.add_argument( + '--per_group', + default=False, + action="store_true", + help= + 'By default, we use a single static scaling factor to scale weights in the int4 range. ' + 'per_group chooses at run time, and for each group, a custom scaling factor. ' + 'The flag is built for GPTQ/AWQ quantization.') + + parser.add_argument('--hidden_act', type=str, default='silu') + + parser.add_argument('--rotary_base', type=float, default=10000.0) + parser.add_argument('--rotary_scaling', nargs=2, type=str, default=None) + + parser.add_argument('--group_size', + type=int, + default=128, + help='Group size used in GPTQ/AWQ quantization.') + + parser.add_argument("--storage-type", + "-t", + type=str, + default="fp32", + choices=["fp32", "fp16"]) + parser.add_argument("--dataset-cache-dir", + type=str, + default=None, + help="cache dir to load the hugging face dataset") + parser.add_argument("--load_model_on_cpu", action="store_true") + + parser.add_argument( + '--use_parallel_embedding', + action="store_true", + default=False, + help= + 'By default embedding parallelism is disabled. By setting this flag, embedding parallelism is enabled' + ) + parser.add_argument( + '--embedding_sharding_dim', + type=int, + default=0, + choices=[0, 1], + help= + 'By default the embedding lookup table is sharded along vocab dimension (embedding_sharding_dim=0). ' + 'To shard it along hidden dimension, set embedding_sharding_dim=1' + 'Note: embedding sharing is only enabled when embedding_sharding_dim = 0' + ) + parser.add_argument( + '--use_embedding_sharing', + action="store_true", + default=False, + help= + 'Try to reduce the engine size by sharing the embedding lookup table between two layers.' + 'Note: the flag might not take effect when the criteria are not met.') + parser.add_argument('--output_dir', + type=str, + default='tllm_checkpoint', + help='The path to save the TensorRT-LLM checkpoint') + parser.add_argument( + '--workers', + type=int, + default=1, + help='The number of workers for converting checkpoint in parallel') + parser.add_argument( + '--dense_context_fmha', + default=False, + action='store_true', + help= + 'Enable dense fmha in context phase, otherwise sliding window attention.' + 'If dense_context_fmha=False, the sliding window size is the max attention window size.' + ) + args = parser.parse_args() + return args + + +def load_from_gptq_qwen( + model, + num_hidden_layers=None, + mapping=Mapping(), + dtype="float16", +): + tensorrt_llm.logger.info( + "loading weights from groupwise GPTQ QWen safetensors...") + weights = {} + tik = time.time() + + model_params = {k: v for k, v in model.state_dict().items()} + torch.cuda.empty_cache() + + def torch_split(v, dim): + if v.shape[dim] % mapping.tp_size != 0: + tensorrt_llm.logger.error( + "Current weight shape is invalid for mapping.tp_size=" + + str(mapping.tp_size)) + assert False, "Invalid TP size" + return v.split(v.shape[dim] // mapping.tp_size, + dim=dim)[mapping.tp_rank] + + def unpack_int32_into_int8(w_packed): + # unpack inputs packed in int32/float32 into uint4 and store them in int8 format + w_packed_int4x2 = w_packed.contiguous().view(torch.uint8) + w_unpacked = torch.zeros(w_packed_int4x2.shape[0], + w_packed_int4x2.shape[1] * 2, + dtype=torch.int8) + w_unpacked[:, ::2] = w_packed_int4x2 % 16 + w_unpacked[:, 1::2] = w_packed_int4x2 // 16 + return w_unpacked.contiguous() + + def process_and_assign_weight(v: List[torch.Tensor], + tllm_prex: str, + tp_dim: int = -1): + if tp_dim == -1: + qweight_int32, qzeros_int32, scales_fp16 = [ + item.cpu() for item in v + ] + else: + qweight_int32, qzeros_int32, scales_fp16 = [ + torch_split(item, tp_dim).cpu() for item in v + ] + + USE_UINT4_INPUT = 1 # Set to true if checkpoint store UINT4 weights + USE_GPTQ_FOR_LLAMA = 1 # GPTQ-for-LLaMA added 1 to zeros + + qweight_unpacked_int8 = unpack_int32_into_int8( + qweight_int32.T).T.contiguous() - 8 + qweight_interleaved = preprocessor(packer(qweight_unpacked_int8), + torch.quint4x2).view(torch.float16) + # zeros = zeros * scales + qzeros_unpacked_int32 = unpack_int32_into_int8(qzeros_int32) + if not USE_UINT4_INPUT: + # Correcting UINT4 values back to INT4 order + mask_negative = qzeros_unpacked_int32[qzeros_unpacked_int32 < 0] + mask_positive = qzeros_unpacked_int32[qzeros_unpacked_int32 >= 0] + qzeros_unpacked_int32 = qzeros_unpacked_int32 + 16 * mask_negative - 16 * mask_positive + zeros_x_scales_fp16 = (-qzeros_unpacked_int32 + 8 * USE_UINT4_INPUT - + USE_GPTQ_FOR_LLAMA) * scales_fp16 + zeros_x_scales_fp16 = zeros_x_scales_fp16.half() + + results = { + f'{tllm_prex}.weight': qweight_interleaved, + f'{tllm_prex}.weights_scaling_factor': scales_fp16, + f'{tllm_prex}.zero': zeros_x_scales_fp16, + } + return results + + packer = torch.ops.trtllm.pack_int8_tensor_to_packed_int4 + preprocessor = torch.ops.trtllm.preprocess_weights_for_mixed_gemm + torch_dtype = str_dtype_to_torch(dtype) + + # Load weights from GPTQ checkpoint into TRT-LLM module + # 1. vocab_embedding + v = model_params['transformer.wte.weight'] + if mapping.is_first_pp_rank(): + weights['transformer.vocab_embedding.weight'] = v.to(torch_dtype) + + # 2. ln_f + v = model_params['transformer.ln_f.weight'] + if mapping.is_last_pp_rank(): + weights['transformer.ln_f.weight'] = v.to(torch_dtype) + + # 3. lm_head + v = model_params['lm_head.weight'] + if mapping.is_last_pp_rank(): + weights['lm_head.weight'] = torch_split(v, 0).to(torch_dtype) + + # 4. Weights inside each layer + layers_per_pipeline_stage = num_hidden_layers // mapping.pp_size + layers_range = list( + range(mapping.pp_rank * layers_per_pipeline_stage, + (mapping.pp_rank + 1) * layers_per_pipeline_stage, 1)) + suffixs = ["qweight", "qzeros", "scales"] + + for l in tqdm(layers_range, desc="loading weight in each layer..."): + layer_idx = l - mapping.pp_rank * layers_per_pipeline_stage + prefix = "transformer.h." + str(layer_idx) + "." + tllm_prex = f'transformer.layers.{l-layers_range[0]}' + # 4.1 attention.qkv + qkv_weight_list = [] + for suf in suffixs: + qkv_part = model_params[prefix + "attn.c_attn." + suf] + qkv_weight_list.append(qkv_part) + weights.update( + process_and_assign_weight(qkv_weight_list, + f'{tllm_prex}.attention.qkv')) + # 4.2 attention.bias + qkv_bias = model_params[prefix + "attn.c_attn.bias"].to( + torch_dtype).cpu().contiguous() + q_emb = qkv_bias.shape[0] // 3 + qkv_bias = qkv_bias.reshape(3, q_emb) + split_v = split(qkv_bias, mapping.tp_size, mapping.rank, dim=1) + split_v = split_v.reshape(3 * (q_emb // mapping.tp_size)) + weights[tllm_prex + ".attention.qkv.bias"] = split_v + # 4.3 attention.dense + qkv_dense_list = [] + for suf in suffixs: + qkv_dense_part = model_params[prefix + "attn.c_proj." + suf] + qkv_dense_list.append(qkv_dense_part) + weights.update( + process_and_assign_weight(qkv_dense_list, + f'{tllm_prex}.attention.dense', + tp_dim=0)) + # 4.4 mlp.gate + mlp_gate_list = [] + for suf in suffixs: + mlp_gate_part = model_params[prefix + "mlp.w1." + suf] + mlp_gate_list.append(mlp_gate_part) + weights.update( + process_and_assign_weight(mlp_gate_list, + f'{tllm_prex}.mlp.gate', + tp_dim=1)) + # 4.5 mlp.proj + mlp_proj_list = [] + for suf in suffixs: + mlp_proj_part = model_params[prefix + "mlp.c_proj." + suf] + mlp_proj_list.append(mlp_proj_part) + weights.update( + process_and_assign_weight(mlp_proj_list, + f'{tllm_prex}.mlp.proj', + tp_dim=0)) + # 4.6 mlp.fc + mlp_fc_list = [] + for suf in suffixs: + mlp_fc_part = model_params[prefix + "mlp.w2." + suf] + mlp_fc_list.append(mlp_fc_part) + weights.update( + process_and_assign_weight(mlp_fc_list, + f'{tllm_prex}.mlp.fc', + tp_dim=1)) + # 4.7 input_layernorm + v = model_params[prefix + "ln_1.weight"] + weights[f'{tllm_prex}.input_layernorm.weight'] = v.to(torch_dtype) + # 4.8 post_layernorm + v = model_params[prefix + "ln_2.weight"] + weights[f'{tllm_prex}.post_layernorm.weight'] = v.to(torch_dtype) + + tok = time.time() + t = time.strftime("%h:%m:%s", time.gmtime(tok - tik)) + tensorrt_llm.logger.info(f"weights loaded. total time: {t}") + + return weights + + +def make_context( + tokenizer, + query, + history, + system, + max_input_length, + max_window_size: int = 6144, + chat_format: str = "chatml", +): + if history is None: + history = [] + + if chat_format == "chatml": + im_start, im_end = "<|im_start|>", "<|im_end|>" + im_start_tokens = [tokenizer.im_start_id] + im_end_tokens = [tokenizer.im_end_id] + nl_tokens = tokenizer.encode("\n") + + def _tokenize_str(role, content): + return (f"{role}\n{content}", + tokenizer.encode( + role, + allowed_special=set(), + ) + nl_tokens + tokenizer.encode( + content, + allowed_special=set(), + )) + + system_text, system_tokens_part = _tokenize_str("system", system) + system_tokens = im_start_tokens + system_tokens_part + im_end_tokens + raw_text = "" + context_tokens = [] + + for turn_query, turn_response in reversed(history): + query_text, query_tokens_part = _tokenize_str("user", turn_query) + query_tokens = im_start_tokens + query_tokens_part + im_end_tokens + + response_text, response_tokens_part = _tokenize_str( + "assistant", turn_response) + response_tokens = im_start_tokens + response_tokens_part + im_end_tokens + next_context_tokens = nl_tokens + query_tokens + nl_tokens + response_tokens + prev_chat = ( + f"\n{im_start}{query_text}{im_end}\n{im_start}{response_text}{im_end}" + ) + + current_context_size = (len(system_tokens) + + len(next_context_tokens) + + len(context_tokens)) + if current_context_size < max_window_size: + context_tokens = next_context_tokens + context_tokens + raw_text = prev_chat + raw_text + else: + break + + context_tokens = system_tokens + context_tokens + raw_text = f"{im_start}{system_text}{im_end}" + raw_text + context_tokens += (nl_tokens + im_start_tokens + + _tokenize_str("user", query)[1] + im_end_tokens + + nl_tokens + im_start_tokens + + tokenizer.encode("assistant") + nl_tokens) + raw_text += f"\n{im_start}user\n{query}{im_end}\n{im_start}assistant\n" + + elif chat_format == "raw": + raw_text = query + context_tokens = tokenizer.encode(raw_text) + else: + raise NotImplementedError(f"Unknown chat format {chat_format!r}") + # truncate to max_input_length, truncate from the front + return raw_text, context_tokens[-max_input_length:] + + +def generate_int8(weights, act_range, is_qkv=False, multi_query_mode=False): + """ + This function has two purposes: + - compute quantized weights, scaled either per-tensor or per-column + - compute scaling factors + + Depending on the GEMM API (CUTLASS/CUBLAS) the required scaling factors differ. + CUTLASS uses two sets of scaling factors. One for the activation X, one for the weight W. + CUBLAS only has one (we can't do per-row scaling). So we must provide pre-multiplied scaling factor. + + Here is the list of what we need (T means per-tensor, C per-column): + - scale_x_orig_quant puts fp activation into the quantized range (i.e. [-128, 127], for int8). Used before the GEMM. (T) + - scale_y_quant_orig puts quantized activation into the fp range. Used if the GEMM outputs int8. (T) + - scale_w_quant_orig puts weights from quant range to fp range (used with CUTLASS) (T, C) + - scale_y_accum_quant puts the GEMM result (XW) from accumulation range (int32) + to quant range (int8) (used for CUBLAS) (T, C) + + Note that we don't do anything special about row-parallel GEMM. Theoretically, we could have per-GPU scaling factors too, + but then the model would change depending on the number of GPUs used. + + For QKV projection, the behavior is special. Even if we have a single matrix to perform QKV projection, we consider it + as three different matrices: Q, K, and V. So per-tensor actually means one scaling factor for each Q, K and V. + For our GEMM implementation to respect this behavior, we use per-column mode and replicate values along columns. + """ + weights = weights.detach().cpu().numpy() + + # compute weight scaling factors for fp->int8 and int8->fp + if is_qkv and not multi_query_mode: + scale_w_orig_quant_t = 127. / act_range["w"].reshape(3, -1).max( + dim=-1, keepdims=True)[0].cpu().numpy() + scale_w_orig_quant_c = 127. / act_range["w"].reshape(3, + -1).cpu().numpy() + elif is_qkv and multi_query_mode: + hidden_dim = weights.shape[0] + local_dim = act_range["w"].shape[0] + kv_dim = (local_dim - hidden_dim) // 2 + scale_w_q = act_range["w"][0:hidden_dim] + scale_w_k = act_range["w"][hidden_dim:hidden_dim + kv_dim] + scale_w_v = act_range["w"][-kv_dim:] + + scale_w_qkv_t = torch.concat([ + scale_w_q.max(dim=0, keepdim=True)[0], + scale_w_k.max(dim=0, keepdim=True)[0], + scale_w_v.max(dim=0, keepdim=True)[0] + ]) + + scale_w_orig_quant_t = 127. / scale_w_qkv_t.cpu().numpy() + scale_w_orig_quant_c = 127. / act_range["w"].cpu().numpy() + else: + scale_w_orig_quant_t = 127. / act_range["w"].max().cpu().numpy() + scale_w_orig_quant_c = 127. / act_range["w"].cpu().numpy() + scale_w_quant_orig_t = 1.0 / scale_w_orig_quant_t + scale_w_quant_orig_c = 1.0 / scale_w_orig_quant_c + + scale_w_orig_quant_c = scale_w_orig_quant_c.astype(np.float32) + scale_w_orig_quant_t = scale_w_orig_quant_t.astype(np.float32) + + # compute the rest of needed scaling factors + scale_x_orig_quant_t = np.array(127. / act_range["x"].max().item()) + scale_y_orig_quant_t = np.array(127. / act_range["y"].max().item()) + scale_y_quant_orig_t = np.array(act_range["y"].max().item() / 127.) + scale_y_accum_quant_t = scale_y_orig_quant_t / (scale_x_orig_quant_t * + scale_w_orig_quant_t) + scale_y_accum_quant_c = scale_y_orig_quant_t / (scale_x_orig_quant_t * + scale_w_orig_quant_c) + if is_qkv and not multi_query_mode: + scale_y_accum_quant_t = np.broadcast_to(scale_y_accum_quant_t, + scale_w_orig_quant_c.shape) + scale_w_quant_orig_t = np.broadcast_to(scale_w_quant_orig_t, + scale_w_orig_quant_c.shape) + if is_qkv and multi_query_mode: + scale_q_y_accum_t = np.broadcast_to(scale_y_accum_quant_t[0], + scale_w_q.shape) + scale_k_y_accum_t = np.broadcast_to(scale_y_accum_quant_t[1], + scale_w_k.shape) + scale_v_y_accum_t = np.broadcast_to(scale_y_accum_quant_t[2], + scale_w_v.shape) + scale_y_accum_quant_t = np.concatenate( + [scale_q_y_accum_t, scale_k_y_accum_t, scale_v_y_accum_t]) + scale_w_quant_orig_t = np.concatenate([ + np.broadcast_to(scale_w_quant_orig_t[0], scale_w_q.shape), + np.broadcast_to(scale_w_quant_orig_t[1], scale_w_k.shape), + np.broadcast_to(scale_w_quant_orig_t[2], scale_w_v.shape) + ]) + + to_i8 = lambda x: x.round().clip(-127, 127).astype(np.int8) + + if is_qkv and multi_query_mode: + weight_int8 = to_i8(weights / scale_w_quant_orig_t) + else: + weight_int8 = to_i8(weights * scale_w_orig_quant_t) + return { + "weight.int8": weight_int8, + "weight.int8.col": to_i8(weights * scale_w_orig_quant_c), + "scale_x_orig_quant": scale_x_orig_quant_t.astype(np.float32), + "scale_w_quant_orig": scale_w_quant_orig_t.astype(np.float32), + "scale_w_quant_orig.col": scale_w_quant_orig_c.astype(np.float32), + "scale_y_accum_quant": scale_y_accum_quant_t.astype(np.float32), + "scale_y_accum_quant.col": scale_y_accum_quant_c.astype(np.float32), + "scale_y_quant_orig": scale_y_quant_orig_t.astype(np.float32), + } + + +@torch.no_grad() +def apply_smoothing(scales, + gemm_weights, + layernorm_weights=None, + layernorm_bias=None, + dtype=torch.float32, + layernorm_1p=False): + if not isinstance(gemm_weights, list): + gemm_weights = [gemm_weights] + + if layernorm_weights is not None: + assert layernorm_weights.numel() == scales.numel() + layernorm_weights.div_(scales).to(dtype) + if layernorm_bias is not None: + assert layernorm_bias.numel() == scales.numel() + layernorm_bias.div_(scales).to(dtype) + if layernorm_1p: + layernorm_weights += (1 / scales) - 1 + + for gemm in gemm_weights: + gemm.mul_(scales.view(1, -1)).to(dtype) + + +@torch.no_grad() +def smooth_gemm(gemm_weights, + act_scales, + layernorm_weights=None, + layernorm_bias=None, + alpha=0.5, + weight_scales=None): + if not isinstance(gemm_weights, list): + gemm_weights = [gemm_weights] + orig_dtype = gemm_weights[0].dtype + + for gemm in gemm_weights: + # gemm_weights are expected to be transposed + assert gemm.shape[1] == act_scales.numel() + + if weight_scales is None: + weight_scales = torch.cat( + [gemm.abs().max(dim=0, keepdim=True)[0] for gemm in gemm_weights], + dim=0) + weight_scales = weight_scales.max(dim=0)[0] + weight_scales.to(float).clamp(min=1e-5) + scales = (act_scales.to(gemm_weights[0].device).to(float).pow(alpha) / + weight_scales.pow(1 - alpha)).clamp(min=1e-5) + + apply_smoothing(scales, gemm_weights, layernorm_weights, layernorm_bias, + orig_dtype) + + return scales + + +@torch.no_grad() +def smooth_gemm_fc1_gate(fc1_weights, + gate_weights, + act_scales, + layernorm_weights=None, + layernorm_bias=None, + alpha=0.5, + weight_scales=None): + gemm_weights = [] + if not isinstance(fc1_weights, list): + fc1_weights = [fc1_weights] + if not isinstance(gate_weights, list): + gate_weights = [gate_weights] + + for i in range(len(fc1_weights)): + gemm_weight = torch.cat([fc1_weights[i], gate_weights[i]], dim=0) + gemm_weights.append(gemm_weight) + + orig_dtype = gemm_weights[0].dtype + + for gemm in gemm_weights: + # gemm_weights are expected to be transposed + assert gemm.shape[1] == act_scales.numel() + + if weight_scales is None: + weight_scales = torch.cat( + [gemm.abs().max(dim=0, keepdim=True)[0] for gemm in gemm_weights], + dim=0) + weight_scales = weight_scales.max(dim=0)[0] + weight_scales.to(float).clamp(min=1e-5) + scales = (act_scales.to(gemm_weights[0].device).to(float).pow(alpha) / + weight_scales.pow(1 - alpha)).clamp(min=1e-5) + + apply_smoothing(scales, fc1_weights + gate_weights, layernorm_weights, + layernorm_bias, orig_dtype) + + return scales + + +@torch.no_grad() +def smooth_qwen_model(model, scales, alpha, qwen_qkv_para, qwen_smoother): + # Smooth the activation and weights with smoother = $\diag{s}$ + for name, module in model.named_modules(): + if not module._get_name() == "QWenBlock": + continue + # qkv_proj + layer_name = name + ".attn.c_attn" + smoother = smooth_gemm(module.attn.c_attn.weight, + scales[layer_name]["x"], module.ln_1.weight, + None, alpha) + + scales[layer_name]["x"] = scales[layer_name]["x"] / smoother + scales[layer_name]["w"] = module.attn.c_attn.weight.abs().max(dim=1)[0] + + # see transpose_weights function + qwen_qkv_para[layer_name] = module.attn.c_attn.weight.transpose(0, 1) + + # ================================================================= + layer_name = name + ".attn.c_proj" + smoother = smooth_gemm( + module.attn.c_proj.weight, + scales[layer_name]["x"], + None, + None, + alpha=alpha, + ) + qwen_smoother[layer_name] = smoother.float() + + scales[layer_name]["x"] = scales[layer_name]["x"] / smoother + scales[layer_name]["w"] = module.attn.c_proj.weight.abs().max(dim=1)[0] + # ================================================================== + fc1_layer_name = name + ".mlp.w1" + gate_layer_name = name + ".mlp.w2" + + smoother = smooth_gemm_fc1_gate(module.mlp.w1.weight, + module.mlp.w2.weight, + scales[fc1_layer_name]["x"], + module.ln_2.weight, None, alpha) + + scales[fc1_layer_name]["x"] = scales[fc1_layer_name]["x"] / smoother + scales[fc1_layer_name]["w"] = module.mlp.w1.weight.abs().max(dim=1)[0] + + scales[gate_layer_name]["x"] = scales[gate_layer_name]["x"] / smoother + scales[gate_layer_name]["w"] = module.mlp.w2.weight.abs().max(dim=1)[0] + + # ================================================================== + layer_name = name + ".mlp.c_proj" + smoother = smooth_gemm(module.mlp.c_proj.weight, + scales[layer_name]["x"], None, None, alpha) + qwen_smoother[layer_name] = smoother.float() + scales[layer_name]["x"] = scales[layer_name]["x"] / smoother + scales[layer_name]["w"] = module.mlp.c_proj.weight.abs().max(dim=1)[0] + + +@torch.no_grad() +def capture_activation_range(model, + tokenizer, + dataset, + system_prompt, + chat_format, + num_samples=512, + seq_len=512): + model.eval() + device = next(model.parameters()).device + act_scales = defaultdict(lambda: {"x": None, "y": None, "w": None}) + + tokenizer.pad_token_id = tokenizer.im_end_id + + def stat_tensor(name, tensor, act_scales, key): + hidden_dim = tensor.shape[-1] + tensor = tensor.view(-1, hidden_dim).abs().detach() + comming_max = torch.max(tensor, dim=0)[0].float() + + if act_scales[name][key] is None: + act_scales[name][key] = comming_max + else: + act_scales[name][key] = torch.max(act_scales[name][key], + comming_max) + + def stat_input_hook(m, x, y, name): + if isinstance(x, tuple): + x = x[0] + stat_tensor(name, x, act_scales, "x") + stat_tensor(name, y, act_scales, "y") + + if act_scales[name]["w"] is None: + act_scales[name]["w"] = m.weight.abs().clip(1e-8, + None).max(dim=1)[0] + + hooks = [] + for name, m in model.named_modules(): + if isinstance(m, nn.Linear) or isinstance(m, Conv1D): + hooks.append( + m.register_forward_hook( + functools.partial(stat_input_hook, name=name))) + + for i in tqdm(range(num_samples), desc="calibrating model"): + line = dataset['train'][i]["article"] + line = line + ' TL;DR: ' + line = line.strip() + line = line.replace(" n't", "n't") + _, input_id_list = make_context(tokenizer=tokenizer, + query=line, + history=[], + system=system_prompt, + chat_format=chat_format, + max_input_length=seq_len) + line_encoded = torch.from_numpy(np.array( + input_id_list, dtype=np.int32)).type(torch.int32).unsqueeze(0) + line_encoded = line_encoded.to(device) + model(line_encoded) + for h in hooks: + h.remove() + return act_scales + + +def split(v, tp_size, idx, dim=0): + if tp_size == 1: + return v + if len(v.shape) == 1: + return torch.chunk(v, tp_size)[idx].contiguous() + else: + return torch.chunk(v, tp_size, dim=dim)[idx].contiguous() + + +def split_qkv_tp(v, n_head, n_hidden, tensor_parallel, rank): + """ + Splits the QKV matrix according to tensor parallelism + """ + v = v.reshape(3, n_hidden, n_hidden) + split_v = split(v, tensor_parallel, rank, dim=1) + split_v = split_v.reshape(3 * (n_hidden // tensor_parallel), n_hidden) + return split_v.contiguous() + + +def split_qkv_bias_tp(v, n_head, n_hidden, tensor_parallel, rank): + """ + Splits the QKV bias according to tensor parallelism + """ + v = v.reshape(3, n_hidden) + split_v = split(v, tensor_parallel, rank, dim=1) + split_v = split_v.reshape(3 * (n_hidden // tensor_parallel)) + return split_v.contiguous() + + +def split_matrix_tp(v, tensor_parallel, rank, dim): + return split(v, tensor_parallel, rank, dim=dim) + + +def get_weight(config, prefix, dtype): + if config[prefix + '.weight'].dtype != dtype: + config[prefix + '.weight'].data = config[prefix + '.weight'].to(dtype) + return config[prefix + '.weight'] + + +def get_bias(config, prefix, dtype): + if config[prefix + '.bias'].dtype != dtype: + config[prefix + '.bias'].data = config[prefix + '.bias'].to(dtype) + return config[prefix + '.bias'] + + +def get_weight_and_bias(config, prefix, dtype): + return get_weight(config, prefix, dtype), get_bias(config, prefix, dtype) + + +def get_tllm_linear_weight(weight, + prefix, + bias=None, + use_weight_only=False, + plugin_weight_only_quant_type=torch.int8, + dtype='float32', + use_gemm_woq_plugin=True, + postfix='weight'): + results = {} + if use_weight_only: + v = weight.t().contiguous() + processed_torch_weights, torch_weight_scales = \ + torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix( + v.cpu(), plugin_weight_only_quant_type) + if not use_gemm_woq_plugin: + results[prefix + postfix] = v.to(dtype) + else: + results[prefix + postfix] = processed_torch_weights + results[prefix + 'per_channel_scale'] = torch_weight_scales + else: + results[prefix + postfix] = weight.contiguous() + + if bias is not None: + results[prefix + 'bias'] = bias + + return results + + +def get_tllm_linear_sq_weight(vals, + prefix, + shape, + tensor_parallel, + is_qkv=False, + per_token=False, + per_channel=False, + last_prefix=None, + bias=None, + smoother_value=None, + smoother_shape=None, + rank=0, + cat_dim=0, + multi_query_mode=False): + results = {} + + def multi_query_split(data, local_dim, head_size, tp_size, cur_rank): + q, k, v = np.split(data, [local_dim, local_dim + head_size], axis=-1) + q_split = np.split(q, tp_size, axis=-1) + k_split = np.split(k, tp_size, axis=-1) + v_split = np.split(v, tp_size, axis=-1) + return [ + np.concatenate((q_split[ii], k_split[ii], v_split[ii]), axis=-1) + for ii in range(tp_size) + ][cur_rank] + + col_shape = shape if (is_qkv or per_channel) else [1, 1] + + if per_token: + original_weights = vals["weight.int8.col"] + + local_dim = original_weights.shape[0] + head_size = (original_weights.shape[1] - local_dim) // 2 + if multi_query_mode: + cur_weights = multi_query_split(original_weights, local_dim, + head_size, tensor_parallel, rank) + else: + cur_weights = np.split(original_weights, + tensor_parallel, + axis=cat_dim)[rank] + if is_qkv: + hidden_dim = cur_weights.shape[0] + cur_weights = cur_weights.reshape(hidden_dim, -1) + results[prefix + + 'weight'] = torch.from_numpy(cur_weights).t().contiguous() + if smoother_value is None: + results[last_prefix] = torch.from_numpy( + np.array([1.0], dtype=np.float32)) + + if smoother_value is None: + if multi_query_mode: + cur_per_channel_value = multi_query_split( + vals["scale_w_quant_orig.col"], local_dim, head_size, + tensor_parallel, rank) + else: + cur_per_channel_value = np.split(vals["scale_w_quant_orig.col"], + tensor_parallel, + axis=cat_dim)[rank] + else: + cur_per_channel_value = vals["scale_w_quant_orig.col"] + results[prefix + 'per_channel_scale'] = torch.from_numpy( + np.array(cur_per_channel_value, + dtype=np.float32).reshape(col_shape)).contiguous() + else: + original_weights = np.array(vals["weight.int8"]) + cur_weights = np.split(original_weights, tensor_parallel, + axis=cat_dim)[rank] + + if is_qkv: + hidden_dim = cur_weights.shape[0] + cur_weights = cur_weights.reshape(hidden_dim, -1) + results[prefix + + 'weight'] = torch.from_numpy(cur_weights).t().contiguous() + # 'weight'] = torch.from_numpy(cur_weights).t().contiguous() + + cur_per_channel_value = vals["scale_y_accum_quant"] + + results[prefix + 'per_channel_scale'] = torch.from_numpy( + np.array([cur_per_channel_value], + dtype=np.float32).reshape(col_shape)).contiguous() + + results[last_prefix] = torch.from_numpy( + np.array([vals['scale_x_orig_quant']], + dtype=np.float32)).contiguous() + + results[prefix + 'act_scale'] = torch.from_numpy( + np.array([[vals["scale_y_quant_orig"]]], + dtype=np.float32)).contiguous() + + if smoother_value is not None: + cur_smoother_value = np.split(smoother_value, + tensor_parallel, + axis=cat_dim)[rank] + results[prefix + 'smoother'] = cur_smoother_value.reshape( + smoother_shape).contiguous().to(torch.float32) + + if bias is not None: + results[prefix + 'bias'] = bias + + return results + + +def convert_hf_qwen(hf_model, + mapping, + vocab_size=32000, + dtype='float32', + use_parallel_embedding=False, + sharding_dim=0, + use_weight_only=False, + share_embedding_table=False, + use_gemm_woq_plugin=False, + plugin_weight_only_quant_type=torch.int8, + use_smooth_quant=False, + per_channel=False, + per_token=False, + int8_kv_cache=False, + act_range=[], + qkv_para=[], + smoother=[]): + weights = {} + tik = time.time() + tensor_parallel = mapping.tp_size + model_params = dict(hf_model.named_parameters()) + dtype = getattr(torch, dtype) + num_attention_heads = hf_model.config.num_attention_heads + hidden_size = hf_model.config.hidden_size + intermediate_size = hf_model.config.intermediate_size // 2 # Qwen's actual intermediate_size is one half of what's in hf_config + num_key_value_heads = hf_model.config.num_key_value_heads if hasattr( + hf_model.config, "num_key_value_heads") else num_attention_heads + mha_mode = (num_key_value_heads == num_attention_heads) + assert mha_mode == True, "QWen uses MHA." + layers_range = mapping.pp_layers(hf_model.config.num_hidden_layers) + + for l in layers_range: + prefix = f'transformer.h.{l}.' + tllm_prex = f'transformer.layers.{l - layers_range[0]}.' + qkv_weight, qkv_bias = get_weight_and_bias(model_params, + prefix + 'attn.c_attn', + dtype) + qkv_w = split_qkv_tp(qkv_weight, num_attention_heads, hidden_size, + tensor_parallel, mapping.tp_rank) + qkv_b = split_qkv_bias_tp(qkv_bias, num_attention_heads, hidden_size, + tensor_parallel, mapping.tp_rank) + + if use_smooth_quant: + qkv_weight = qkv_para[prefix + 'attn.c_attn'] + qkv_weight = qkv_weight.reshape(hidden_size, 3, hidden_size) + + int8_weights = generate_int8(qkv_weight, + act_range.get(prefix + 'attn.c_attn'), + is_qkv=True, + multi_query_mode=bool(not mha_mode)) + + weights.update( + get_tllm_linear_sq_weight( + int8_weights, + tllm_prex + 'attention.qkv.', [ + 1, 3 * hidden_size // tensor_parallel + if mha_mode else hidden_size // tensor_parallel + + (hidden_size // num_key_value_heads) // + tensor_parallel * 2 + ], + tensor_parallel, + is_qkv=True, + per_token=per_token, + per_channel=per_channel, + last_prefix=tllm_prex + 'input_layernorm.scale_to_int', + bias=qkv_bias, + smoother_value=None, + smoother_shape=None, + rank=mapping.tp_rank, + cat_dim=-1, + multi_query_mode=bool(not mha_mode))) + else: + weights.update( + get_tllm_linear_weight(qkv_w, tllm_prex + 'attention.qkv.', + qkv_b, use_weight_only, + plugin_weight_only_quant_type, dtype, + use_gemm_woq_plugin)) + + if int8_kv_cache: + qkv_y = act_range.get(prefix + 'attn.c_attn')["y"] + + int8_kv_scales = qkv_y.max() / 127. + + kv_cache_weights = {} + + kv_cache_weights[ + tllm_prex + + 'attention.kv_cache_scaling_factor'] = int8_kv_scales.reshape( + [1]) + + weights.update(kv_cache_weights) + + attn_dense_weight = get_weight(model_params, prefix + 'attn.c_proj', + dtype) + split_v = split_matrix_tp(attn_dense_weight, + tensor_parallel, + mapping.tp_rank, + dim=1) + if use_smooth_quant: + attn_dense_weight = attn_dense_weight.t() + int8_weights = generate_int8(attn_dense_weight, + act_range.get(prefix + 'attn.c_proj')) + weights.update( + get_tllm_linear_sq_weight( + int8_weights, + tllm_prex + 'attention.dense.', [1, hidden_size], + tensor_parallel, + is_qkv=False, + per_token=per_token, + per_channel=per_channel, + last_prefix=tllm_prex + + 'attention.quantization_scaling_factor', + smoother_value=smoother[(prefix + 'attn.c_proj')], + smoother_shape=[1, hidden_size // tensor_parallel], + rank=mapping.tp_rank, + cat_dim=0)) + else: + weights.update( + get_tllm_linear_weight(split_v, tllm_prex + 'attention.dense.', + None, use_weight_only, + plugin_weight_only_quant_type, dtype, + use_gemm_woq_plugin)) + + mlp_gate_weight = get_weight(model_params, prefix + 'mlp.w1', dtype) + split_v = split_matrix_tp(mlp_gate_weight, + tensor_parallel, + mapping.tp_rank, + dim=0) + if use_smooth_quant: + mlp_gate_weight = mlp_gate_weight.t() + int8_weights = generate_int8(mlp_gate_weight, + act_range.get(prefix + 'mlp.w1')) + + weights.update( + get_tllm_linear_sq_weight( + int8_weights, + tllm_prex + 'mlp.gate.', + [1, intermediate_size // tensor_parallel], + tensor_parallel, + is_qkv=False, + per_token=per_token, + per_channel=per_channel, + last_prefix=tllm_prex + 'post_layernorm.scale_to_int', + smoother_value=None, + smoother_shape=None, + rank=mapping.tp_rank, + cat_dim=-1)) + else: + weights.update( + get_tllm_linear_weight(split_v, tllm_prex + 'mlp.gate.', None, + use_weight_only, + plugin_weight_only_quant_type, dtype, + use_gemm_woq_plugin)) + + mlp_fc_weight = get_weight(model_params, prefix + 'mlp.w2', dtype) + split_v = split_matrix_tp(mlp_fc_weight, + tensor_parallel, + mapping.tp_rank, + dim=0) + + if use_smooth_quant: + mlp_fc_weight = mlp_fc_weight.t() #verified + int8_weights = generate_int8(mlp_fc_weight, + act_range.get(prefix + 'mlp.w2')) + weights.update( + get_tllm_linear_sq_weight( + int8_weights, + tllm_prex + 'mlp.fc.', + [1, intermediate_size // tensor_parallel], + tensor_parallel, + is_qkv=False, + per_token=per_token, + per_channel=per_channel, + last_prefix=tllm_prex + 'post_layernorm.scale_to_int', + smoother_value=None, + smoother_shape=None, + rank=mapping.tp_rank, + cat_dim=-1)) + else: + weights.update( + get_tllm_linear_weight(split_v, tllm_prex + 'mlp.fc.', None, + use_weight_only, + plugin_weight_only_quant_type, dtype, + use_gemm_woq_plugin)) + + mlp_proj_weight = get_weight(model_params, prefix + 'mlp.c_proj', dtype) + split_v = split_matrix_tp(mlp_proj_weight, + tensor_parallel, + mapping.tp_rank, + dim=1) + + if use_smooth_quant: + mlp_proj_weight = mlp_proj_weight.t() + int8_weights = generate_int8(mlp_proj_weight, + act_range.get(prefix + 'mlp.c_proj')) + weights.update( + get_tllm_linear_sq_weight( + int8_weights, + tllm_prex + 'mlp.proj.', [1, hidden_size], + tensor_parallel, + is_qkv=False, + per_token=per_token, + per_channel=per_channel, + last_prefix=tllm_prex + 'mlp.quantization_scaling_factor', + smoother_value=smoother[prefix + 'mlp.c_proj'], + smoother_shape=[1, intermediate_size // tensor_parallel], + rank=mapping.tp_rank, + cat_dim=0)) + else: + weights.update( + get_tllm_linear_weight(split_v, tllm_prex + 'mlp.proj.', None, + use_weight_only, + plugin_weight_only_quant_type, dtype, + use_gemm_woq_plugin)) + + # Layer norms do not use tensor parallelism + input_ln_weight = get_weight(model_params, prefix + 'ln_1', dtype) + weights[tllm_prex + 'input_layernorm.weight'] = input_ln_weight + + post_ln_weight = get_weight(model_params, prefix + 'ln_2', dtype) + weights[tllm_prex + 'post_layernorm.weight'] = post_ln_weight + + v = get_weight(model_params, 'transformer.wte', dtype) + + if hf_model.config.tie_word_embeddings: + # lm_head.weight has the same weights as embedding + if mapping.is_last_pp_rank(): + if vocab_size % mapping.tp_size != 0: + # padding + vocab_size_padded = pad_vocab_size(vocab_size, mapping.tp_size) + pad_width = vocab_size_padded - vocab_size + + v = torch.from_numpy( + np.pad(v.detach().cpu().numpy(), ((0, pad_width), (0, 0)), + 'constant', + constant_values=0)) + weights['lm_head.weight'] = split(v, mapping.tp_size, + mapping.tp_rank) + + if use_parallel_embedding: + v = split_matrix_tp(v, + mapping.tp_size, + mapping.tp_rank, + dim=sharding_dim) + + if mapping.is_first_pp_rank(): + weights['transformer.vocab_embedding.weight'] = v + + lm_head_weights = get_weight(model_params, 'lm_head', dtype) + + if mapping.is_last_pp_rank(): + + if vocab_size % mapping.tp_size != 0: + # padding + vocab_size_padded = pad_vocab_size(vocab_size, mapping.tp_size) + pad_width = vocab_size_padded - vocab_size + + lm_head_weights = torch.from_numpy( + np.pad(lm_head_weights.detach().cpu().numpy(), + ((0, pad_width), (0, 0)), + 'constant', + constant_values=0)) + weights['lm_head.weight'] = split_matrix_tp(lm_head_weights, + tensor_parallel, + mapping.tp_rank, + dim=0) + ln_f_w = get_weight(model_params, 'transformer.ln_f', dtype) + weights['transformer.ln_f.weight'] = ln_f_w + + tok = time.time() + t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) + print(f'Weights loaded. Total time: {t}') + return weights + + +def main(): + # TODO(qijun): Currently, the convert script depends on a torch op: + # torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix, + # which is included in tensorrt_llm Python package. Otherwise, the convert + # script does not need to import tensorrt_llm. Will remove it after reimplementing + # the op with PyTorch. + print(tensorrt_llm.__version__) + args = parse_arguments() + world_size = args.tp_size * args.pp_size + + tik = time.time() + + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + hf_config = None + if args.model_dir is not None: + hf_config = AutoConfig.from_pretrained(args.model_dir, + trust_remote_code=True) + args.model_type = hf_config.model_type + args.n_head = hf_config.num_attention_heads + args.inter_size = hf_config.intermediate_size + args.n_layer = hf_config.num_hidden_layers + args.n_embd = hf_config.hidden_size + if hasattr(hf_config, "num_key_value_heads"): + args.n_kv_head = hf_config.num_key_value_heads + args.rms_norm_eps = hf_config.layer_norm_epsilon + args.vocab_size = hf_config.vocab_size + args.n_positions = hf_config.max_position_embeddings + args.rotary_base = hf_config.rotary_emb_base + args.n_kv_head = args.n_kv_head or args.n_head + + if args.rotary_scaling is not None: + # assert args.use_gpt_attention_plugin, "RoPE scaling is only supported through GPT attention plugin." + rotary_scaling = { + "type": args.rotary_scaling[0], + "factor": float(args.rotary_scaling[1]) + } + assert rotary_scaling["type"] in ["linear", "dynamic"] + assert rotary_scaling["factor"] > 1.0 + args.rotary_scaling = rotary_scaling + + config = { + 'architecture': "QWenForCausalLM", + 'dtype': args.dtype, + 'logits_dtype': 'float32', + 'num_hidden_layers': args.n_layer, + 'num_attention_heads': args.n_head, + 'hidden_size': args.n_embd, + 'intermediate_size': args.inter_size, + 'num_key_value_heads': args.n_kv_head, + 'vocab_size': args.vocab_size, + 'position_embedding_type': 'rope_gpt_neox', + 'max_position_embeddings': args.n_positions, + 'hidden_act': args.hidden_act, + 'rotary_base': args.rotary_base, + 'rotary_scaling': args.rotary_scaling, + 'norm_epsilon': args.rms_norm_eps, + 'quantization': { + 'quant_algo': None, + 'kv_cache_quant_algo': None, + }, + 'mapping': { + 'world_size': world_size, + 'tp_size': args.tp_size, + 'pp_size': args.pp_size, + }, + 'use_parallel_embedding': args.use_parallel_embedding, + 'embedding_sharding_dim': args.embedding_sharding_dim, + 'share_embedding_table': args.use_embedding_sharing, + 'dense_context_fmha': args.dense_context_fmha, + 'disable_weight_only_quant_plugin': + args.disable_weight_only_quant_plugin + } + + if args.use_weight_only: + if args.weight_only_precision == 'int8': + config['quantization']['quant_algo'] = 'W8A16' + elif args.weight_only_precision == 'int4': + config['quantization']['quant_algo'] = 'W4A16' + elif args.smoothquant: + if args.per_channel: + if args.per_token: + config['quantization'][ + 'quant_algo'] = 'W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN' + else: + config['quantization'][ + 'quant_algo'] = 'W8A8_SQ_PER_CHANNEL_PER_TENSOR_PLUGIN' + else: + if args.per_token: + config['quantization'][ + 'quant_algo'] = 'W8A8_SQ_PER_TENSOR_PER_TOKEN_PLUGIN' + else: + config['quantization'][ + 'quant_algo'] = 'W8A8_SQ_PER_TENSOR_PLUGIN' + + if args.int8_kv_cache: + config['quantization']['kv_cache_quant_algo'] = 'INT8' + + if args.weight_only_precision == 'int4_gptq': + config['quantization'].update({ + "group_size": args.group_size, + "has_zero_point": True, + "pre_quant_scale": False, + 'quant_algo': 'W4A16_GPTQ' + }) + + with open(os.path.join(args.output_dir, 'config.json'), 'w') as f: + json.dump(config, f, indent=4) + + if args.model_dir is None: + return + + if args.weight_only_precision == 'int8': + plugin_weight_only_quant_type = torch.int8 + elif args.weight_only_precision == 'int4': + plugin_weight_only_quant_type = torch.quint4x2 + + act_range = {} + qwen_qkv_para = {} + # smoother for inputs of self_attn.o_proj and mlp.down_proj + qwen_smoother = {} + model = None + if args.model_dir is not None: + if args.use_weight_only and args.weight_only_precision == 'int4_gptq': + model = AutoModelForCausalLM.from_pretrained( + args.model_dir, device_map="auto", + trust_remote_code=True).eval().cpu() + else: + model = AutoModelForCausalLM.from_pretrained( + args.model_dir, + device_map='auto' if not args.load_model_on_cpu else 'cpu', + torch_dtype='auto' if not args.smoothquant else torch.float16, + trust_remote_code=True, + ).half() + + if args.smoothquant is not None or args.int8_kv_cache: + os.environ["TOKENIZERS_PARALLELISM"] = os.environ.get( + "TOKENIZERS_PARALLELISM", "false") + if args.load_model_on_cpu: + logger.warning( + "Note that running capture_activation_range on cpu would be very small." + ) + dataset = load_dataset("ccdv/cnn_dailymail", + '3.0.0', + cache_dir=args.dataset_cache_dir) + system_prompt = "You are a useful assistant, please directly output the corresponding summary according to the article entered by the user." + gen_config_path = os.path.join(args.model_dir, + 'generation_config.json') + with open(gen_config_path, 'r') as f: + gen_config = json.load(f) + chat_format = gen_config['chat_format'] + act_range = capture_activation_range( + model, + AutoTokenizer.from_pretrained(args.model_dir, + trust_remote_code=True, + use_fast=False, + padding_side='left'), dataset, + system_prompt, chat_format) + if args.smoothquant is not None: + smooth_qwen_model(model, act_range, args.smoothquant, + qwen_qkv_para, qwen_smoother) + convert_args = { + 'hf_model': model, + 'act_range': act_range, + 'qwen_qkv_para': qwen_qkv_para, + 'qwen_smoother': qwen_smoother, + } + + def covert_and_save(rank, convert_args): + mapping = Mapping(world_size=world_size, + rank=rank, + tp_size=args.tp_size, + pp_size=args.pp_size) + + if args.use_weight_only and args.weight_only_precision == 'int4_gptq': + weights = load_from_gptq_qwen(convert_args['hf_model'], + args.n_layer, + mapping, + dtype=args.dtype) + + else: + weights = convert_hf_qwen( + convert_args['hf_model'], + mapping, + vocab_size=args.vocab_size, + dtype=args.dtype, + use_weight_only=args.use_weight_only, + use_gemm_woq_plugin=not args.disable_weight_only_quant_plugin, + plugin_weight_only_quant_type=plugin_weight_only_quant_type, + use_parallel_embedding=args.use_parallel_embedding, + sharding_dim=args.embedding_sharding_dim, + share_embedding_table=args.use_embedding_sharing, + use_smooth_quant=args.smoothquant, + per_channel=args.per_channel, + per_token=args.per_token, + int8_kv_cache=args.int8_kv_cache, + act_range=convert_args['act_range'], + qkv_para=convert_args['qwen_qkv_para'], + smoother=convert_args['qwen_smoother']) + + safetensors.torch.save_file( + weights, os.path.join(args.output_dir, f'rank{rank}.safetensors')) + + if args.workers == 1: + + for rank in range(world_size): + covert_and_save(rank, convert_args) + else: + with ThreadPoolExecutor(max_workers=args.workers) as p: + futures = [ + p.submit(covert_and_save, rank, convert_args) + for rank in range(world_size) + ] + exceptions = [] + for future in as_completed(futures): + try: + future.result() + except Exception as e: + traceback.print_exc() + exceptions.append(e) + assert len( + exceptions + ) == 0, "Checkpoint conversion failed, please check error log." + + tok = time.time() + t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) + print(f'Total time of converting checkpoints: {t}') + + +if __name__ == '__main__': + main() diff --git a/examples/qwen/hf_qwen_convert.py b/examples/qwen/hf_qwen_convert.py deleted file mode 100644 index 248a95f13..000000000 --- a/examples/qwen/hf_qwen_convert.py +++ /dev/null @@ -1,361 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -''' -Convert huggingface QWen-7B-Chat model to numpy file. -Use https://huggingface.co/Qwen/Qwen-7B-Chat as demo. -''' -import argparse -import configparser -import dataclasses -import json -import os -from pathlib import Path - -import torch -import torch.multiprocessing as multiprocessing -from smoothquant import capture_activation_range, smooth_gemm, smooth_gemm_mlp -from tqdm import tqdm -from transformers import AutoModelForCausalLM # transformers-4.10.0-py3 -from transformers import AutoTokenizer, GenerationConfig -# for debug -from utils.convert import split_and_save_weight - -from tensorrt_llm._utils import str_dtype_to_torch, torch_to_numpy - -now_dir = os.path.dirname(os.path.abspath(__file__)) - - -@dataclasses.dataclass(frozen=True) -class ProgArgs: - out_dir: str - in_file: str - max_input_len: int = 2048 - tensor_parallelism: int = 1 - processes: int = 1 - calibrate_kv_cache: bool = False - smoothquant: float = None - model: str = "qwen" - storage_type: str = "fp32" - dataset_cache_dir: str = None - - @staticmethod - def parse(args=None) -> 'ProgArgs': - parser = argparse.ArgumentParser( - formatter_class=argparse.RawTextHelpFormatter) - parser.add_argument('--out-dir', - '-o', - type=str, - help='file name of output directory', - required=True) - parser.add_argument('--in-file', - '-i', - type=str, - help='file name of input checkpoint file', - required=True) - parser.add_argument( - '--max_input_len', - type=int, - help= - "This should be consistent with the max_input_len you used when building engine.", - default=2048) - parser.add_argument('--tensor-parallelism', - '-tp', - type=int, - help='Requested tensor parallelism for inference', - default=1) - parser.add_argument( - "--processes", - "-p", - type=int, - help= - "How many processes to spawn for conversion (default: 1). Set it to a lower value to reduce RAM usage.", - default=1) - parser.add_argument( - "--calibrate-kv-cache", - "-kv", - action="store_true", - help= - "Generate scaling factors for KV cache. Used for storing KV cache in int8." - ) - parser.add_argument( - "--smoothquant", - "-sq", - type=float, - default=None, - help="Set the α parameter (see https://arxiv.org/pdf/2211.10438.pdf)" - " to Smoothquant the model, and output int8 weights." - " A good first try is 0.5. Must be in [0, 1]") - parser.add_argument( - "--model", - default="qwen", - type=str, - help="Specify GPT variants to convert checkpoints correctly", - choices=["qwen", "gpt2", "santacoder", "starcoder"]) - parser.add_argument("--storage-type", - "-t", - type=str, - default="float16", - choices=["float32", "float16", "bfloat16"]) - parser.add_argument("--dataset-cache-dir", - type=str, - default=None, - help="cache dir to load the hugging face dataset") - return ProgArgs(**vars(parser.parse_args(args))) - - -@torch.no_grad() -def smooth_qwen_model(model, scales, alpha, qwen_smoother): - # Smooth the activation and weights with smoother = $\diag{s}$ - for name, module in model.named_modules(): - # if not isinstance(module, QWenBlock): - if not str(type(module)).endswith("QWenBlock'>"): - continue - - # qkv_proj - layer_name = name + ".attn.c_attn" - smoother = smooth_gemm(module.attn.c_attn.weight, - scales[layer_name]["x"], - module.ln_1.weight, - alpha=alpha) - scales[layer_name]["x"] = scales[layer_name]["x"] / smoother - scales[layer_name]["w"] = module.attn.c_attn.weight.abs().max(dim=1)[0] - - # attention dense - layer_name = name + ".attn.c_proj" - smoother3 = smooth_gemm( - module.attn.c_proj.weight, - scales[layer_name]["x"], - None, - alpha=alpha, - ) - qwen_smoother[layer_name] = smoother3.float() - - scales[layer_name]["x"] = scales[layer_name]["x"] / smoother3 - scales[layer_name]["w"] = module.attn.c_proj.weight.abs().max(dim=1)[0] - - # mlp w1 / w2, because then use some input hidden_states as input, so we need to smooth it with same scale - mlp_w1_name = name + ".mlp.w1" - mlp_w2_name = name + ".mlp.w2" - smoother2 = smooth_gemm_mlp(module.mlp.w1.weight, - module.mlp.w2.weight, - scales[mlp_w1_name]["x"], - module.ln_2.weight, - alpha=alpha) - scales[mlp_w1_name]["x"] = scales[mlp_w1_name]["x"] / smoother2 - scales[mlp_w2_name]["x"] = scales[mlp_w2_name]["x"] / smoother2 - scales[mlp_w1_name]["w"] = module.mlp.w1.weight.abs().max(dim=1)[0] - scales[mlp_w2_name]["w"] = module.mlp.w2.weight.abs().max(dim=1)[0] - - # mlp c_proj - layer_name = name + ".mlp.c_proj" - smoother4 = smooth_gemm(module.mlp.c_proj.weight, - scales[layer_name]["x"], - None, - alpha=alpha) - qwen_smoother[layer_name] = smoother4.float() - scales[layer_name]["x"] = scales[layer_name]["x"] / smoother4 - scales[layer_name]["w"] = module.mlp.c_proj.weight.abs().max(dim=1)[0] - - -# SantaCoder separates Q projection from KV projection -def concat_qkv_weight_bias(q, hf_key, hf_model): - kv = hf_model.state_dict()[hf_key.replace("q_attn", "kv_attn")] - return torch.cat([q, kv], dim=-1) - - -# StarCoder uses nn.Linear for these following ops whose weight matrix is transposed compared to transformer.Conv1D -def transpose_weights(hf_name, param): - weight_to_transpose = [ - "attn.c_attn", "attn.c_proj", "mlp.c_proj", "mlp.w1", "mlp.w2" - ] - if any([k in hf_name for k in weight_to_transpose]): - if len(param.shape) == 2: - param = param.transpose(0, 1) - return param - - -def convert_qwen_name(orig_name): - global_weights = { - "transformer.wte.weight": "vocab_embedding.weight", - "transformer.ln_f.weight": "ln_f.weight", - "lm_head.weight": "lm_head.weight" - } - - if orig_name in global_weights: - return global_weights[orig_name] - - _, _, layer_idx, *weight_name = orig_name.split(".") - layer_idx = int(layer_idx) - weight_name = "transformer." + ".".join(weight_name) - - per_layer_weights = { - "transformer.ln_1.weight": "ln_1.weight", - "transformer.ln_2.weight": "ln_2.weight", - "transformer.attn.c_attn.weight": "attention.qkv.weight", - "transformer.attn.c_attn.bias": "attention.qkv.bias", - "transformer.attn.c_proj.weight": "attention.dense.weight", - "transformer.mlp.w1.weight": "mlp.w1.weight", - "transformer.mlp.w2.weight": "mlp.w2.weight", - "transformer.mlp.c_proj.weight": "mlp.c_proj.weight", - } - return f"layers.{layer_idx}.{per_layer_weights[weight_name]}" - - -@torch.no_grad() -def hf_qwen_converter(args: ProgArgs): - infer_tp = args.tensor_parallelism - multi_query_mode = True if args.model in ["santacoder", "starcoder" - ] else False - saved_dir = Path(args.out_dir) / f"{infer_tp}-gpu" - saved_dir.mkdir(parents=True, exist_ok=True) - - # load position_embedding from rank 0 - model = AutoModelForCausalLM.from_pretrained( - args.in_file, - device_map= - "auto", # if you gpu memory is not enough, you can set device_map="cpu" - trust_remote_code=True, - torch_dtype=str_dtype_to_torch(args.storage_type), - ).half() # if you gpu memory is not enough, you can set .half() to .float() - model.generation_config = GenerationConfig.from_pretrained( - args.in_file, trust_remote_code=True) - act_range = {} - qwen_smoother = {} - if args.smoothquant is not None or args.calibrate_kv_cache: - os.environ["TOKENIZERS_PARALLELISM"] = os.environ.get( - "TOKENIZERS_PARALLELISM", "false") - from datasets import load_dataset - - # copy from summarize.py - dataset_cnn = load_dataset("ccdv/cnn_dailymail", '3.0.0') - dataset = dataset_cnn["test"] - tokenizer = AutoTokenizer.from_pretrained( - args.in_file, - legacy=False, - padding_side='left', - trust_remote_code=True, - ) - gen_config_path = os.path.join(args.in_file, 'generation_config.json') - with open(gen_config_path, 'r') as f: - gen_config = json.load(f) - chat_format = gen_config['chat_format'] - tokenizer.pad_token_id = tokenizer.im_end_id - # use this prompt to make chat model do summarize - system_prompt = "You are a useful assistant, please directly output the corresponding summary according to the article entered by the user." - act_range = capture_activation_range( - model, - tokenizer, - dataset, - system_prompt=system_prompt, - chat_format=chat_format, - max_input_len=args.max_input_len, - ) - if args.smoothquant is not None: - smooth_qwen_model(model, act_range, args.smoothquant, qwen_smoother) - - config = configparser.ConfigParser() - config["qwen"] = {} - for key in vars(args): - config["qwen"][key] = f"{vars(args)[key]}" - for k, v in vars(model.config).items(): - config["qwen"][k] = f"{v}" - config["qwen"]["storage_dtype"] = args.storage_type - config["qwen"]["multi_query_mode"] = str(multi_query_mode) - with open(saved_dir / "config.ini", 'w') as configfile: - config.write(configfile) - - storage_type = str_dtype_to_torch(args.storage_type) - - global_weights = ["vocab_embedding.weight", "ln_f.weight", "lm_head.weight"] - - int8_outputs = None - if args.calibrate_kv_cache: - int8_outputs = "kv_cache_only" - if args.smoothquant is not None: - int8_outputs = "all" - - starmap_args = [] - for name, param in tqdm( - model.named_parameters(), - desc="convert and save", - total=len(list(model.parameters())), - ncols=80, - ): - if "weight" not in name and "bias" not in name: - continue - converted_name = convert_qwen_name(name) - if name.replace(".weight", "") in qwen_smoother.keys(): - smoother = qwen_smoother[name.replace(".weight", "")] - starmap_arg = ( - 0, - saved_dir, - infer_tp, - f"{converted_name}.smoother".replace(".weight", ""), - smoother, - storage_type, - None, - { - "int8_outputs": int8_outputs, - "multi_query_mode": multi_query_mode, - "local_dim": None, - }, - ) - if args.processes > 1: - starmap_args.append(starmap_arg) - else: - split_and_save_weight(*starmap_arg) - - param = transpose_weights(name, param) - if converted_name in global_weights: - torch_to_numpy(param.to(storage_type).cpu()).tofile( - saved_dir / f"{converted_name}.bin") - else: - if 'q_attn' in name: - param = concat_qkv_weight_bias(param, name, model) - converted_name = converted_name.replace("query", - "query_key_value") - # Needed by QKV projection weight split. With multi_query_mode one does not simply take - # out_dim and divide it by 3 to get local_dim because out_dim = local_dim + 2 * head_size - local_dim = model.transformer.h[ - 0].attn.embed_dim if multi_query_mode else None - starmap_arg = (0, saved_dir, infer_tp, converted_name, - param.to(storage_type), storage_type, - act_range.get(name.replace(".weight", "")), { - "int8_outputs": int8_outputs, - "multi_query_mode": multi_query_mode, - "local_dim": local_dim - }) - if args.processes > 1: - starmap_args.append(starmap_arg) - else: - split_and_save_weight(*starmap_arg) - - if args.processes > 1: - starmap_args = tqdm(starmap_args, desc="saving weights") - with multiprocessing.Pool(args.processes) as pool: - pool.starmap(split_and_save_weight, starmap_args) - - -def run_conversion(args: ProgArgs): - print("\n=============== Arguments ===============") - for key, value in vars(args).items(): - print(f"{key}: {value}") - print("========================================") - hf_qwen_converter(args) - - -if __name__ == "__main__": - torch.multiprocessing.set_start_method("spawn") - run_conversion(ProgArgs.parse()) diff --git a/examples/qwen/requirements.txt b/examples/qwen/requirements.txt index d20e2764b..b8584db87 100644 --- a/examples/qwen/requirements.txt +++ b/examples/qwen/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.9.0.dev2024031200 +tensorrt_llm==0.9.0.dev2024031900 datasets~=2.16.0 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/qwen/smoothquant.py b/examples/qwen/smoothquant.py deleted file mode 100644 index b1af23889..000000000 --- a/examples/qwen/smoothquant.py +++ /dev/null @@ -1,209 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -''' -Utilities for SmoothQuant models -''' - -import functools -import os -import sys -from collections import defaultdict - -import numpy as np -import torch -import torch.nn as nn -from tqdm import tqdm -from transformers.pytorch_utils import Conv1D - -project_dir = os.path.dirname( - os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -sys.path.append(project_dir) -from utils.utils import make_context - - -@torch.no_grad() -def apply_smoothing(scales, - gemm_weights, - rmsnorm_weights=None, - dtype=torch.float32, - rmsnorm_1p=False): - if not isinstance(gemm_weights, list): - gemm_weights = [gemm_weights] - - if rmsnorm_weights is not None: - assert rmsnorm_weights.numel() == scales.numel() - rmsnorm_weights.div_(scales).to(dtype) - if rmsnorm_1p: - rmsnorm_weights += (1 / scales) - 1 - - for gemm in gemm_weights: - gemm.mul_(scales.view(1, -1)).to(dtype) - - -@torch.no_grad() -def smooth_gemm(gemm_weights, - act_scales, - rmsnorm_weights=None, - alpha=0.5, - weight_scales=None): - if not isinstance(gemm_weights, list): - gemm_weights = [gemm_weights] - orig_dtype = gemm_weights[0].dtype - - for gemm in gemm_weights: - # gemm_weights are expected to be transposed - assert gemm.shape[1] == act_scales.numel() - - if weight_scales is None: - weight_scales = torch.cat( - [gemm.abs().max(dim=0, keepdim=True)[0] for gemm in gemm_weights], - dim=0) - weight_scales = weight_scales.max(dim=0)[0] - weight_scales.to(float).clamp(min=1e-5) - scales = (act_scales.to(gemm_weights[0].device).to(float).pow(alpha) / - weight_scales.pow(1 - alpha)).clamp(min=1e-5) - - apply_smoothing(scales, gemm_weights, rmsnorm_weights, orig_dtype) - - return scales - - -@torch.no_grad() -def smooth_gemm_mlp(w1_weights, - w2_weights, - act_scales, - rmsnorm_weights=None, - alpha=0.5, - weight_scales=None): - gemm_weights = [] - if not isinstance(w1_weights, list): - w1_weights = [w1_weights] - if not isinstance(w2_weights, list): - w2_weights = [w2_weights] - - for i in range(len(w1_weights)): - gemm_weight = torch.cat([w1_weights[i], w2_weights[i]], dim=0) - gemm_weights.append(gemm_weight) - - orig_dtype = gemm_weights[0].dtype - - for gemm in gemm_weights: - # gemm_weights are expected to be transposed - assert gemm.shape[1] == act_scales.numel() - - if weight_scales is None: - weight_scales = torch.cat( - [gemm.abs().max(dim=0, keepdim=True)[0] for gemm in gemm_weights], - dim=0) - weight_scales = weight_scales.max(dim=0)[0] - weight_scales.to(float).clamp(min=1e-5) - scales = (act_scales.to(gemm_weights[0].device).to(float).pow(alpha) / - weight_scales.pow(1 - alpha)).clamp(min=1e-5) - - apply_smoothing(scales, w1_weights + w2_weights, rmsnorm_weights, - orig_dtype) - - return scales - - -@torch.no_grad() -def smooth_ln_fcs(ln, fcs, act_scales, alpha=0.5): - if not isinstance(fcs, list): - fcs = [fcs] - for fc in fcs: - assert isinstance(fc, nn.Linear) - assert ln.weight.numel() == fc.in_features == act_scales.numel() - - device, dtype = fcs[0].weight.device, fcs[0].weight.dtype - act_scales = act_scales.to(device=device, dtype=dtype) - weight_scales = torch.cat( - [fc.weight.abs().max(dim=0, keepdim=True)[0] for fc in fcs], dim=0) - weight_scales = weight_scales.max(dim=0)[0].clamp(min=1e-5) - - scales = (act_scales.pow(alpha) / - weight_scales.pow(1 - alpha)).clamp(min=1e-5).to(device).to(dtype) - - if ln is not None: - ln.weight.div_(scales) - ln.bias.div_(scales) - - for fc in fcs: - fc.weight.mul_(scales.view(1, -1)) - return scales - - -@torch.no_grad() -def capture_activation_range( - model, - tokenizer, - dataset, - system_prompt, - chat_format, - max_input_len, - num_samples=512, -): - model.eval() - device = next(model.parameters()).device - act_scales = defaultdict(lambda: {"x": None, "y": None, "w": None}) - - def stat_tensor(name, tensor, act_scales, key): - hidden_dim = tensor.shape[-1] - tensor = tensor.view(-1, hidden_dim).abs().detach() - comming_max = torch.max(tensor, dim=0)[0].float() - - if act_scales[name][key] is None: - act_scales[name][key] = comming_max - else: - act_scales[name][key] = torch.max(act_scales[name][key], - comming_max) - - def stat_input_hook(m, x, y, name): - if isinstance(x, tuple): - x = x[0] - stat_tensor(name, x, act_scales, "x") - stat_tensor(name, y, act_scales, "y") - - if act_scales[name]["w"] is None: - act_scales[name]["w"] = m.weight.abs().clip(1e-8, - None).max(dim=1)[0] - - hooks = [] - for name, m in model.named_modules(): - if isinstance(m, nn.Linear) or isinstance(m, Conv1D): - hooks.append( - m.register_forward_hook( - functools.partial(stat_input_hook, name=name))) - num_samples = min(num_samples, len(dataset)) - for i in tqdm(range(num_samples), desc="calibrating model"): - line = dataset[i]["article"] - line = line + ' TL;DR: ' - line = line.strip() - line = line.replace(" n't", "n't") - # use make_content to generate prompt - _, input_id_list = make_context(tokenizer=tokenizer, - query=line, - history=[], - system=system_prompt, - chat_format=chat_format, - max_input_length=max_input_len) - line_encoded = torch.from_numpy(np.array( - input_id_list, dtype=np.int32)).type(torch.int32).unsqueeze(0) - line_encoded = line_encoded.to(device) - model(line_encoded) - - for h in hooks: - h.remove() - - return act_scales diff --git a/examples/qwen/utils/__init__.py b/examples/qwen/utils/__init__.py deleted file mode 100644 index 71bf6d298..000000000 --- a/examples/qwen/utils/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/examples/qwen/utils/convert.py b/examples/qwen/utils/convert.py deleted file mode 100644 index f6d6809bb..000000000 --- a/examples/qwen/utils/convert.py +++ /dev/null @@ -1,289 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" - Utilities for exporting a model to our custom format. -""" - -import numpy as np -import torch - -from tensorrt_llm._utils import torch_to_numpy - - -def save_val(val, dir, key, tp_num=None): - suffix = "bin" if tp_num is None else f"{tp_num}.bin" - val.tofile(dir / f"model.{key}.{suffix}") - - -def save_split(split_vals, dir, key, i, split_factor): - for j, val in enumerate(split_vals): - save_val(val, dir, key, i * split_factor + j) - - -def generate_int8(weights, act_range, is_qkv=False, multi_query_mode=False): - """ - This function has two purposes: - - compute quantized weights, scaled either per-tensor or per-column - - compute scaling factors - - Depending on the GEMM API (CUTLASS/CUBLAS) the required scaling factors differ. - CUTLASS uses two sets of scaling factors. One for the activation X, one for the weight W. - CUBLAS only has one (we can't do per-row scaling). So we must provide pre-multiplied scaling factor. - - Here is the list of what we need (T means per-tensor, C per-column): - - scale_x_orig_quant puts fp activation into the quantized range (i.e. [-128, 127], for int8). Used before the GEMM. (T) - - scale_y_quant_orig puts quantized activation into the fp range. Used if the GEMM outputs int8. (T) - - scale_w_quant_orig puts weights from quant range to fp range (used with CUTLASS) (T, C) - - scale_y_accum_quant puts the GEMM result (XW) from accumulation range (int32) - to quant range (int8) (used for CUBLAS) (T, C) - - Note that we don't do anything special about row-parallel GEMM. Theoretically, we could have per-GPU scaling factors too, - but then the model would change depending on the number of GPUs used. - - For QKV projection, the behavior is special. Even if we have a single matrix to perform QKV projection, we consider it - as three different matrices: Q, K, and V. So per-tensor actually means one scaling factor for each Q, K and V. - """ - - # compute weight scaling factors for fp->int8 and int8->fp - if is_qkv and not multi_query_mode: - scale_w_orig_quant_t = 127. / torch_to_numpy(act_range["w"].reshape( - 3, -1).max(dim=-1, keepdims=True)[0].cpu()).astype(np.float32) - scale_w_orig_quant_c = 127. / torch_to_numpy(act_range["w"].reshape( - 3, -1).cpu()).astype(np.float32) - elif is_qkv and multi_query_mode: - raise ValueError( - f"Multi-query w/ int8 quant has not been supported yet") - else: - scale_w_orig_quant_t = 127. / torch_to_numpy( - act_range["w"].max().cpu()).astype(np.float32) - scale_w_orig_quant_c = 127. / torch_to_numpy( - act_range["w"].cpu()).astype(np.float32) - scale_w_quant_orig_t = 1.0 / scale_w_orig_quant_t - scale_w_quant_orig_c = 1.0 / scale_w_orig_quant_c - - # compute the rest of needed scaling factors - scale_x_orig_quant_t = np.array(127. / act_range["x"].max().item()) - scale_y_orig_quant_t = np.array(127. / act_range["y"].max().item()) - scale_y_quant_orig_t = np.array(act_range["y"].max().item() / 127.) - scale_y_accum_quant_t = scale_y_orig_quant_t / (scale_x_orig_quant_t * - scale_w_orig_quant_t) - scale_y_accum_quant_c = scale_y_orig_quant_t / (scale_x_orig_quant_t * - scale_w_orig_quant_c) - if is_qkv: - scale_y_accum_quant_t = np.broadcast_to(scale_y_accum_quant_t, - scale_w_orig_quant_c.shape) - scale_w_quant_orig_t = np.broadcast_to(scale_w_quant_orig_t, - scale_w_orig_quant_c.shape) - - to_i8 = lambda x: x.round().clip(-127, 127).astype(np.int8) - return { - "weight.int8": to_i8(weights * scale_w_orig_quant_t), - "weight.int8.col": to_i8(weights * scale_w_orig_quant_c), - "scale_x_orig_quant": scale_x_orig_quant_t.astype(np.float32), - "scale_w_quant_orig": scale_w_quant_orig_t.astype(np.float32), - "scale_w_quant_orig.col": scale_w_quant_orig_c.astype(np.float32), - "scale_y_accum_quant": scale_y_accum_quant_t.astype(np.float32), - "scale_y_accum_quant.col": scale_y_accum_quant_c.astype(np.float32), - "scale_y_quant_orig": scale_y_quant_orig_t.astype(np.float32), - } - - -def write_int8(vals, - dir, - base_key, - split_dim, - tp_rank, - split_factor, - kv_cache_only=False): - if not kv_cache_only: - save_split(np.split(vals["weight.int8"], split_factor, axis=split_dim), - dir, f"{base_key}.weight.int8", tp_rank, split_factor) - save_split( - np.split(vals["weight.int8.col"], split_factor, axis=split_dim), - dir, f"{base_key}.weight.int8.col", tp_rank, split_factor) - - saved_keys_once = ["scale_y_quant_orig"] - if not kv_cache_only: - saved_keys_once += [ - "scale_x_orig_quant", "scale_w_quant_orig", "scale_y_accum_quant" - ] - # per-column scaling factors are loaded per-gpu for ColumnParallel GEMMs (QKV, FC1) - if not kv_cache_only: - if split_dim == -1: - save_split( - np.split(vals["scale_w_quant_orig.col"], - split_factor, - axis=split_dim), dir, - f"{base_key}.scale_w_quant_orig.col", tp_rank, split_factor) - save_split( - np.split(vals["scale_y_accum_quant.col"], - split_factor, - axis=split_dim), dir, - f"{base_key}.scale_y_accum_quant.col", tp_rank, split_factor) - else: - saved_keys_once += [ - "scale_w_quant_orig.col", "scale_y_accum_quant.col" - ] - - if tp_rank == 0: - for save_key in saved_keys_once: - save_val(vals[save_key], dir, f"{base_key}.{save_key}") - - -# Note: in multi_query_mode, only query heads are split between multiple GPUs, while key/value head -# are not split as there is only one head per key/value. -@torch.no_grad() -def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, - storage_type, act_range, config): - use_attention_nemo_shape = config.get("use_attention_nemo_shape", False) - split_gated_activation = config.get("split_gated_activation", False) - num_attention_heads = config.get("num_attention_heads", 0) - tp_size = config.get("tp_size", 1) - int8_outputs = config.get("int8_outputs", None) - multi_query_mode = config.get("multi_query_mode", False) - local_dim = config.get("local_dim", None) - - save_int8 = int8_outputs == "all" or int8_outputs == "kv_cache_only" - - if not key.endswith(".smoother"): - if not isinstance(vals, list): - vals = [vals] - - if config.get("transpose_weights", False) and vals[0].ndim == 2: - vals = [val.T for val in vals] - if "layernorm.weight" in key and config.get("apply_layernorm_1p", - False): - vals = [val + 1.0 for val in vals] - vals = [torch_to_numpy(val.cpu().to(storage_type)) for val in vals] - else: - vals = torch_to_numpy(vals.cpu()) - - if "ln_1.weight" in key or "ln_1.bias" in key or \ - "attention.dense.bias" in key or \ - "ln_2.weight" in key or "ln_2.bias" in key or \ - "mlp.c_proj.bias" in key or "ln_f.weight" in key: - # "final_layernorm.weight" in key or "final_layernorm.bias" in key: - - # shared weights, only need to convert the weights of rank 0 - if tp_rank == 0: - save_val(vals[0], saved_dir, key) - - elif "attention.dense.weight" in key or "mlp.c_proj.weight" in key: - cat_dim = 0 - val = np.concatenate(vals, axis=cat_dim) - split_vals = np.split(val, split_factor, axis=cat_dim) - save_split(split_vals, saved_dir, key, tp_rank, split_factor) - if act_range is not None and int8_outputs == "all": - base_key = key.replace(".weight", "") - vals_i8 = generate_int8(val, - act_range, - multi_query_mode=multi_query_mode) - write_int8(vals_i8, saved_dir, base_key, cat_dim, tp_rank, - split_factor) - - elif "mlp.w1.weight" in key or "mlp.w2.weight" in key or "mlp.w1.bias" in key or "mlp.w2.bias" in key: - if split_gated_activation: - splits = [np.split(val, 2, axis=-1) for val in vals] - vals, gates = list(zip(*splits)) - cat_dim = -1 - val = np.concatenate(vals, axis=cat_dim) - split_vals = np.split(val, split_factor, axis=cat_dim) - save_split(split_vals, saved_dir, key, tp_rank, split_factor) - if act_range is not None and int8_outputs == "all": - base_key = key.replace(".weight", "") - vals_i8 = generate_int8(val, - act_range, - multi_query_mode=multi_query_mode) - write_int8(vals_i8, saved_dir, base_key, cat_dim, tp_rank, - split_factor) - - if split_gated_activation: - assert not save_int8 - prefix, dot, suffix = key.rpartition(".") - key = prefix + ".gate" + dot + suffix - - gate = np.concatenate(gates, axis=cat_dim) - split_vals = np.split(gate, split_factor, axis=cat_dim) - save_split(split_vals, saved_dir, key, tp_rank, split_factor) - - elif "attention.qkv.bias" in key: - if local_dim is None: - local_dim = vals[0].shape[-1] // 3 - - if multi_query_mode: - val = vals[0] - # out_feature = local_dim + 2 * head_size; assumes local_dim equals to hidden_dim - b_q, b_kv = np.split(val, [local_dim], axis=-1) - b_q_split = np.split(b_q, split_factor, axis=-1) - split_vals = [np.concatenate((i, b_kv), axis=-1) for i in b_q_split] - else: - if use_attention_nemo_shape: - head_num = num_attention_heads // tp_size - size_per_head = local_dim // num_attention_heads - nemo_shape = (head_num, 3, size_per_head) - vals = [val.reshape(nemo_shape) for val in vals] - vals = [val.transpose(1, 0, 2) for val in vals] - - vals = [val.reshape(3, local_dim) for val in vals] - val = np.concatenate(vals, axis=-1) - split_vals = np.split(val, split_factor, axis=-1) - save_split(split_vals, saved_dir, key, tp_rank, split_factor) - - elif "attention.qkv.weight" in key: - hidden_dim = vals[0].shape[0] - if local_dim is None: - local_dim = vals[0].shape[-1] // 3 - if multi_query_mode: - val = vals[0] - # out_feature = local_dim + 2 * head_size; assumes local_dim equals to hidden_dim - head_size = (val.shape[-1] - local_dim) // 2 - val = val.reshape(hidden_dim, local_dim + 2 * head_size) - w_q, w_kv = np.split(val, [local_dim], axis=-1) - w_q_split = np.split(w_q, split_factor, axis=-1) - split_vals = [np.concatenate((i, w_kv), axis=-1) for i in w_q_split] - else: - if use_attention_nemo_shape: - head_num = num_attention_heads // tp_size - size_per_head = hidden_dim // num_attention_heads - vals = [ - val.reshape(hidden_dim, head_num, 3, size_per_head) - for val in vals - ] - vals = [val.transpose(0, 2, 1, 3) for val in vals] - - vals = [val.reshape(hidden_dim, 3, local_dim) for val in vals] - cat_dim = -1 - val = np.concatenate(vals, axis=cat_dim) - split_vals = np.split(val, split_factor, axis=cat_dim) - save_split(split_vals, saved_dir, key, tp_rank, split_factor) - if save_int8: - base_key = key.replace(".weight", "") - vals_i8 = generate_int8(val, - act_range, - is_qkv=True, - multi_query_mode=multi_query_mode) - write_int8(vals_i8, - saved_dir, - base_key, - cat_dim, - tp_rank, - split_factor, - kv_cache_only=int8_outputs == "kv_cache_only") - - elif "attention.dense.smoother" in key or "mlp.c_proj.smoother" in key: - split_vals = np.split(vals, split_factor, axis=0) - save_split(split_vals, saved_dir, key, tp_rank, split_factor) - else: - print(f"[WARNING] {key} not handled by converter") diff --git a/examples/qwen/utils/utils.py b/examples/qwen/utils/utils.py deleted file mode 100644 index bf9f60292..000000000 --- a/examples/qwen/utils/utils.py +++ /dev/null @@ -1,99 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import List, Tuple - -from transformers import PreTrainedTokenizer - - -def make_context( - tokenizer: PreTrainedTokenizer, - query: str, - history: List[Tuple[str, str]] = None, - system: str = "You are a helpful assistant.", - max_input_length: - int = 2048, # if you want to change this, you need to change the max_input_len in tensorrt_llm_july-release-v1/examples/qwen/build.py - max_window_size: int = 6144, - chat_format: str = "chatml", -): - if history is None: - history = [] - - if chat_format == "chatml": - im_start, im_end = "<|im_start|>", "<|im_end|>" - im_start_tokens = [tokenizer.im_start_id] - im_end_tokens = [tokenizer.im_end_id] - nl_tokens = tokenizer.encode("\n") - - def _tokenize_str(role, content): - return (f"{role}\n{content}", - tokenizer.encode( - role, - allowed_special=set(), - ) + nl_tokens + tokenizer.encode( - content, - allowed_special=set(), - )) - - system_text, system_tokens_part = _tokenize_str("system", system) - system_tokens = im_start_tokens + system_tokens_part + im_end_tokens - raw_text = "" - context_tokens = [] - - for turn_query, turn_response in reversed(history): - query_text, query_tokens_part = _tokenize_str("user", turn_query) - query_tokens = im_start_tokens + query_tokens_part + im_end_tokens - - response_text, response_tokens_part = _tokenize_str( - "assistant", turn_response) - response_tokens = im_start_tokens + response_tokens_part + im_end_tokens - next_context_tokens = nl_tokens + query_tokens + nl_tokens + response_tokens - prev_chat = ( - f"\n{im_start}{query_text}{im_end}\n{im_start}{response_text}{im_end}" - ) - - current_context_size = (len(system_tokens) + - len(next_context_tokens) + - len(context_tokens)) - if current_context_size < max_window_size: - context_tokens = next_context_tokens + context_tokens - raw_text = prev_chat + raw_text - else: - break - - context_tokens = system_tokens + context_tokens - raw_text = f"{im_start}{system_text}{im_end}" + raw_text - context_tokens += (nl_tokens + im_start_tokens + - _tokenize_str("user", query)[1] + im_end_tokens + - nl_tokens + im_start_tokens + - tokenizer.encode("assistant") + nl_tokens) - raw_text += f"\n{im_start}user\n{query}{im_end}\n{im_start}assistant\n" - - elif chat_format == "raw": - raw_text = query - context_tokens = tokenizer.encode(raw_text) - else: - raise NotImplementedError(f"Unknown chat format {chat_format!r}") - # truncate to max_input_length, truncate from the front - return raw_text, context_tokens[-max_input_length:] - - -def get_stop_words_ids(chat_format, tokenizer): - if chat_format == "raw": - stop_words_ids = [tokenizer.encode("Human:"), [tokenizer.eod_id]] - elif chat_format == "chatml": - stop_words_ids = [[tokenizer.im_end_id], [tokenizer.im_start_id]] - else: - raise NotImplementedError(f"Unknown chat format {chat_format!r}") - return stop_words_ids diff --git a/examples/qwen/weight.py b/examples/qwen/weight.py deleted file mode 100644 index 9b6a75153..000000000 --- a/examples/qwen/weight.py +++ /dev/null @@ -1,999 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import configparser -import time -from operator import attrgetter -from pathlib import Path -from typing import Union - -import numpy as np -import torch -from safetensors import safe_open -from tqdm import tqdm -from transformers import AutoModelForCausalLM - -import tensorrt_llm -from tensorrt_llm._utils import (str_dtype_to_np, str_dtype_to_torch, - torch_to_numpy) -from tensorrt_llm.mapping import Mapping -from tensorrt_llm.models import QWenForCausalLM -from tensorrt_llm.quantization import QuantMode - - -def gen_suffix(rank, use_smooth_quant, quant_per_channel): - suffix = f"{rank}.bin" - if use_smooth_quant: - sq_prefix = "int8." - if quant_per_channel: - sq_prefix += "col." - suffix = sq_prefix + suffix - return suffix - - -def extract_layer_idx(name): - ss = name.split('.') - for s in ss: - if s.isdigit(): - return s - return None - - -def split(v: Union[np.ndarray, torch.Tensor], - tp_size: int, - tp_rank: int, - dim=0): - if tp_size == 1: - return v - assert len(v.shape) > 1 or dim == 0 - if isinstance(v, np.ndarray): - return np.ascontiguousarray( - np.split(v, tp_size, axis=dim)[tp_rank].copy()) - else: - assert v.shape[dim] % tp_size == 0, \ - 'Unable to split: shape={v.shape} (dim={dim}) tp_size={tp_size}.' - split_size = v.shape[dim] // tp_size - return v.split(split_size, dim=dim)[tp_rank].clone().detach() - - -def parse_bin_config(ini_file): - qwen_config = configparser.ConfigParser() - qwen_config.read(ini_file) - - vocab_size = qwen_config.getint('qwen', 'vocab_size') - hidden_size = qwen_config.getint('qwen', 'hidden_size') - inter_size = qwen_config.getint('qwen', 'intermediate_size', fallback=None) - num_hidden_layers = qwen_config.getint( - "qwen", - "num_hidden_layers", - fallback=32, - ) - max_position_embeddings = qwen_config.getint("qwen", - "max_position_embeddings", - fallback=8192) - kv_channels = qwen_config.getint('qwen', 'kv_channels', fallback=128) - rotary_pct = qwen_config.getfloat('qwen', 'rotary_pct', fallback=0.0) - rotary_emb_base = qwen_config.getint('qwen', - 'rotary_emb_base', - fallback=10000) - multi_query_mode = qwen_config.getboolean('qwen', - 'multi_query_mode', - fallback=False) - return (vocab_size, hidden_size, inter_size, num_hidden_layers, kv_channels, - rotary_pct, rotary_emb_base, multi_query_mode, - max_position_embeddings) - - -def load_from_binary(tensorrt_llm_qwen: QWenForCausalLM, - dir_path, - mapping=Mapping(), - dtype='float16', - multi_query_mode=False): - tensorrt_llm.logger.info('Loading weights from FT...') - tik = time.time() - quant_mode = getattr(tensorrt_llm_qwen, 'quant_mode', QuantMode(0)) - if quant_mode.is_int8_weight_only(): - plugin_weight_only_quant_type = torch.int8 - elif quant_mode.is_int4_weight_only(): - plugin_weight_only_quant_type = torch.quint4x2 - (vocab_size, hidden_size, inter_size, num_hidden_layers, kv_channels, - rotary_pct, rotary_emb_base, multi_query_mode, - max_position_embeddings) = parse_bin_config(Path(dir_path) / 'config.ini') - np_dtype = str_dtype_to_np(dtype) - - def fromfile(dir_path, name, shape=None, dtype=np.float16): - dtype = np_dtype if dtype is None else dtype - p = dir_path + '/' + name - if Path(p).exists(): - t = np.fromfile(p, dtype=dtype) - if shape is not None: - t = t.reshape(shape) - return t - else: - print(f"Warning: {p} not found.") - return None - - def set_smoothquant_scale_factors( - module, - pre_scale_weight, - dir_path, - basename, - shape, - per_tok_dyn, - per_channel, - is_qkv=False, - rank=None, - ): - suffix = "bin" - if per_channel: - if rank is not None: - suffix = f"{rank}." + suffix - suffix = "col." + suffix - - col_shape = shape if (per_channel or is_qkv) else [1, 1] - if per_tok_dyn: - if pre_scale_weight is not None: - pre_scale_weight.value = np.array([1.0], dtype=np.float32) - t = fromfile(dir_path, f"{basename}scale_w_quant_orig.{suffix}", - col_shape, np.float32) - module.per_channel_scale.value = t - else: - t = fromfile(dir_path, f"{basename}scale_x_orig_quant.bin", [1], - np.float32) - pre_scale_weight.value = t - t = fromfile(dir_path, f"{basename}scale_y_accum_quant.{suffix}", - col_shape, np.float32) - module.per_channel_scale.value = t - t = fromfile(dir_path, f"{basename}scale_y_quant_orig.bin", [1, 1], - np.float32) - module.act_scale.value = t - - def set_smoother(module, dir_path, base_name, shape, rank): - suffix = f"{rank}.bin" - t = fromfile(dir_path, f"{base_name}.smoother.{suffix}", shape, - np.float32) - module.smoother.value = t - - # Determine the quantization mode. - quant_mode = getattr(tensorrt_llm_qwen, "quant_mode", QuantMode(0)) - # Do we use SmoothQuant? - use_smooth_quant = quant_mode.has_act_and_weight_quant() - # Do we use quantization per token? - quant_per_token_dyn = quant_mode.has_per_token_dynamic_scaling() - # Do we use quantization per channel? - quant_per_channel = quant_mode.has_per_channel_scaling() - - # Do we use INT4/INT8 weight-only? - use_weight_only = quant_mode.is_weight_only() - - # Int8 KV cache - use_int8_kv_cache = quant_mode.has_int8_kv_cache() - - # Debug - suffix = gen_suffix(mapping.tp_rank, use_smooth_quant, quant_per_channel) - # The type of weights. - w_type = np_dtype if not use_smooth_quant else np.int8 - - if mapping.is_first_pp_rank(): - tensorrt_llm_qwen.embedding.vocab_embedding.weight.value = (fromfile( - dir_path, 'vocab_embedding.weight.bin', [vocab_size, hidden_size])) - - if mapping.is_last_pp_rank(): - tensorrt_llm_qwen.ln_f.weight.value = (fromfile(dir_path, - 'ln_f.weight.bin')) - - lm_head_weight = fromfile(dir_path, 'lm_head.weight.bin', - [vocab_size, hidden_size]) - - if vocab_size % mapping.tp_size != 0: - # padding - vocab_size_padded = tensorrt_llm_qwen.lm_head.out_features * mapping.tp_size - pad_width = vocab_size_padded - vocab_size - lm_head_weight = np.pad(lm_head_weight, ((0, pad_width), (0, 0)), - 'constant', - constant_values=0) - if mapping.is_last_pp_rank(): - tensorrt_llm_qwen.lm_head.weight.value = np.ascontiguousarray( - split(lm_head_weight, mapping.tp_size, mapping.tp_rank)) - - num_hidden_layers = tensorrt_llm_qwen.num_layers - layers_range = mapping.pp_layers(num_hidden_layers) - for i in layers_range: - c_attn_out_dim = (3 * hidden_size // - mapping.tp_size) if not multi_query_mode else ( - hidden_size // mapping.tp_size + - (hidden_size // num_hidden_layers) * 2) - idx = i - layers_range[0] - tensorrt_llm_qwen.layers[idx].ln_1.weight.value = fromfile( - dir_path, 'model.layers.' + str(i) + '.ln_1.weight.bin') - - dst = tensorrt_llm_qwen.layers[idx].ln_2.weight - dst.value = fromfile(dir_path, - 'model.layers.' + str(i) + '.ln_2.weight.bin') - - t = fromfile( - dir_path, - 'model.layers.' + str(i) + '.attention.qkv.weight.' + suffix, - [hidden_size, c_attn_out_dim], w_type) - if t is not None: - dst = tensorrt_llm_qwen.layers[idx].attention.qkv.weight - if use_smooth_quant: - dst.value = np.ascontiguousarray(np.transpose(t, [1, 0])) - set_smoothquant_scale_factors( - tensorrt_llm_qwen.layers[idx].attention.qkv, - tensorrt_llm_qwen.layers[idx].ln_1.scale_to_int, - dir_path, - 'model.layers.' + str(i) + '.attention.qkv.', - [1, c_attn_out_dim], - quant_per_token_dyn, - quant_per_channel, - rank=mapping.tp_rank, - is_qkv=True) - elif use_weight_only: - processed_torch_weights, torch_weight_scales = torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix( - torch.tensor(t), plugin_weight_only_quant_type) - dst.value = processed_torch_weights.numpy() - scales = tensorrt_llm_qwen.layers[ - idx].attention.qkv.per_channel_scale - scales.value = torch_weight_scales.numpy() - else: - dst.value = np.ascontiguousarray(np.transpose(t, [1, 0])) - - dst = tensorrt_llm_qwen.layers[idx].attention.qkv.bias - t = fromfile( - dir_path, 'model.layers.' + str(i) + '.attention.qkv.bias.' + - str(mapping.tp_rank) + '.bin', [c_attn_out_dim]) - dst.value = np.ascontiguousarray(t) - - dst = tensorrt_llm_qwen.layers[idx].attention.dense.weight - t = fromfile( - dir_path, - 'model.layers.' + str(i) + '.attention.dense.weight.' + suffix, - [hidden_size // mapping.tp_size, hidden_size], w_type) - if use_smooth_quant: - dst.value = np.ascontiguousarray(np.transpose(t, [1, 0])) - dense_scale = getattr(tensorrt_llm_qwen.layers[idx].attention, - "quantization_scaling_factor", None) - set_smoothquant_scale_factors( - tensorrt_llm_qwen.layers[idx].attention.dense, - dense_scale, - dir_path, - 'model.layers.' + str(i) + '.attention.dense.', - [1, hidden_size], - quant_per_token_dyn, - quant_per_channel, - ) - set_smoother(tensorrt_llm_qwen.layers[idx].attention.dense, - dir_path, - 'model.layers.' + str(i) + '.attention.dense', - [1, hidden_size // mapping.tp_size], mapping.tp_rank) - - elif use_weight_only: - processed_torch_weights, torch_weight_scales = torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix( - torch.tensor(t), plugin_weight_only_quant_type) - dst.value = processed_torch_weights.numpy() - scales = tensorrt_llm_qwen.layers[ - i].attention.dense.per_channel_scale - scales.value = torch_weight_scales.numpy() - else: - dst.value = np.ascontiguousarray(np.transpose(t, [1, 0])) - - t = fromfile(dir_path, - 'model.layers.' + str(i) + '.mlp.w1.weight.' + suffix, - [hidden_size, inter_size // mapping.tp_size // 2], w_type) - if use_smooth_quant: - tensorrt_llm_qwen.layers[ - idx].mlp.gate.weight.value = np.ascontiguousarray( - np.transpose(t, [1, 0])) - set_smoothquant_scale_factors( - tensorrt_llm_qwen.layers[idx].mlp.gate, - tensorrt_llm_qwen.layers[idx].ln_2.scale_to_int, - dir_path, - 'model.layers.' + str(i) + '.mlp.w1.', - [1, inter_size // mapping.tp_size // 2], - quant_per_token_dyn, - quant_per_channel, - rank=mapping.tp_rank) - elif use_weight_only: - dst = tensorrt_llm_qwen.layers[idx].mlp.gate.weight - processed_torch_weights, torch_weight_scales = torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix( - torch.tensor(t), plugin_weight_only_quant_type) - dst.value = processed_torch_weights.numpy() - scales = tensorrt_llm_qwen.layers[idx].mlp.gate.per_channel_scale - scales.value = torch_weight_scales.numpy() - else: - tensorrt_llm_qwen.layers[ - idx].mlp.gate.weight.value = np.ascontiguousarray( - np.transpose(t, [1, 0])) - - t = fromfile(dir_path, - 'model.layers.' + str(i) + '.mlp.w2.weight.' + suffix, - [hidden_size, inter_size // mapping.tp_size // 2], w_type) - if use_smooth_quant: - tensorrt_llm_qwen.layers[ - idx].mlp.fc.weight.value = np.ascontiguousarray( - np.transpose(t, [1, 0])) - set_smoothquant_scale_factors( - tensorrt_llm_qwen.layers[idx].mlp.fc, - tensorrt_llm_qwen.layers[idx].ln_2.scale_to_int, - dir_path, - 'model.layers.' + str(i) + '.mlp.w2.', - [1, inter_size // mapping.tp_size // 2], - quant_per_token_dyn, - quant_per_channel, - rank=mapping.tp_rank) - elif use_weight_only: - dst = tensorrt_llm_qwen.layers[idx].mlp.fc.weight - processed_torch_weights, torch_weight_scales = torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix( - torch.tensor(t), plugin_weight_only_quant_type) - dst.value = processed_torch_weights.numpy() - scales = tensorrt_llm_qwen.layers[idx].mlp.fc.per_channel_scale - scales.value = torch_weight_scales.numpy() - else: - tensorrt_llm_qwen.layers[ - idx].mlp.fc.weight.value = np.ascontiguousarray( - np.transpose(t, [1, 0])) - - t = fromfile(dir_path, - 'model.layers.' + str(i) + '.mlp.c_proj.weight.' + suffix, - [inter_size // mapping.tp_size // 2, hidden_size], w_type) - if use_smooth_quant: - tensorrt_llm_qwen.layers[ - idx].mlp.proj.weight.value = np.ascontiguousarray( - np.transpose(t, [1, 0])) - proj_scale = getattr(tensorrt_llm_qwen.layers[idx].mlp, - "quantization_scaling_factor", None) - set_smoothquant_scale_factors( - tensorrt_llm_qwen.layers[idx].mlp.proj, proj_scale, dir_path, - 'model.layers.' + str(i) + '.mlp.c_proj.', [1, hidden_size], - quant_per_token_dyn, quant_per_channel) - set_smoother(tensorrt_llm_qwen.layers[idx].mlp.proj, dir_path, - 'model.layers.' + str(i) + '.mlp.c_proj', - [1, inter_size // mapping.tp_size // 2], - mapping.tp_rank) - elif use_weight_only: - dst = tensorrt_llm_qwen.layers[idx].mlp.proj.weight - processed_torch_weights, torch_weight_scales = torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix( - torch.tensor(t), plugin_weight_only_quant_type) - dst.value = processed_torch_weights.numpy() - scales = tensorrt_llm_qwen.layers[idx].mlp.proj.per_channel_scale - scales.value = torch_weight_scales.numpy() - else: - tensorrt_llm_qwen.layers[ - idx].mlp.proj.weight.value = np.ascontiguousarray( - np.transpose(t, [1, 0])) - - if use_int8_kv_cache: - t = fromfile( - dir_path, 'model.layers.' + str(i) + - '.attention.qkv.scale_y_quant_orig.bin', [1], np.float32) - tensorrt_llm_qwen.layers[ - idx].attention.kv_cache_scaling_factor.value = t - - tok = time.time() - t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) - tensorrt_llm.logger.info(f'Weights loaded. Total time: {t}') - - -def load_from_hf_qwen(tensorrt_llm_qwen: tensorrt_llm.models.QWenForCausalLM, - hf_qwen, - mapping=Mapping(), - dtype="float32", - multi_query_mode=False): - tensorrt_llm.logger.info('Loading weights from HF QWen...') - tik = time.time() - - quant_mode = getattr(tensorrt_llm_qwen, 'quant_mode', QuantMode(0)) - if quant_mode.is_int8_weight_only(): - plugin_weight_only_quant_type = torch.int8 - elif quant_mode.is_int4_weight_only(): - plugin_weight_only_quant_type = torch.quint4x2 - use_weight_only = quant_mode.is_weight_only() - - model_params = dict(hf_qwen.named_parameters()) - torch_dtype = str_dtype_to_torch(dtype) - - num_hidden_layers = hf_qwen.config.num_hidden_layers - layers_range = mapping.pp_layers(num_hidden_layers) - - for k, v in tqdm(model_params.items(), - total=len(model_params), - ncols=80, - desc="Converting..."): - if 'visual' in k: - continue - if isinstance(v, list): - v = [torch_to_numpy(vv.to(torch_dtype).detach().cpu()) for vv in v] - else: - v = torch_to_numpy(v.to(torch_dtype).detach().cpu()) - if 'transformer.wte.weight' in k: - if tensorrt_llm_qwen.use_parallel_embedding: - v = split(v, mapping.tp_size, mapping.tp_rank, - tensorrt_llm_qwen.embedding_sharding_dim) - if mapping.is_first_pp_rank(): - tensorrt_llm_qwen.embedding.vocab_embedding.weight.value = v - elif 'transformer.ln_f.weight' in k: - if mapping.is_last_pp_rank(): - tensorrt_llm_qwen.ln_f.weight.value = v - elif 'lm_head.weight' in k: - if mapping.is_last_pp_rank(): - tensorrt_llm_qwen.lm_head.weight.value = np.ascontiguousarray( - split(v, mapping.tp_size, mapping.tp_rank)) - else: - layer_idx = extract_layer_idx(k) - if layer_idx is None or int(layer_idx) not in layers_range: - continue - idx = int(layer_idx) - layers_range[0] - if idx >= tensorrt_llm_qwen.num_layers: - continue - if 'ln_1.weight' in k: - tensorrt_llm_qwen.layers[idx].ln_1.weight.value = v - elif 'ln_2.weight' in k: - tensorrt_llm_qwen.layers[idx].ln_2.weight.value = v - elif 'attn.c_attn.weight' in k: - dst = tensorrt_llm_qwen.layers[idx].attention.qkv.weight - if multi_query_mode: - assert isinstance(v, list) and len(v) == 3 - wq = split(v[0], mapping.tp_size, mapping.tp_rank) - wk = split(v[1], mapping.tp_size, mapping.tp_rank) - wv = split(v[2], mapping.tp_size, mapping.tp_rank) - split_v = np.concatenate((wq, wk, wv)) - else: - q_emb = v.shape[0] // 3 - model_emb = v.shape[1] - v = v.reshape(3, q_emb, model_emb) - split_v = split(v, mapping.tp_size, mapping.tp_rank, dim=1) - split_v = split_v.reshape(3 * (q_emb // mapping.tp_size), - model_emb) - if use_weight_only: - v = np.ascontiguousarray(split_v.transpose()) - processed_torch_weights, torch_weight_scales = torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix( - torch.tensor(v), plugin_weight_only_quant_type) - dst.value = processed_torch_weights.numpy() - scales = tensorrt_llm_qwen.layers[ - idx].attention.qkv.per_channel_scale - scales.value = torch_weight_scales.numpy() - else: - dst.value = np.ascontiguousarray(split_v) - elif 'attn.c_attn.bias' in k: - dst = tensorrt_llm_qwen.layers[idx].attention.qkv.bias - if multi_query_mode: - assert isinstance(v, list) and len(v) == 3 - wq = split(v[0], mapping.tp_size, mapping.tp_rank) - wk = split(v[1], mapping.tp_size, mapping.tp_rank) - wv = split(v[2], mapping.tp_size, mapping.tp_rank) - split_v = np.concatenate((wq, wk, wv)) - else: - q_emb = v.shape[0] // 3 - v = v.reshape(3, q_emb) - split_v = split(v, mapping.tp_size, mapping.tp_rank, dim=1) - split_v = split_v.reshape(3 * (q_emb // mapping.tp_size)) - dst.value = np.ascontiguousarray(split_v) - elif 'attn.c_proj.weight' in k: - dst = tensorrt_llm_qwen.layers[idx].attention.dense.weight - split_v = split(v, mapping.tp_size, mapping.tp_rank, dim=1) - if use_weight_only: - v = np.ascontiguousarray(split_v.transpose()) - processed_torch_weights, torch_weight_scales = torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix( - torch.tensor(v), plugin_weight_only_quant_type) - dst.value = processed_torch_weights.numpy() - scales = tensorrt_llm_qwen.layers[ - idx].attention.dense.per_channel_scale - scales.value = torch_weight_scales.numpy() - else: - dst.value = np.ascontiguousarray(split_v) - elif 'mlp.w1.weight' in k: - dst = tensorrt_llm_qwen.layers[idx].mlp.gate.weight - split_v = split(v, mapping.tp_size, mapping.tp_rank, dim=0) - if use_weight_only: - v = np.ascontiguousarray(split_v.transpose()) - processed_torch_weights, torch_weight_scales = torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix( - torch.tensor(v), plugin_weight_only_quant_type) - dst.value = processed_torch_weights.numpy() - scales = tensorrt_llm_qwen.layers[ - idx].mlp.gate.per_channel_scale - scales.value = torch_weight_scales.numpy() - else: - dst.value = np.ascontiguousarray(split_v) - elif 'mlp.w2.weight' in k: - dst = tensorrt_llm_qwen.layers[idx].mlp.fc.weight - split_v = split(v, mapping.tp_size, mapping.tp_rank, dim=0) - if use_weight_only: - v = np.ascontiguousarray(split_v.transpose()) - processed_torch_weights, torch_weight_scales = torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix( - torch.tensor(v), plugin_weight_only_quant_type) - dst.value = processed_torch_weights.numpy() - scales = tensorrt_llm_qwen.layers[ - idx].mlp.fc.per_channel_scale - scales.value = torch_weight_scales.numpy() - else: - dst.value = np.ascontiguousarray(split_v) - elif 'mlp.c_proj.weight' in k: - dst = tensorrt_llm_qwen.layers[idx].mlp.proj.weight - split_v = split(v, mapping.tp_size, mapping.tp_rank, dim=1) - if use_weight_only: - v = np.ascontiguousarray(split_v.transpose()) - processed_torch_weights, torch_weight_scales = torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix( - torch.tensor(v), plugin_weight_only_quant_type) - dst.value = processed_torch_weights.numpy() - scales = tensorrt_llm_qwen.layers[ - idx].mlp.proj.per_channel_scale - scales.value = torch_weight_scales.numpy() - else: - dst.value = np.ascontiguousarray(split_v) - else: - print("unknown key: ", k) - - tok = time.time() - t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) - tensorrt_llm.logger.info(f'Weights loaded. Total time: {t}') - return - - -def load_from_gptq_qwen( - tensorrt_llm_qwen: QWenForCausalLM, - quant_ckpt_path, - mapping=Mapping(), - dtype="float16", -): - tensorrt_llm.logger.info( - "loading weights from groupwise gptq qwen safetensors...") - tik = time.time() - - if quant_ckpt_path.endswith(".safetensors"): - groupwise_qweight_safetensors = safe_open(quant_ckpt_path, - framework="pt", - device='cpu') - model_params = { - key: groupwise_qweight_safetensors.get_tensor(key) - for key in groupwise_qweight_safetensors.keys() - } - elif quant_ckpt_path.endswith(".pt"): - model_params = torch.load(quant_ckpt_path, - map_location=torch.device("cpu")) - else: - if Path(quant_ckpt_path).is_dir(): - model = AutoModelForCausalLM.from_pretrained( - quant_ckpt_path, device_map="auto", - trust_remote_code=True).eval().cpu() - model_params = {k: v for k, v in model.state_dict().items()} - torch.cuda.empty_cache() - del model - else: - raise ValueError("quantized checkpoint format not supported!") - - def unpack_int32_into_int8(w_packed): - # unpack inputs packed in int32/float32 into uint4 and store them in int8 format - w_packed_int4x2 = w_packed.contiguous().view(torch.uint8) - w_unpacked = torch.zeros(w_packed_int4x2.shape[0], - w_packed_int4x2.shape[1] * 2, - dtype=torch.int8) - w_unpacked[:, ::2] = w_packed_int4x2 % 16 - w_unpacked[:, 1::2] = w_packed_int4x2 // 16 - return w_unpacked.contiguous() - - def preprocess_groupwise_weight_params( - weight_name, - qweight_int32=None, - qzeros_int32=None, - scales_fp16=None, - ): - if weight_name is not None: - qweight_int32 = model_params[weight_name].cpu() - qzeros_int32 = model_params[weight_name[:-7] + "qzeros"].cpu() - scales_fp16 = model_params[weight_name[:-7] + "scales"].cpu() - - UINT4_TO_INT4_FLAG = 1 - GPTQ_FLAG = 1 - packer = torch.ops.trtllm.pack_int8_tensor_to_packed_int4 - preprocessor = torch.ops.trtllm.preprocess_weights_for_mixed_gemm - - qweight_unpacked_int8 = ( - unpack_int32_into_int8(qweight_int32.T).T.contiguous() - 8) - qweight_interleaved = preprocessor(packer(qweight_unpacked_int8), - torch.quint4x2).view(torch.float16) - # zeros = zeros * scales - qzeros_unpacked_int32 = unpack_int32_into_int8(qzeros_int32) - - zeros_x_scales_fp16 = (-qzeros_unpacked_int32 + 8 * UINT4_TO_INT4_FLAG - - GPTQ_FLAG) * scales_fp16 - zeros_x_scales_fp16 = zeros_x_scales_fp16.half() - - # return processed interleaved weight, original scales and zeros * scales - return ( - qweight_interleaved.contiguous(), # dtype: int8 - zeros_x_scales_fp16.contiguous(), # dtype: float16 - scales_fp16.contiguous(), # dtype: float16 - ) - - layer_ids = [ - extract_layer_idx(key) for key in model_params.keys() - if 'visual' not in key - ] #exclude 'visual' for Qwen-VL case - layer_ids = [ - int(layer_idx) for layer_idx in layer_ids if layer_idx is not None - ] - num_hidden_layers = max(layer_ids) + 1 - suffixs = ["qweight", "qzeros", "scales"] - - layers_range = mapping.pp_layers(num_hidden_layers) - torch_dtype = str_dtype_to_torch(dtype) - for layer in tqdm(layers_range, - ncols=80, - desc="loading attention weight..."): - prefix = f"transformer.h.{layer}.attn." - split_qkv_suf = [] - for suf in suffixs: - qkv_part = model_params[prefix + "c_attn." + suf].cpu() - q_emb = qkv_part.shape[1] // 3 - model_emb = qkv_part.shape[0] - qkv_part = qkv_part.reshape(model_emb, 3, q_emb) - split_qkv = split(qkv_part, mapping.tp_size, mapping.rank, dim=2) - split_qkv = split_qkv.reshape(model_emb, - 3 * (q_emb // mapping.tp_size)) - # dtype: int32, int32, float16 - split_qkv_suf.append(split_qkv) - - idx = layer - layers_range[0] - th_bias = model_params[prefix + "c_attn.bias"].to( - torch_dtype).cpu().contiguous() - - q_emb = th_bias.shape[0] // 3 - th_bias = th_bias.reshape(3, q_emb) - split_v = split(th_bias, mapping.tp_size, mapping.rank, dim=1) - split_v = split_v.reshape(3 * (q_emb // mapping.tp_size)) - - tensorrt_llm_qwen.layers[ - idx].attention.qkv.bias.value = np.ascontiguousarray(split_v) - - th_qweight, th_zero, th_scale = preprocess_groupwise_weight_params( - None, - split_qkv_suf[0], - split_qkv_suf[1], - split_qkv_suf[2], - ) - tensorrt_llm_qwen.layers[ - idx].attention.qkv.weight.value = th_qweight.numpy() - tensorrt_llm_qwen.layers[idx].attention.qkv.zero.value = th_zero.to( - torch_dtype).numpy() - tensorrt_llm_qwen.layers[ - idx].attention.qkv.weights_scaling_factor.value = th_scale.to( - torch_dtype).numpy() - - for k, v in tqdm(model_params.items(), - ncols=80, - desc="loading other weight..."): - if 'visual' in k: - continue - if isinstance(v, list): - v = [torch_to_numpy(vv.to(torch_dtype).detach().cpu()) for vv in v] - else: - v = torch_to_numpy(v.to(torch_dtype).detach().cpu()) - - if "transformer.wte.weight" in k: - if mapping.is_first_pp_rank(): - tensorrt_llm.logger.info(f"converting: {k}") - tensorrt_llm_qwen.embedding.vocab_embedding.weight.value = v - elif "transformer.ln_f.weight" in k: - if mapping.is_last_pp_rank(): - tensorrt_llm_qwen.ln_f.weight.value = v - elif "lm_head.weight" in k: - if mapping.is_last_pp_rank(): - tensorrt_llm_qwen.lm_head.weight.value = np.ascontiguousarray( - split(v, mapping.tp_size, mapping.rank)) - else: - layer_idx = extract_layer_idx(k) - if layer_idx is None: - continue - idx = int(layer_idx) - if idx not in layers_range: - continue - idx = idx - layers_range[0] - - if "ln_1.weight" in k: - tensorrt_llm_qwen.layers[idx].ln_1.weight.value = v - elif "ln_2.weight" in k: - tensorrt_llm_qwen.layers[idx].ln_2.weight.value = v - elif 'post_attention_layernorm.weight' in k: - tensorrt_llm_qwen.layers[idx].post_layernorm.weight.value = v - elif "attn.c_proj.qweight" in k: - split_v_suf = [] - for suf in suffixs: - v = model_params[k[:-7] + suf].cpu() - split_v = v.split(v.shape[0] // mapping.tp_size, - dim=0)[mapping.tp_rank] - split_v_suf.append(split_v) - th_qweight, th_zero, th_scale = preprocess_groupwise_weight_params( - None, split_v_suf[0], split_v_suf[1], split_v_suf[2]) - tensorrt_llm_qwen.layers[ - idx].attention.dense.weight.value = th_qweight.numpy() - tensorrt_llm_qwen.layers[ - idx].attention.dense.zero.value = th_zero.to( - torch_dtype).numpy() - tensorrt_llm_qwen.layers[ - idx].attention.dense.weights_scaling_factor.value = th_scale.to( - torch_dtype).numpy() - elif "mlp.w1.qweight" in k: - split_v_suf = [] - for suf in suffixs: - v = model_params[k[:-7] + suf].cpu() - split_v = v.split(v.shape[1] // mapping.tp_size, - dim=1)[mapping.tp_rank] - split_v_suf.append(split_v) - th_qweight, th_zero, th_scale = preprocess_groupwise_weight_params( - None, split_v_suf[0], split_v_suf[1], split_v_suf[2]) - tensorrt_llm_qwen.layers[ - idx].mlp.gate.weight.value = th_qweight.numpy() - tensorrt_llm_qwen.layers[idx].mlp.gate.zero.value = th_zero.to( - torch_dtype).numpy() - tensorrt_llm_qwen.layers[ - idx].mlp.gate.weights_scaling_factor.value = th_scale.to( - torch_dtype).numpy() - elif "mlp.c_proj.qweight" in k: - split_v_suf = [] - for suf in suffixs: - v = model_params[k[:-7] + suf].cpu() - split_v = v.split(v.shape[0] // mapping.tp_size, - dim=0)[mapping.tp_rank] - split_v_suf.append(split_v) - th_qweight, th_zero, th_scale = preprocess_groupwise_weight_params( - None, split_v_suf[0], split_v_suf[1], split_v_suf[2]) - tensorrt_llm_qwen.layers[ - idx].mlp.proj.weight.value = th_qweight.numpy() - tensorrt_llm_qwen.layers[idx].mlp.proj.zero.value = th_zero.to( - torch_dtype).numpy() - tensorrt_llm_qwen.layers[ - idx].mlp.proj.weights_scaling_factor.value = th_scale.to( - torch_dtype).numpy() - elif "mlp.w2.qweight" in k: - split_v_suf = [] - for suf in suffixs: - v = model_params[k[:-7] + suf].cpu() - split_v = v.split(v.shape[1] // mapping.tp_size, - dim=1)[mapping.tp_rank] - split_v_suf.append(split_v) - th_qweight, th_zero, th_scale = preprocess_groupwise_weight_params( - None, split_v_suf[0], split_v_suf[1], split_v_suf[2]) - tensorrt_llm_qwen.layers[ - idx].mlp.fc.weight.value = th_qweight.numpy() - tensorrt_llm_qwen.layers[idx].mlp.fc.zero.value = th_zero.to( - torch_dtype).numpy() - tensorrt_llm_qwen.layers[ - idx].mlp.fc.weights_scaling_factor.value = th_scale.to( - torch_dtype).numpy() - elif 'attn.c_attn.bias' in k: - dst = tensorrt_llm_qwen.layers[idx].attention.qkv.bias - q_emb = v.shape[0] // 3 - v = v.reshape(3, q_emb) - split_v = split(v, mapping.tp_size, mapping.rank, dim=1) - split_v = split_v.reshape(3 * (q_emb // mapping.tp_size)) - dst.value = np.ascontiguousarray(split_v) - - tok = time.time() - t = time.strftime("%H:%M:%S", time.gmtime(tok - tik)) - tensorrt_llm.logger.info(f"weights loaded. total time: {t}") - - -def load_from_awq_qwen(tensorrt_llm_qwen: QWenForCausalLM, - quant_ckpt_path, - quantize_lm_head=False, - mapping=Mapping(), - dtype="float16"): - tensorrt_llm.logger.info( - 'Loading weights from groupwise AWQ Qwen safetensors...') - tik = time.time() - - if quant_ckpt_path.endswith(".safetensors"): - groupwise_qweight_safetensors = safe_open(quant_ckpt_path, - framework="pt", - device=0) - model_params = { - key: groupwise_qweight_safetensors.get_tensor(key) - for key in groupwise_qweight_safetensors.keys() - } - elif quant_ckpt_path.endswith(".pt"): - model_params = torch.load(quant_ckpt_path, - map_location=torch.device('cpu')) - else: - assert False, "Quantized checkpoint format not supported!" - - group_size = model_params["transformer.h.0.attn.c_proj.weight"].numel( - ) // model_params[ - "transformer.h.0.attn.c_proj.weight_quantizer._amax"].numel() - - awq_block_names = [ - "ln_1.weight", - "ln_2.weight", - ] - - tensorrt_llm_block_names = [ - "ln_1.weight", - "ln_2.weight", - ] - - getattr(tensorrt_llm_qwen, 'quant_mode', QuantMode(0)) - - packer = torch.ops.trtllm.pack_int8_tensor_to_packed_int4 - preprocessor = torch.ops.trtllm.preprocess_weights_for_mixed_gemm - torch_dtype = str_dtype_to_torch(dtype) - - def torch_split(v, dim): - if v.shape[dim] % mapping.tp_size != 0: - tensorrt_llm.logger.error( - "Current weight shape is invalid for mapping.tp_size=" + - str(mapping.tp_size)) - assert False, "Invalid TP size" - return v.split(v.shape[dim] // mapping.tp_size, - dim=dim)[mapping.tp_rank] - - def AWQ_quantize_pack_preprocess(weight, scale): - scale = scale.repeat_interleave(group_size, dim=0) - weight = weight / scale # fp16 -> int8 - qweight_int8 = torch.clamp(torch.round(weight.cuda()).char(), -8, 7) - int4_weight = packer(qweight_int8.cpu()) - int4_weight = preprocessor(int4_weight, - torch.quint4x2) # int8 save as uint4 - return int4_weight.view(torch.float16).cpu().numpy() - - def process_and_assign_attn_weight(model_params, mPrefix, mOp, tp_dim=0): - weight = model_params[mPrefix + ".weight"].to(torch_dtype) - q_emb = weight.shape[0] // 3 - model_emb = weight.shape[1] - weight = weight.reshape(3, q_emb, model_emb) - # [k, n] = weight.shape - split_v = split(weight, mapping.tp_size, mapping.rank, dim=tp_dim) - split_v = split_v.reshape(3 * (q_emb // mapping.tp_size), model_emb) - amax = model_params[mPrefix + ".weight_quantizer._amax"].reshape( - (q_emb * 3, int(model_emb / group_size))).to(torch_dtype) - amax = amax.reshape(3, q_emb, model_emb // group_size) - split_amax = split(amax, mapping.tp_size, mapping.rank, dim=tp_dim) - split_amax = split_amax.reshape(3 * (q_emb // mapping.tp_size), - model_emb // group_size) - split_v = split_v.T.contiguous() - split_amax = split_amax.T.contiguous() - pre_quant_scale = model_params[ - mPrefix + ".input_quantizer._pre_quant_scale"].reshape( - (1, model_emb)).to(torch_dtype) - split_scale = split_amax / 8.0 - mOp.weight.value = AWQ_quantize_pack_preprocess(split_v, split_scale) - mOp.weights_scaling_factor.value = split_scale.cpu().numpy() - mOp.prequant_scaling_factor.value = pre_quant_scale.cpu().numpy() - - def process_and_assign_weight(model_params, mPrefix, mOp, tp_dim=0): - weight = model_params[mPrefix + ".weight"].T.contiguous() - [k, n] = weight.shape - weight = torch_split(weight, tp_dim) - amax = model_params[mPrefix + ".weight_quantizer._amax"].reshape( - (n, int(k / group_size))).T.contiguous() - amax = torch_split(amax, tp_dim) - pre_quant_scale = model_params[ - mPrefix + ".input_quantizer._pre_quant_scale"].reshape((1, k)) - if tp_dim == 0: - pre_quant_scale = torch_split(pre_quant_scale, 1) - scale = amax / 8.0 - mOp.weight.value = AWQ_quantize_pack_preprocess(weight, scale) - mOp.weights_scaling_factor.value = scale.to(torch_dtype).cpu().numpy() - mOp.prequant_scaling_factor.value = pre_quant_scale.to( - torch_dtype).cpu().numpy() - - # Check if we need to pad vocab - v = model_params.get('transformer.wte.weight') - [vocab_size, k] = v.shape - pad_vocab = False - pad_vocab_size1 = vocab_size - if quantize_lm_head and vocab_size % 64 != 0: - pad_vocab = True - pad_vocab_size1 = int((vocab_size + 63) / 64) * 64 - if pad_vocab: - new_v = torch.zeros([pad_vocab_size1, k]) - new_v[:vocab_size, :] = v - v = new_v - if mapping.is_first_pp_rank(): - tensorrt_llm_qwen.embedding.vocab_embedding.weight.value = v.to( - torch_dtype).cpu().numpy() - - layer_ids = [extract_layer_idx(key) for key in model_params.keys()] - layer_ids = [ - int(layer_idx) for layer_idx in layer_ids if layer_idx is not None - ] - - num_hidden_layers = max(layer_ids) + 1 - layers_range = mapping.pp_layers(num_hidden_layers) - for layer_idx in tqdm(layers_range, "Loading weights..."): - prefix = "transformer.h." + str(layer_idx) + "." - for idx, awq_attr in enumerate(awq_block_names): - v = model_params[prefix + awq_attr] - layer = attrgetter(tensorrt_llm_block_names[idx])( - tensorrt_llm_qwen.layers[layer_idx]) - setattr(layer, 'value', v.to(torch_dtype).cpu().numpy()) - - mPrefix = prefix + "attn.c_attn" - mOp = tensorrt_llm_qwen.layers[layer_idx].attention.qkv - process_and_assign_attn_weight(model_params, mPrefix, mOp, 1) - - # Attention QKV Liner Bias - th_bias = model_params[prefix + "attn.c_attn.bias"].cpu().to( - torch_dtype).contiguous() - q_emb = th_bias.shape[0] // 3 - th_bias = th_bias.reshape(3, q_emb) - split_v = split(th_bias, mapping.tp_size, mapping.rank, dim=1) - split_v = split_v.reshape(3 * (q_emb // mapping.tp_size)) - tensorrt_llm_qwen.layers[ - layer_idx].attention.qkv.bias.value = np.ascontiguousarray(split_v) - - # Attention Dense (out_proj) Linear - mPrefix = prefix + "attn.c_proj" - mOp = tensorrt_llm_qwen.layers[layer_idx].attention.dense - process_and_assign_weight(model_params, mPrefix, mOp, 0) - - # MLP down_proj (mlp.gate) Linear - mPrefix = prefix + "mlp.w1" - mOp = tensorrt_llm_qwen.layers[layer_idx].mlp.gate - process_and_assign_weight(model_params, mPrefix, mOp, 1) - - # MLP up_proj (mlp.fc) Linear - mPrefix = prefix + "mlp.w2" - mOp = tensorrt_llm_qwen.layers[layer_idx].mlp.fc - process_and_assign_weight(model_params, mPrefix, mOp, 1) - - # MLP gate_proj (mlp.proj) Linear - mPrefix = prefix + "mlp.c_proj" - mOp = tensorrt_llm_qwen.layers[layer_idx].mlp.proj - process_and_assign_weight(model_params, mPrefix, mOp, 0) - - v = model_params['transformer.ln_f.weight'] - if mapping.is_last_pp_rank(): - tensorrt_llm_qwen.ln_f.weight.value = v.to(torch_dtype).cpu().numpy() - - # lm_head - if pad_vocab: - weight = model_params['lm_head.weight'] - [vocab_size, k] = weight.shape - new_weight = torch.zeros([pad_vocab_size1, k]) - new_weight[:vocab_size, :] = weight - new_weight = new_weight.T.contiguous() - amax = model_params['lm_head.weight_quantizer._amax'].reshape( - [vocab_size, k // group_size]) - new_amax = torch.ones([pad_vocab_size1, k // group_size]) - new_amax[:vocab_size, :] = amax - new_amax = new_amax.T.contiguous() - new_scale = new_amax / 8 - if mapping.is_last_pp_rank(): - tensorrt_llm_qwen.lm_head.weight.value = AWQ_quantize_pack_preprocess( - new_weight, new_scale) - tensorrt_llm_qwen.lm_head.weights_scaling_factor.value = new_scale.to( - torch_dtype).cpu().numpy() - tensorrt_llm_qwen.lm_head.prequant_scaling_factor.value = model_params[ - 'lm_head.input_quantizer._pre_quant_scale'].to( - torch_dtype).cpu().numpy() - elif quantize_lm_head: - mPrefix = "lm_head" - mOp = tensorrt_llm_qwen.lm_head - if mapping.is_last_pp_rank(): - process_and_assign_weight(model_params, mPrefix, mOp, 1) - else: - if mapping.is_last_pp_rank(): - tensorrt_llm_qwen.lm_head.weight.value = torch_split( - model_params['lm_head.weight'], - 0).to(torch_dtype).cpu().numpy() - - tok = time.time() - t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) - tensorrt_llm.logger.info(f'Weights loaded. Total time: {t}') - if quant_ckpt_path.endswith(".safetensors"): - del groupwise_qweight_safetensors - del model_params - import gc - gc.collect() - return diff --git a/examples/qwenvl/README.md b/examples/qwenvl/README.md index a517990e1..d0e3e7c66 100644 --- a/examples/qwenvl/README.md +++ b/examples/qwenvl/README.md @@ -17,43 +17,30 @@ ``` This command saves the test image tensor to `image.pt` for later pipeline inference. -3. Build INT4-GPTQ Qwen TensorRT engine. -- Quantize the weights to INT4 with GPTQ +3. Build Qwen TensorRT engine. +- Convert checkpoint 1. Install packages ```bash pip install -r requirements.txt ``` - 2. Weight quantization to INT4 with GPTQ + 2. Convert ```bash - python3 gptq_convert.py --pretrained_model_dir ./Qwen-VL-Chat \ - --quantized_model_dir ./Qwen-VL-Chat-4bit + python3 ../qwen/convert_checkpoint.py --model_dir=./Qwen-VL-Chat \ + --output_dir=./tllm_checkpoint_1gpu ``` - Build TensorRT-LLM engine NOTE: `max_prompt_embedding_table_size = query_token_num * max_batch_size`, therefore, if you change `max_batch_size`, `--max_prompt_embedding_table_size` must be reset accordingly. ```bash - python3 ../qwen/build.py --model_dir=Qwen-VL-Chat \ - --quant_ckpt_path=./Qwen-VL-Chat-4bit/gptq_model-4bit-128g.safetensors \ - --dtype float16 \ - --max_batch_size 8 \ - --max_input_len 2048 \ - --max_output_len 1024 \ - --remove_input_padding \ - --use_gpt_attention_plugin float16 \ - --use_gemm_plugin float16 \ - --use_weight_only \ - --weight_only_precision int4_gptq \ - --per_group \ - --enable_context_fmha \ - --log_level verbose \ - --use_lookup_plugin float16 \ - --max_prompt_embedding_table_size 2048 \ - --output_dir=./trt_engines/Qwen-VL-7B-Chat-int4-gptq - - # --max_prompt_embedding_table_size 2048 = 256 (query_token number) * 8 (max_batch_size) + trtllm-build --checkpoint_dir=./tllm_checkpoint_1gpu \ + --gemm_plugin=float16 --gpt_attention_plugin=float16 \ + --lookup_plugin=float16 --max_input_len=2048 --max_output_len=1024 \ + --max_batch_size=8 --max_prompt_embedding_table_size=2048 \ + --remove_input_padding=enable \ + --output_dir=./trt_engines/Qwen-VL-7B-Chat ``` - The built Qwen engines are located in `./trt_engines/Qwen-VL-7B-Chat-int4-gptq`. + The built Qwen engines are located in `./trt_engines/Qwen-VL-7B-Chat`. For more information about Qwen, refer to the README.md in [`example/qwen`](../qwen). 4. Assemble everything into the Qwen-VL pipeline. @@ -62,7 +49,7 @@ ```bash python3 run.py \ --tokenizer_dir=./Qwen-VL-Chat \ - --qwen_engine_dir=./trt_engines/Qwen-VL-7B-Chat-int4-gptq \ + --qwen_engine_dir=./trt_engines/Qwen-VL-7B-Chat \ --vit_engine_dir=./plan \ --images_path='{"image": "./pics/demo.jpeg"}' \ --input_dir='{"image": "image.pt"}' @@ -71,7 +58,7 @@ ```bash python3 run_chat.py \ --tokenizer_dir=./Qwen-VL-Chat \ - --qwen_engine_dir=./trt_engines/Qwen-VL-7B-Chat-int4-gptq \ + --qwen_engine_dir=./trt_engines/Qwen-VL-7B-Chat \ --vit_engine_dir=./plan \ --images_path='{"image": "./pics/demo.jpeg"}' \ --input_dir='{"image": "image.pt"}' @@ -97,7 +84,7 @@ ```bash python3 run_chat.py \ --tokenizer_dir=./Qwen-VL-Chat \ - --qwen_engine_dir=./trt_engines/Qwen-VL-7B-Chat-int4-gptq \ + --qwen_engine_dir=./trt_engines/Qwen-VL-7B-Chat \ --vit_engine_dir=./plan \ --display \ --port=8006 @@ -110,7 +97,7 @@ ```bash python3 run_chat.py \ --tokenizer_dir=./Qwen-VL-Chat \ - --qwen_engine_dir=./trt_engines/Qwen-VL-7B-Chat-int4-gptq \ + --qwen_engine_dir=./trt_engines/Qwen-VL-7B-Chat \ --vit_engine_dir=./plan \ --display \ --local_machine diff --git a/examples/qwenvl/gptq_convert.py b/examples/qwenvl/gptq_convert.py deleted file mode 100644 index f1626c327..000000000 --- a/examples/qwenvl/gptq_convert.py +++ /dev/null @@ -1,62 +0,0 @@ -import argparse -import logging - -import torch -from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig -from transformers import AutoTokenizer - - -def parse_arguments(): - - parser = argparse.ArgumentParser() - parser.add_argument( - "--pretrained_model_dir", - type=str, - default="Qwen-VL-Chat", - ) - parser.add_argument( - "--quantized_model_dir", - type=str, - default="Qwen-VL-Chat-4bit", - ) - - args = parser.parse_args() - return args - - -args = parse_arguments() - -device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") -logging.basicConfig(format="%(asctime)s %(levelname)s [%(name)s] %(message)s", - level=logging.INFO, - datefmt="%Y-%m-%d %H:%M:%S") - -tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_dir, - use_fast=True, - trust_remote_code=True) -examples = [ - tokenizer( - "auto-gptq is an easy-to-use model quantization library with user-friendly apis, based on GPTQ algorithm.", - return_tensors="pt").to(device) -] - -quantize_config = BaseQuantizeConfig( - bits=4, # quantize model to 4-bit - group_size=128, # it is recommended to set the value to 128 - desc_act= - False, # set to False can significantly speed up inference but the perplexity may slightly bad -) - -# load un-quantized model, by default, the model will always be loaded into CPU memory -model = AutoGPTQForCausalLM.from_pretrained(args.pretrained_model_dir, - quantize_config, - trust_remote_code=True, - low_cpu_mem_usage=True, - device_map=device, - fp16=True) - -# quantize model, the examples should be list of dict whose keys can only be "input_ids" and "attention_mask" -model.quantize(examples) - -# save quantized model using safetensors -model.save_quantized(args.quantized_model_dir, use_safetensors=True) diff --git a/examples/qwenvl/requirements.txt b/examples/qwenvl/requirements.txt index f7422ffb5..e6da08316 100644 --- a/examples/qwenvl/requirements.txt +++ b/examples/qwenvl/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.9.0.dev2024031200 +tensorrt_llm==0.9.0.dev2024031900 datasets~=2.16.0 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/qwenvl/run.py b/examples/qwenvl/run.py index f7c8c000c..712ab6121 100644 --- a/examples/qwenvl/run.py +++ b/examples/qwenvl/run.py @@ -29,11 +29,8 @@ TensorInfo) -def get_engine_name(model, dtype, tp_size, pp_size, rank): - if pp_size == 1: - return '{}_{}_tp{}_rank{}.engine'.format(model, dtype, tp_size, rank) - return '{}_{}_tp{}_pp{}_rank{}.engine'.format(model, dtype, tp_size, - pp_size, rank) +def get_engine_name(rank): + return 'rank{}.engine'.format(rank) def trt_dtype_to_torch(dtype): @@ -89,33 +86,40 @@ def get_model(self): else: raise Exception("unknown chat format ", chat_format) - use_gpt_attention_plugin = config['plugin_config'][ + use_gpt_attention_plugin = config['build_config']['plugin_config'][ 'gpt_attention_plugin'] - remove_input_padding = config['plugin_config']['remove_input_padding'] - dtype = config['builder_config']['precision'] - tp_size = config['builder_config']['tensor_parallel'] - pp_size = config['builder_config']['pipeline_parallel'] + remove_input_padding = config['build_config']['plugin_config'][ + 'remove_input_padding'] + dtype = config['pretrained_config']['dtype'] + tp_size = config['pretrained_config']['mapping']['tp_size'] + pp_size = config['pretrained_config']['mapping']['pp_size'] world_size = tp_size * pp_size assert world_size == tensorrt_llm.mpi_world_size(), \ f'Engine world size ({world_size}) != Runtime world size ({tensorrt_llm.mpi_world_size()})' - num_heads = config['builder_config']['num_heads'] // world_size - max_batch_size = config['builder_config']['max_batch_size'] - hidden_size = config['builder_config']['hidden_size'] // world_size - vocab_size = config['builder_config']['vocab_size'] - num_layers = config['builder_config']['num_layers'] - num_kv_heads = config['builder_config'].get('num_kv_heads', num_heads) - paged_kv_cache = config['plugin_config']['paged_kv_cache'] - tokens_per_block = config['plugin_config']['tokens_per_block'] - max_prompt_embedding_table_size = config['builder_config'].get( + num_heads = config['pretrained_config'][ + 'num_attention_heads'] // world_size + max_batch_size = config['build_config']['max_batch_size'] + hidden_size = config['pretrained_config']['hidden_size'] // world_size + vocab_size = config['pretrained_config']['vocab_size'] + num_layers = config['pretrained_config']['num_hidden_layers'] + num_kv_heads = config['pretrained_config'].get('num_key_value_heads', + num_heads) + paged_kv_cache = config['build_config']['plugin_config'][ + 'paged_kv_cache'] + tokens_per_block = config['build_config']['plugin_config'][ + 'tokens_per_block'] + max_prompt_embedding_table_size = config['build_config'].get( 'max_prompt_embedding_table_size', 0) - quant_mode = QuantMode(config['builder_config']['quant_mode']) - if config['builder_config'].get('multi_query_mode', False): + quant_mode = QuantMode.from_quant_algo( + config['pretrained_config']['quantization']['quant_algo'], + config['pretrained_config']['quantization']['kv_cache_quant_algo']) + if config['pretrained_config'].get('multi_query_mode', False): tensorrt_llm.logger.warning( "`multi_query_mode` config is deprecated. Please rebuild the engine." ) num_kv_heads = 1 # num_kv_heads = (num_kv_heads + tp_size - 1) // tp_size - use_custom_all_reduce = config['plugin_config'].get( + use_custom_all_reduce = config['build_config']['plugin_config'].get( 'use_custom_all_reduce', False) runtime_rank = tensorrt_llm.mpi_rank() @@ -150,8 +154,7 @@ def get_model(self): temperature=1.0, ) - engine_name = get_engine_name('qwen', dtype, tp_size, pp_size, - runtime_rank) + engine_name = get_engine_name(runtime_rank) serialize_path = os.path.join(self.qwen_engine_dir, engine_name) print(f'Loading engine from {serialize_path}') return (model_config, sampling_config, runtime_mapping, runtime_rank, diff --git a/examples/run.py b/examples/run.py index e8d3563fb..06d5735c6 100644 --- a/examples/run.py +++ b/examples/run.py @@ -221,13 +221,12 @@ def parse_input(tokenizer, elif input_file.endswith('.txt'): with open(input_file, 'r', encoding='utf-8', errors='replace') as txt_file: - input_text = txt_file.read() - input_ids = tokenizer.encode( + input_text = txt_file.readlines() + batch_input_ids = tokenizer( input_text, add_special_tokens=add_special_tokens, truncation=True, - max_length=max_input_length) - batch_input_ids.append(input_ids) + max_length=max_input_length)["input_ids"] else: print('Input file format not supported.') raise SystemExit diff --git a/examples/server/README.md b/examples/server/README.md deleted file mode 100644 index cd027012d..000000000 --- a/examples/server/README.md +++ /dev/null @@ -1,25 +0,0 @@ -# Asynchronous generation in python - -## Install the requirements - -` pip install -r examples/server/requirements.txt` - -## Directly from python, with the HL API - -Due to limitation from the HLAPI implementation, currently only LLaMA models are supported: -`python3 examples/server/async.py ` - - -## Using the server interface for TensorRT-LLM - -### Start the server - -`python3 -m examples.server.server &` - -### Send requests - -You can pass request arguments like "max_new_tokens", "top_p", "top_k" in your JSON dict: -`curl http://localhost:8000/generate -d '{"prompt": "In this example,", "max_new_tokens": 8}'` - -You can also use the streaming interface with: -`curl http://localhost:8000/generate -d '{"prompt": "In this example,", "max_new_tokens": 8, "streaming": true}' --output -` diff --git a/examples/server/async.py b/examples/server/async.py deleted file mode 100644 index 61508664f..000000000 --- a/examples/server/async.py +++ /dev/null @@ -1,20 +0,0 @@ -import argparse -from asyncio import run -from pathlib import Path - -from executor import GenerationExecutor - - -async def main(model_dir: Path, tokenizer: Path | str): - engine = GenerationExecutor(model_dir, tokenizer) - text = "deep learning is" - async for response in engine.generate(prompt=text, max_new_tokens=16): - print(response.text) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("model_dir", type=Path) - parser.add_argument("tokenizer", type=Path) - args = parser.parse_args() - run(main(args.model_dir, args.tokenizer)) diff --git a/examples/server/requirements.txt b/examples/server/requirements.txt deleted file mode 100644 index 606faaeea..000000000 --- a/examples/server/requirements.txt +++ /dev/null @@ -1,4 +0,0 @@ ---extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.9.0.dev2024031200 -uvicorn -fastapi diff --git a/examples/skywork/requirements.txt b/examples/skywork/requirements.txt index 626b31531..87da98aac 100644 --- a/examples/skywork/requirements.txt +++ b/examples/skywork/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.9.0.dev2024031200 +tensorrt_llm==0.9.0.dev2024031900 datasets~=2.16.1 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/summarize.py b/examples/summarize.py index 142d5a4b8..33bf3e24b 100644 --- a/examples/summarize.py +++ b/examples/summarize.py @@ -21,7 +21,7 @@ import numpy as np import torch from datasets import load_dataset -from qwen.utils.utils import make_context +from qwen.convert_checkpoint import make_context from transformers import (AutoModel, AutoModelForCausalLM, AutoModelForSeq2SeqLM, GenerationConfig) from utils import DEFAULT_HF_MODEL_DIRS, load_tokenizer, read_model_name @@ -147,7 +147,7 @@ def _prepare_inputs(batch_input_texts, input_ids = tokenizer.encode(curr_text, return_tensors='pt').squeeze(0) input_ids = input_ids[:test_token_num] - elif model_name == 'qwen': + elif model_name == 'QWenForCausalLM': # use make_content to generate prompt system_prompt = "You are a useful assistant, please directly output the corresponding summary according to the article entered by the user." _, input_id_list = make_context( @@ -274,6 +274,12 @@ def eval_hf(datapoint, batch_input_ids = torch.stack(batch_input_ids) batch_input_ids = batch_input_ids.cuda() + # specialization for HF + if early_stopping in [0, 1]: + local_early_stopping = bool(early_stopping) + else: + local_early_stopping = "never" + with torch.no_grad(): outputs = model.generate(batch_input_ids, max_new_tokens=output_len, @@ -284,7 +290,7 @@ def eval_hf(datapoint, num_beams=num_beams, num_return_sequences=num_beams, length_penalty=length_penalty, - early_stopping=early_stopping, + early_stopping=local_early_stopping, output_scores=True, return_dict_in_generate=True) if eval_ppl and batch_size == 1: diff --git a/examples/utils.py b/examples/utils.py index f5bee0f1a..7917ec0ba 100644 --- a/examples/utils.py +++ b/examples/utils.py @@ -27,7 +27,7 @@ 'BloomForCausalLM': 'bigscience/bloom-560m', 'ChatGLMForCausalLM': 'THUDM/chatglm3-6b', 'FalconForCausalLM': 'tiiuae/falcon-rw-1b', - 'gpt': 'gpt2-medium', + 'GPTForCausalLM': 'gpt2-medium', 'GPTJForCausalLM': 'EleutherAI/gpt-j-6b', 'GPTNeoXForCausalLM': 'EleutherAI/gpt-neox-20b', 'InternLMForCausalLM': 'internlm/internlm-chat-7b', @@ -35,13 +35,13 @@ 'MPTForCausalLM': 'mosaicml/mpt-7b', 'PhiForCausalLM': 'microsoft/phi-2', 'OPTForCausalLM': 'facebook/opt-350m', - 'qwen': 'Qwen/Qwen-7B', + 'QWenForCausalLM': 'Qwen/Qwen-7B', } DEFAULT_PROMPT_TEMPLATES = { 'InternLMForCausalLM': "<|User|>:{input_text}\n<|Bot|>:", - 'qwen': + 'QWenForCausalLM': "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n{input_text}<|im_end|>\n<|im_start|>assistant\n", } @@ -73,7 +73,7 @@ def throttle_generator(generator, stream_interval): def load_tokenizer(tokenizer_dir: Optional[str] = None, vocab_file: Optional[str] = None, - model_name: str = 'gpt', + model_name: str = 'GPTForCausalLM', model_version: Optional[str] = None, tokenizer_type: Optional[str] = None): if vocab_file is None: @@ -103,7 +103,7 @@ def load_tokenizer(tokenizer_dir: Optional[str] = None, truncation_side='left', legacy=False) - if model_name == 'qwen': + if model_name == 'QWenForCausalLM': with open(Path(tokenizer_dir) / "generation_config.json") as f: gen_config = json.load(f) chat_format = gen_config['chat_format'] diff --git a/examples/whisper/requirements.txt b/examples/whisper/requirements.txt index 67ec9df76..2d148a4f0 100644 --- a/examples/whisper/requirements.txt +++ b/examples/whisper/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.9.0.dev2024031200 +tensorrt_llm==0.9.0.dev2024031900 tiktoken datasets kaldialign diff --git a/examples/whisper/run.py b/examples/whisper/run.py index cbeb13a85..5f980754f 100644 --- a/examples/whisper/run.py +++ b/examples/whisper/run.py @@ -413,7 +413,7 @@ def decode_dataset( results, enable_log=True) if args.accuracy_check and args.dataset == "hf-internal-testing/librispeech_asr_dummy" and not args.input_file: - assert total_error_rate <= 2.5, f"Word Error rate using whisper large-v3 model should be less than 2.5% but got {total_error_rate}" + assert total_error_rate <= 2.8, f"Word Error rate using whisper large-v3 model should be 2.40%, but got {total_error_rate}" rtf = elapsed / total_duration s = f"RTF: {rtf:.4f}\n" diff --git a/requirements.txt b/requirements.txt index 7cd6d86e4..4a2e757c1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,10 +16,13 @@ pandas h5py sentencepiece>=0.1.99 tensorrt==9.3.0.post12.dev1 -torch>=2.1.0a,<=2.2.0a # https://github.com/pytorch/pytorch/blob/v2.1.2/version.txt still uses 2.1.0a0. -nvidia-ammo~=0.7.0; platform_machine=="x86_64" +# https://github.com/pytorch/pytorch/blob/v2.2.1/version.txt still uses 2.2.0a0. +# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-02.html#rel-24-02 uses 2.3.0a0. +torch>=2.2.0a,<=2.3.0a +nvidia-ammo~=0.7.0 transformers==4.38.2 wheel optimum evaluate janus +setuptools diff --git a/tensorrt_llm/_utils.py b/tensorrt_llm/_utils.py index bf6c43edb..460d3403c 100644 --- a/tensorrt_llm/_utils.py +++ b/tensorrt_llm/_utils.py @@ -13,11 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. import copy +import gc import json import math import struct import tarfile import weakref +from dataclasses import asdict from functools import partial from pathlib import Path, PosixPath from typing import Any, Dict, List, Optional, Union @@ -49,11 +51,11 @@ def torch_to_numpy(x: torch.Tensor): def numpy_to_torch(x): if x.dtype == np_bfloat16: - return torch.tensor(x.view(np.int16)).view(torch.bfloat16) + return torch.from_numpy(x.view(np.int16)).view(torch.bfloat16) elif x.dtype == np_float8: - return torch.tensor(x.view(np.int8)).view(torch.float8_e4m3fn) + return torch.from_numpy(x.view(np.int8)).view(torch.float8_e4m3fn) else: - return torch.tensor(x) + return torch.from_numpy(x) def numpy_to_dtype(x, dtype: str): @@ -424,3 +426,45 @@ def set_obj_attrs( assert not hasattr( obj, key), (f"Overwriting existing tensor attribute: {key}") setattr(obj, key, value) + + +def release_gc(): + ''' Release memory allocated by PyTorch and Python garbage collector explicitly and immediately. + This could be used when some states might be kept in memory even after the variables are deleted. + ''' + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + + +class DictConversion: + + @classmethod + def from_dict(cls, config: Dict[str, Any]): + obj = cls() + fields = obj.__dataclass_fields__ + for key, value in config.items(): + assert hasattr(obj, key) + field_cls = fields[key].type + if (isinstance(field_cls, type) + and issubclass(field_cls, DictConversion) + and isinstance(value, dict)): + value = field_cls.from_dict(value) + setattr(obj, key, value) + return obj + + def to_dict(self): + return asdict(self) + + @classmethod + def from_json_file(cls, file): + with open(file) as f: + return cls.from_dict(json.load(f)) + + def set_defaults(self, **kwargs): + for key, default in kwargs.items(): + value = getattr(self, key) + if (value is None + or (isinstance(value, (list, dict)) and len(value) == 0)): + setattr(self, key, default) diff --git a/tensorrt_llm/auto_parallel/config.py b/tensorrt_llm/auto_parallel/config.py index b54410165..ed5ffd3bb 100644 --- a/tensorrt_llm/auto_parallel/config.py +++ b/tensorrt_llm/auto_parallel/config.py @@ -1,42 +1,12 @@ -import json -from dataclasses import asdict, dataclass, field -from typing import Any, Dict, List, Optional, Union +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Union import torch -from .utils import BaseEnum - - -class DictConversion: - - @classmethod - def from_dict(cls, config: Dict[str, Any]): - obj = cls() - fields = obj.__dataclass_fields__ - for key, value in config.items(): - assert hasattr(obj, key) - field_cls = fields[key].type - if (isinstance(field_cls, type) - and issubclass(field_cls, DictConversion) - and isinstance(value, dict)): - value = field_cls.from_dict(value) - setattr(obj, key, value) - return obj +from tensorrt_llm._utils import DictConversion +from tensorrt_llm.logger import logger - def to_dict(self): - return asdict(self) - - @classmethod - def from_json_file(cls, file): - with open(file) as f: - return cls.from_dict(json.load(f)) - - def set_defaults(self, **kwargs): - for key, default in kwargs.items(): - value = getattr(self, key) - if (value is None - or (isinstance(value, (list, dict)) and len(value) == 0)): - setattr(self, key, default) +from .utils import BaseEnum @dataclass @@ -337,7 +307,11 @@ def is_32gb(): return "V100-PCIe-32GB" else: return "V100-PCIe-16GB" - return None + + fallback_key = "A100-SXM-80GB" + logger.warning( + f"Fail to infer cluster key, use {fallback_key} as fallback.") + return fallback_key class CostModel(str, BaseEnum): diff --git a/tensorrt_llm/auto_parallel/parallelization.py b/tensorrt_llm/auto_parallel/parallelization.py index cc598befa..30dac713b 100644 --- a/tensorrt_llm/auto_parallel/parallelization.py +++ b/tensorrt_llm/auto_parallel/parallelization.py @@ -1,7 +1,7 @@ import contextlib import copy import itertools -import pickle +import pickle # nosec B403 import re from abc import ABC, abstractmethod from dataclasses import dataclass, field @@ -12,7 +12,8 @@ import torch from filelock import FileLock -from tensorrt_llm._utils import trt_dtype_to_np, trt_dtype_to_torch +from tensorrt_llm._utils import (preview_trt_version, trt_dtype_to_np, + trt_dtype_to_torch) from tensorrt_llm.functional import AllReduceStrategy from tensorrt_llm.logger import logger from tensorrt_llm.mapping import Mapping @@ -38,6 +39,8 @@ from .utils import (get_builder_flags, get_updated_plugin, to_base_class_layer, to_subclass_layer, to_trt_weights) +default_int_dtype = trt.int64 if preview_trt_version() else trt.int32 + @dataclass class ParallelConfig: @@ -59,7 +62,7 @@ def save(self, filename): @staticmethod def from_file(filename) -> "ParallelConfig": with open(filename, "rb") as file: - return pickle.load(file) + return pickle.load(file) # nosec B301 def print_graph_strategy(self, file=None): for index, (node_name, @@ -632,8 +635,8 @@ def cast(self, network, tensor, dtype, layer_info): return cast_layer.get_output(0) def const_int(self, network, name, value, layer_info): - const_layer = network.add_constant([1], np.array([value], - dtype=np.int32)) + const_layer = network.add_constant( + [1], np.array([value], dtype=trt_dtype_to_np(default_int_dtype))) self.register_layer(const_layer, name, *layer_info) return const_layer.get_output(0) @@ -954,10 +957,8 @@ def update_shape(self, context: ShardContext): [1]) else: input_dim = split_infos[dim].input_dim - output_dim_layer = network.add_constant([1], - np.array( - [input_dim], - dtype=np.int32)) + output_dim_layer = network.add_constant( + [1], np.array([input_dim], dtype=default_int_dtype)) self.register_layer(output_dim_layer, f"output_dim{dim}", *layer_info) output_dims.append(output_dim_layer.get_output(0)) @@ -1003,7 +1004,7 @@ def shard_slice(self, context: ShardContext): *layer_info) output_dim = self.cast(network, quotient_layer.get_output(0), - trt.DataType.INT32, layer_info) + default_int_dtype, layer_info) output_dims.append(output_dim) else: output_dims.append(output_dim_layer.get_output(0)) @@ -1071,7 +1072,7 @@ def shard_shuffle(self, context: ShardContext): updated_reshape_dims[reshape_dim] = self.cast( network, quotient_layer.get_output(0), - trt.DataType.INT32, + default_int_dtype, layer_info, ) else: diff --git a/tensorrt_llm/auto_parallel/pipeline_graph.py b/tensorrt_llm/auto_parallel/pipeline_graph.py index f021334a7..e97f09973 100644 --- a/tensorrt_llm/auto_parallel/pipeline_graph.py +++ b/tensorrt_llm/auto_parallel/pipeline_graph.py @@ -511,7 +511,9 @@ def get_input(i): if layer.precision_is_set: new_layer.precision = layer.precision for i in range(layer.num_outputs): - if layer.output_type_is_set(i): + # TODO: Remove WAR for shape layer after https://nvbugs/4557631 fixed. + if layer.output_type_is_set( + i) and layer_type != trt.LayerType.SHAPE: new_layer.set_output_type(i, layer.get_output_type(i)) output = new_layer.get_output(i) self._add_tensor(output, layer.get_output(i), prefix) diff --git a/tensorrt_llm/auto_parallel/shape_info.py b/tensorrt_llm/auto_parallel/shape_info.py index 1034d422b..b2542be3e 100644 --- a/tensorrt_llm/auto_parallel/shape_info.py +++ b/tensorrt_llm/auto_parallel/shape_info.py @@ -21,6 +21,12 @@ class ShapeType(Enum): MAX = 2 +_trt_to_type_dict = { + trt.int64: int, + trt.bool: bool, +} + + def get_shape_layers(trt_network): shape_layers = set() for i in range(trt_network.num_layers): @@ -95,16 +101,19 @@ def get_shape_network(trt_network, new_layer = shape_graph.add_layer(layer) for i in range(layer.num_outputs): output = layer.get_output(i) - if output.dtype != trt.DataType.BOOL: - shape_graph.add_output_shape(output) - else: - proxy_layer = shape_network.add_identity( - new_layer.as_trt().get_output(i)) + # TODO: Remove WAR for INT64 after https://nvbugs/4557631 fixed. + if output.dtype in [trt.DataType.BOOL, trt.DataType.INT64]: + proxy_layer = shape_network.add_cast( + new_layer.as_trt().get_output(i), + trt.DataType.INT32, + ) proxy_output = proxy_layer.get_output(0) - proxy_output.dtype = trt.DataType.INT32 shape_graph.register_layer(proxy_layer) shape_graph.add_output_shape(proxy_output) - output_mapping[proxy_output.name] = output.name + output_mapping[proxy_output.name] = (output.name, + output.dtype) + else: + shape_graph.add_output_shape(output) elif layer.name in layers_in_shape_network: if layer.type == trt.LayerType.CONSTANT: shape_graph.add_input(layer.get_output(0)) @@ -161,14 +170,15 @@ def get_per_layer_graph( else: is_output_shape = False if is_output_shape: - if output.dtype == trt.DataType.BOOL: + # TODO: Remove WAR for INT64 after https://nvbugs/4557631 fixed. + if output.dtype in [trt.DataType.BOOL, trt.DataType.INT64]: proxy_layer = network.add_cast( new_layer.as_trt().get_output(i), trt.DataType.INT32, ) proxy_output = proxy_layer.get_output(0) graph.register_layer(proxy_layer) - output_mapping[proxy_output.name] = output.name + output_mapping[proxy_output.name] = (output.name, output.dtype) output = proxy_output graph.add_output_shape(output) else: @@ -198,19 +208,27 @@ def infer_shapes(network, shapes, values, profile=None): if input.is_shape_tensor: value = values[input.name] context.set_shape_input(engine[input.name], value) - context.infer_shapes() - assert context.all_binding_shapes_specified for i in range(network.num_outputs): output = network.get_output(i) shape = context.get_tensor_shape(output.name) - # if len(shape) == 0: - # shape = trt.Dims([1]) shapes[output.name] = shape if output.is_shape_tensor: if shape == [0]: values[output.name] = [] else: - values[output.name] = context.get_shape(engine[output.name]) + if shape == []: + shape = [1] + value = torch.empty(list(shape), + dtype=torch.int32, + device="cpu") + values[output.name] = value + context.set_tensor_address(output.name, value.data_ptr()) + context.infer_shapes() + assert context.all_binding_shapes_specified + for i in range(network.num_outputs): + output = network.get_output(i) + if isinstance(values.get(output.name), torch.Tensor): + values[output.name] = values[output.name].tolist() @dataclass @@ -280,11 +298,13 @@ def infer_per_layer_shapes( f"values={list(cache_key[3])}") raise RuntimeError( f"infer shapes failed for layer {layer.name} ({layer_info})") from e - for proxy_output, output in output_mapping.items(): + for proxy_output, (output, dtype) in output_mapping.items(): shapes[output] = shapes[proxy_output] del shapes[proxy_output] if proxy_output in values: - values[output] = [*map(bool, values[proxy_output])] + values[output] = [ + *map(_trt_to_type_dict[dtype], values[proxy_output]) + ] del values[proxy_output] if cache is not None: logger.debug( @@ -314,9 +334,11 @@ def get_shape_info(trt_network, profile, shape_type: ShapeType = ShapeType.OPT): shape_type=shape_type) try: infer_shapes(shape_network, shapes, values, shape_profile) - for proxy_output, output in output_mapping.items(): + for proxy_output, (output, dtype) in output_mapping.items(): shapes[output] = shapes[proxy_output] - values[output] = [*map(bool, values[proxy_output])] + values[output] = [ + *map(_trt_to_type_dict[dtype], values[proxy_output]) + ] del shapes[proxy_output] del values[proxy_output] except RuntimeError: diff --git a/tensorrt_llm/builder.py b/tensorrt_llm/builder.py index d7bd19563..a6f5cc851 100644 --- a/tensorrt_llm/builder.py +++ b/tensorrt_llm/builder.py @@ -16,6 +16,7 @@ import json import math import os +import shutil import time from dataclasses import dataclass from pathlib import Path @@ -23,13 +24,14 @@ import tensorrt as trt -from ._common import _is_building +from ._common import _is_building, serialize_engine from ._utils import (str_dtype_to_trt, support_strongly_type, to_dict, to_json_file) from .auto_parallel import auto_parallel from .auto_parallel.config import AutoParallelConfig from .graph_rewriting import optimize from .logger import logger +from .lora_manager import LoraBuildConfig from .models import PretrainedConfig, PretrainedModel from .models.modeling_utils import optimize_model from .network import Network, net_guard @@ -154,7 +156,7 @@ def create_builder_config(self, if use_refit and int8: # TRT folds weights into Myelin graph because network contains int8 tensor or Q/DQ nodes # These folded weights can not be refitted - logger.error(f"can't use refit and int8 mode at the same time") + logger.error("can't use refit and int8 mode at the same time") config = self.trt_builder.create_builder_config() if not self.strongly_typed: @@ -299,7 +301,7 @@ def refit_engine(self, network: Network, engine_buffer) -> trt.IHostMemory: @return: A serialized TRT engine if refit successfully, None otherwise ''' assert isinstance(network, Network) - logger.info(f'Refit TRT engine') + logger.info('Refit TRT engine') runtime = trt.Runtime(logger.trt_logger) engine = runtime.deserialize_cuda_engine(engine_buffer) @@ -316,12 +318,11 @@ def refit_engine(self, network: Network, engine_buffer) -> trt.IHostMemory: return None else: logger.error( - f'Please set named parameters before building multiple engines.' - ) + 'Please set named parameters before building multiple engines.') return None if not refitter.refit_cuda_engine(): - logger.error(f'Failed to refit engine.') + logger.error('Failed to refit engine.') return None tok = time.time() @@ -420,6 +421,9 @@ class BuildConfig: enable_debug_output: bool = False max_draft_len: int = 0 use_refit: bool = False + input_timing_cache: str = None + output_timing_cache: str = None + lora_config: LoraBuildConfig = LoraBuildConfig() auto_parallel_config: AutoParallelConfig = AutoParallelConfig() plugin_config: PluginConfig = PluginConfig() @@ -441,12 +445,11 @@ def from_dict(cls, config, plugin_config=None): enable_debug_output = config.pop('enable_debug_output', False) max_draft_len = config.pop('max_draft_len', 0) use_refit = config.pop('use_refit', False) - auto_parallel_config = config.pop('auto_parallel_config', None) - if auto_parallel_config is not None: - auto_parallel_config = AutoParallelConfig.from_dict( - auto_parallel_config) - else: - auto_parallel_config = AutoParallelConfig() + input_timing_cache = config.pop('input_timing_cache', None) + output_timing_cache = config.pop('output_timing_cache', None) + lora_config = LoraBuildConfig.from_dict(config.get('lora_config', {})) + auto_parallel_config = AutoParallelConfig.from_dict( + config.get('auto_parallel_config', {})) if plugin_config is None: plugin_config = PluginConfig() @@ -467,6 +470,9 @@ def from_dict(cls, config, plugin_config=None): enable_debug_output=enable_debug_output, max_draft_len=max_draft_len, use_refit=use_refit, + input_timing_cache=input_timing_cache, + output_timing_cache=output_timing_cache, + lora_config=lora_config, auto_parallel_config=auto_parallel_config, plugin_config=plugin_config) @@ -481,21 +487,12 @@ def to_dict(self): plugin_config = output.pop('plugin_config') plugin_config_dict = copy.deepcopy(plugin_config.__dict__) output['plugin_config'] = plugin_config_dict + output['lora_config'] = output['lora_config'].to_dict() output['auto_parallel_config'] = output['auto_parallel_config'].to_dict( ) return output -def serialize_engine(engine, path): - logger.info(f'Serializing engine to {path}...') - tik = time.time() - with open(path, 'wb') as f: - f.write(engine) - tok = time.time() - t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) - logger.info(f'Engine serialized. Total time: {t}') - - class EngineConfig: def __init__(self, pretrained_config: 'PretrainedConfig', @@ -527,6 +524,28 @@ def __init__(self, config: EngineConfig, engine: trt.IHostMemory): self.engine = engine def save(self, engine_dir: str): + os.makedirs(engine_dir, exist_ok=True) + lora_config = self.config.build_config.lora_config + lora_dirs = lora_config.lora_dir + root_lora_dir = os.path.join(engine_dir, 'lora') + if len(lora_dirs) > 0: + os.makedirs(root_lora_dir, exist_ok=True) + for index, lora_dir in enumerate(lora_dirs): + if lora_config.lora_ckpt_source == 'hf': + target_lora_dir = f"{root_lora_dir}/{index}" + os.makedirs(target_lora_dir, exist_ok=True) + shutil.copy2(os.path.join(lora_dir, 'adapter_config.json'), + target_lora_dir) + shutil.copy2(os.path.join(lora_dir, 'adapter_model.bin'), + target_lora_dir) + lora_config.lora_dir[index] = f"lora/{index}" + elif lora_config.lora_ckpt_source == 'nemo': + target_lora_file = f"{root_lora_dir}/{index}.nemo" + shutil.copyfile(lora_dir, target_lora_file) + lora_config.lora_dir[index] = f"lora/{index}.nemo" + else: + if os.path.exists(root_lora_dir) and os.path.isdir(root_lora_dir): + shutil.rmtree(root_lora_dir) if self.config.pretrained_config.mapping.rank == 0: with open(os.path.join(engine_dir, 'config.json'), "w", @@ -563,10 +582,20 @@ def get_engine_version(engine_dir: str) -> Union[None, str]: def build(model: PretrainedModel, build_config: BuildConfig) -> Engine: + if build_config.plugin_config.lora_plugin is not None: + # TODO(yuxianq): remove this check after TopModelMixin merged into PretrainedModel + assert hasattr(model, 'use_lora'), "This model does not support LoRA" + model = optimize_model( + model, + use_lora=True, + max_lora_rank=model.lora_config.max_lora_rank, + ) + build_config.lora_config = model.lora_config builder = Builder() builder_config = builder.create_builder_config( precision=model.config.dtype, use_refit=build_config.use_refit, + timing_cache=build_config.input_timing_cache, int8=(model.config.quant_mode.has_act_or_weight_quant() and not model.config.quant_mode.has_per_group_scaling()) or model.config.quant_mode.has_int8_kv_cache(), @@ -574,14 +603,6 @@ def build(model: PretrainedModel, build_config: BuildConfig) -> Engine: opt_level=build_config.builder_opt, profiling_verbosity=build_config.profiling_verbosity, quant_mode=model.config.quant_mode, - lora_target_modules=model.config.lora_target_modules if hasattr( - model.config, 'lora_target_modules') else [], - hf_modules_to_trtllm_modules=model.config.lora_target_modules - if hasattr(model.config, 'hf_modules_to_trtllm_modules') else [], - trtllm_modules_to_hf_modules=model.config.lora_target_modules - if hasattr(model.config, 'trtllm_modules_to_hf_modules') else [], - max_lora_rank=model.config.max_lora_rank if hasattr( - model.config, 'max_lora_rank') else 64, ) network = builder.create_network() @@ -632,8 +653,7 @@ def build(model: PretrainedModel, build_config: BuildConfig) -> Engine: max_draft_len=build_config.max_draft_len, gather_context_logits=build_config.gather_context_logits, gather_generation_logits=build_config.gather_generation_logits, - lora_target_modules=model.config.lora_target_modules if hasattr( - model.config, 'lora_target_modules') else []) + lora_target_modules=build_config.lora_config.lora_target_modules) model(**inputs) if build_config.enable_debug_output: @@ -655,4 +675,9 @@ def build(model: PretrainedModel, build_config: BuildConfig) -> Engine: engine = builder.build_engine(network, builder_config) engine_config = EngineConfig(model.config, build_config, __version__) + if build_config.output_timing_cache is not None and model.config.mapping.rank == 0: + ok = builder.save_timing_cache(builder_config, + build_config.output_timing_cache) + assert ok, "Failed to save timing cache." + return Engine(engine_config, engine) diff --git a/tensorrt_llm/commands/build.py b/tensorrt_llm/commands/build.py index de7318127..360717931 100644 --- a/tensorrt_llm/commands/build.py +++ b/tensorrt_llm/commands/build.py @@ -20,21 +20,22 @@ from concurrent.futures import ProcessPoolExecutor, as_completed from importlib.machinery import SourceFileLoader from multiprocessing import get_context -from typing import Dict, Union +from typing import Union import safetensors import torch from .._common import check_max_num_tokens -from .._utils import str_dtype_to_torch from ..auto_parallel.config import _cluster_infos, infer_cluster_key from ..builder import BuildConfig, Engine, build from ..logger import logger +from ..lora_manager import LoraBuildConfig from ..models import MODEL_MAP, PretrainedConfig -from ..models.modeling_utils import WEIGHT_LOADER_MODELS, optimize_model +from ..models.modeling_utils import (WEIGHT_LOADER_MODELS, optimize_model, + preprocess_weights) from ..plugin import PluginConfig, add_plugin_argument from ..quantization import QuantMode -from ..quantization.mode import FP8, W4A8_AWQ, W4A16, W4A16_AWQ, W8A16 +from ..quantization.mode import FP8, W4A16, W8A16 def parse_arguments(): @@ -45,12 +46,16 @@ def parse_arguments(): parser.add_argument('--model_cls_file', type=str, default=None) parser.add_argument('--model_cls_name', type=str, default=None) parser.add_argument( - '--timing_cache', + '--input_timing_cache', type=str, - default='model.cache', + default=None, help= - 'The path of to read timing cache from, will be ignored if the file does not exist' + 'The path to read timing cache, will be ignored if the file does not exist' ) + parser.add_argument('--output_timing_cache', + type=str, + default='model.cache', + help='The path to write timing cache') parser.add_argument('--log_level', type=str, default='info') parser.add_argument( '--profiling_verbosity', @@ -67,9 +72,7 @@ def parse_arguments(): '--output_dir', type=str, default='engine_outputs', - help= - 'The path to save the serialized engine files, timing cache file and model configs' - ) + help='The path to save the serialized engine files and model configs') parser.add_argument('--workers', type=int, default='1', @@ -95,6 +98,8 @@ def parse_arguments(): action='store_true', help= 'Enable horizontal fusion in GatedMLP, reduces layer input traffic and potentially improves performance. ' + 'For FP8 PTQ, the downside is slight reduction of accuracy because one of the quantization scaling factors are discarded. ' + '(An example for reference only: 0.45734 vs 0.45755 for LLaMA-v2 7B using `ammo/examples/hf/instruct_eval/mmlu.py`).' ) parser.add_argument( '--gather_all_token_logits', @@ -109,7 +114,15 @@ def parse_arguments(): action='store_true', default=False, help='Gather generation logits') - parser.add_argument('--strongly_typed', action='store_true', default=False) + parser.add_argument( + '--strongly_typed', + action='store_true', + default=False, + help= + 'This option is introduced with TensorRT 9.1.0.1+ and will reduce the engine building time. ' + 'It\'s not expected to see performance or accuracy regression after enable this flag. ' + 'Note that, we may remove this flag in the future, and enable the feature by default.' + ) parser.add_argument('--builder_opt', type=int, default=None) parser.add_argument('--logits_dtype', type=str, @@ -126,7 +139,43 @@ def parse_arguments(): help= 'Maximum lengths of draft tokens for speculative decoding target model.' ) - parser.add_argument('--world_size', + parser.add_argument( + '--lora_dir', + type=str, + default=None, + nargs="+", + help="The directory of LoRA weights. " + "Use config from the first directory if multiple directories are provided." + ) + parser.add_argument('--lora_ckpt_source', + type=str, + default="hf", + choices=["hf", "nemo"], + help="The source of lora checkpoint.") + parser.add_argument( + '--lora_target_modules', + nargs='+', + default=None, + choices=[ + "attn_qkv", + "attn_q", + "attn_k", + "attn_v", + "attn_dense", + "mlp_h_to_4h", + "mlp_gate", + "mlp_4h_to_h", + ], + help= + "Add lora in which modules. Only be activated when use_lora_plugin is enabled." + ) + parser.add_argument( + '--max_lora_rank', + type=int, + default=64, + help='maximum lora rank for different lora modules. ' + 'It is used to compute the workspace size of lora plugin.') + parser.add_argument('--auto_parallel', type=int, default=1, help='MPI world size for auto parallel.') @@ -183,13 +232,12 @@ def build_model(build_config: BuildConfig, preprocess_model_config(model_config, **kwargs) - logits_dtype = kwargs.pop('logits_dtype', None) + logits_dtype = kwargs.get('logits_dtype') if logits_dtype is not None: model_config.logits_dtype = logits_dtype model_config.use_prompt_tuning = build_config.max_prompt_embedding_table_size > 0 - - weight_only_precision = kwargs.pop('weight_only_precision', None) + weight_only_precision = kwargs.get('weight_only_precision', None) if model_config.quant_mode == QuantMode( 0) and weight_only_precision is not None: if weight_only_precision == 'int4': @@ -222,7 +270,7 @@ def build_model(build_config: BuildConfig, model = model_cls.from_config(rank_config) if ckpt_dir is not None: if model_config.architecture in WEIGHT_LOADER_MODELS: - model_path = os.path.join(ckpt_dir, f'rank0.safetensors') + model_path = os.path.join(ckpt_dir, 'rank0.safetensors') else: model_path = os.path.join(ckpt_dir, f'rank{rank}.safetensors') @@ -246,11 +294,24 @@ def build_model(build_config: BuildConfig, if hasattr(model.config, 'max_medusa_token_len'): build_config.max_draft_len = model.config.max_medusa_token_len - use_fused_mlp = kwargs.pop('use_fused_mlp', False) + if build_config.plugin_config.lora_plugin is not None: + lora_config = LoraBuildConfig( + lora_dir=kwargs['lora_dir'] or [], + lora_ckpt_source=kwargs['lora_ckpt_source'], + max_lora_rank=kwargs['max_lora_rank']) + if kwargs['lora_target_modules'] is not None: + # command line options is preferred over the modules in the lora dir + lora_config.lora_target_modules = kwargs['lora_target_modules'] + # TODO(yuxianq): remove this check after TopModelMixin merged into PretrainedModel + assert hasattr(model, 'use_lora'), "This model does not support LoRA" + model.use_lora(lora_config) + + use_fused_mlp = kwargs.get('use_fused_mlp', False) use_auto_parallel = build_config.auto_parallel_config.enabled - model = optimize_model(model, - use_fused_mlp=use_fused_mlp - and not use_auto_parallel) + model = optimize_model( + model, + use_fused_mlp=(use_fused_mlp and not use_auto_parallel), + use_prompt_tuning=(build_config.max_prompt_embedding_table_size > 0)) if use_auto_parallel: model.config.mapping.rank = real_rank @@ -258,125 +319,6 @@ def build_model(build_config: BuildConfig, return build(model, build_config) -def preprocess_weights( - weights: Dict[str, torch.Tensor], - model_config: PretrainedConfig) -> Dict[str, torch.Tensor]: - quant_algo = model_config.quantization.quant_algo - kv_cache_quant_algo = model_config.quantization.kv_cache_quant_algo - - # INT4_AWQ - if quant_algo == W4A8_AWQ or quant_algo == W4A16_AWQ: - preprocessor = torch.ops.trtllm.preprocess_weights_for_mixed_gemm - for name, param in weights.items(): - if name.endswith('weight') and param.dtype == torch.int8: - dtype = torch.float16 - if model_config.dtype == "bfloat16": - dtype = torch.bfloat16 - weights[name] = preprocessor(param.T.contiguous(), - torch.quint4x2).view(dtype) - if name.endswith('weights_scaling_factor'): - weights[name] = param.T.contiguous().to( - str_dtype_to_torch(model_config.dtype)) - if name.endswith('prequant_scaling_factor'): - weights[name] = param.reshape(1, -1) - if model_config.mapping.tp_rank > 0: - if name.endswith('attention.dense.bias') or name.endswith( - 'mlp.proj.bias'): - weights[name] = torch.zeros_like(param) - - if quant_algo == W4A8_AWQ: - for name in list(weights): - if name.endswith('weights_scaling_factor'): - activation_scaling_factor = weights.pop( - name.replace('weights_scaling_factor', - 'activation_scaling_factor')) - weights_scaling_factor_2 = weights.pop( - name.replace('weights_scaling_factor', - 'weights_scaling_factor_2')) - weights[name] /= weights_scaling_factor_2 - weights[name.replace( - 'weights_scaling_factor', - 'prequant_scaling_factor')] /= activation_scaling_factor - weights[name.replace( - 'weights_scaling_factor', 'alpha' - )] = activation_scaling_factor * weights_scaling_factor_2 - - # FP8 - elif quant_algo == FP8: - for name, param in weights.items(): - if name.endswith('weight') and param.dtype == torch.int8: - weights[name] = param.view(torch.float8_e4m3fn) - # lm_head is not quantized to FP8 - if "lm_head.weight" in weights: - assert weights['lm_head.weight'].dtype == str_dtype_to_torch( - model_config.dtype) - weights.pop('lm_head.weights_scaling_factor', None) - weights.pop('lm_head.activation_scaling_factor', None) - - # Weight only 4bit - elif quant_algo == W4A16: - for name in list(weights): - if any([ - _name in name for _name in [ - 'qkv.weight', 'dense.weight', 'fc.weight', - 'proj.weight', 'gate.weight' - ] - ]) and weights[name].dtype != torch.int8: - processed_torch_weights, torch_weight_scales = \ - torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix( - weights[name].t().contiguous(), torch.quint4x2) - weights[name] = processed_torch_weights - weights[name.replace( - '.weight', '.per_channel_scale')] = torch_weight_scales - - # Weight only 8bit - elif quant_algo == W8A16: - for name in list(weights): - if any([ - _name in name for _name in [ - 'qkv.weight', 'dense.weight', 'fc.weight', - 'proj.weight', 'gate.weight' - ] - ]) and weights[name].dtype != torch.int8: - processed_torch_weights, torch_weight_scales = \ - torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix( - weights[name].t().contiguous(), torch.int8) - weights[name] = processed_torch_weights - weights[name.replace( - '.weight', '.per_channel_scale')] = torch_weight_scales - - # FP8 kv_cache_scaling_factor is always 1.0 - if kv_cache_quant_algo == FP8: - for name, param in weights.items(): - if name.endswith('kv_cache_scaling_factor'): - weights[name] = torch.tensor([1.0], dtype=torch.float32) - - # If layer_norm bias is None. (For MPT) - if model_config.architecture == 'MPTForCausalLM': - update_dict = {} - for name, param in weights.items(): - if 'input_layernorm.weight' in name and name.replace( - 'weight', 'bias') not in weights: - update_dict[name.replace('weight', - 'bias')] = torch.zeros_like(param) - if 'post_layernorm.weight' in name and name.replace( - 'weight', 'bias') not in weights: - update_dict[name.replace('weight', - 'bias')] = torch.zeros_like(param) - if 'ln_f.weight' in name and name.replace('weight', - 'bias') not in weights: - update_dict[name.replace('weight', - 'bias')] = torch.zeros_like(param) - weights.update(update_dict) - - # Parallel block rowlinear should not have duplicate bias. - if model_config.architecture == 'GPTJForCausalLM': - if model_config.mapping.tp_rank > 0: - for name, param in weights.items(): - if 'attention.dense.bias' in name or 'mlp.proj.bias' in name: - weights[name] = torch.zeros_like(param) - - def build_and_save(rank, gpu_id, ckpt_dir, build_config, output_dir, log_level, model_config, model_cls, **kwargs): torch.cuda.set_device(gpu_id) @@ -463,6 +405,17 @@ def main(): workers = min(torch.cuda.device_count(), args.workers) plugin_config = PluginConfig.from_arguments(args) + kwargs = { + 'logits_dtype': args.logits_dtype, + 'use_fused_mlp': args.use_fused_mlp, + 'weight_only_precision': args.weight_only_precision, + 'tp_size': args.tp_size, + 'pp_size': args.pp_size, + 'lora_dir': args.lora_dir, + 'lora_ckpt_source': args.lora_ckpt_source, + 'max_lora_rank': args.max_lora_rank, + 'lora_target_modules': args.lora_target_modules, + } if args.build_config is None: args.max_num_tokens = check_max_num_tokens( max_num_tokens=args.max_num_tokens, @@ -487,9 +440,11 @@ def main(): 'profiling_verbosity': args.profiling_verbosity, 'enable_debug_output': args.enable_debug_output, 'max_draft_len': args.max_draft_len, + 'input_timing_cache': args.input_timing_cache, + 'output_timing_cache': args.output_timing_cache, 'auto_parallel_config': { 'world_size': - args.world_size, + args.auto_parallel, 'gpus_per_node': args.gpus_per_node, 'cluster_key': @@ -509,13 +464,6 @@ def main(): plugin_config=plugin_config) source = args.checkpoint_dir if args.checkpoint_dir is not None else args.model_config - kwargs = { - 'logits_dtype': args.logits_dtype, - 'use_fused_mlp': args.use_fused_mlp, - 'weight_only_precision': args.weight_only_precision, - 'tp_size': args.tp_size, - 'pp_size': args.pp_size, - } parallel_build(source, build_config, args.output_dir, workers, args.log_level, model_cls, **kwargs) diff --git a/tensorrt_llm/engine.py b/tensorrt_llm/engine.py deleted file mode 100644 index 8162eb55d..000000000 --- a/tensorrt_llm/engine.py +++ /dev/null @@ -1,183 +0,0 @@ -from dataclasses import asdict -from pathlib import Path -from typing import Any, Iterable, Optional, Union - -import janus -import torch - -import tensorrt_llm.bindings as tllm - -from .hlapi.tokenizer import TokenizerBase -from .hlapi.utils import GenerationOutput -from .logger import logger -from .runtime import SamplingConfig - - -class AsyncLLMEngine: - TERMINATE_REQUEST_ID = 0 - - def __init__(self, - engine_dir: Path, - tokenizer: Union[str, Path, TokenizerBase], - max_beam_width: int = 1) -> None: - self.requests: list[tllm.InferenceRequest] = [] - self.results: dict[int, janus.Queue] = {} - self.stop_set: set[int] = set() - self.stats: Optional[janus.LifoQueue] = None - - self.tokenizer = tokenizer - if not isinstance(tokenizer, TokenizerBase): - from transformers import AutoTokenizer - self.tokenizer = AutoTokenizer.from_pretrained( - tokenizer, - legacy=False, - padding_side='left', - truncation_side='left', - trust_remote_code=True, - use_fast=True) - opt_params = tllm.TrtGptModelOptionalParams() - # TODO[chunweiy]: Expose the runtime configs - self.engine = tllm.GptManager( - engine_dir, tllm.TrtGptModelType.InflightFusedBatching, - max_beam_width, tllm.SchedulerPolicy.GUARANTEED_NO_EVICT, - self._fetch_requests_callback, self._handle_response_callback, - self._get_stop_set_callback, self._handle_stats_callback, - opt_params, AsyncLLMEngine.TERMINATE_REQUEST_ID) - - self._next_request_id = AsyncLLMEngine.TERMINATE_REQUEST_ID + 1 - - # TODO[chunweiy]: support token-ids as prompt when Tokenizer is disabled in LLM() - # TODO[chunweiy]: Align the keys between SamplingConfig and gptManager - async def generate( - self, - prompt: str, - streaming: bool = True, - sampling_config: Optional[SamplingConfig] = None - ) -> Iterable[GenerationOutput]: - - sampling_options: dict = asdict( - sampling_config) if sampling_config is not None else dict() - if sampling_options: - sampling_options["max_new_tokens"] = [ - sampling_options['max_new_tokens'] - ] - - tllm_request = self.add_request({ - "prompt": prompt, - "streaming": streaming, - **sampling_options - }) - request_id = tllm_request.request_id - tllm_request.input_ids[0].numpy().tolist() - - finished = False - while not finished: - output, finished = await self.get_response(request_id) - diff_ids = output.numpy().tolist() - diff_str = self.tokenizer.decode(diff_ids) - - output = GenerationOutput( - diff_str, - diff_ids, - # TODO[chunweiy]: return the probs as well - ) - yield output - - @property - def next_request_id(self) -> int: - # underlying type is uint64 - uint64_max = 2**64 - 1 - request_id = self._next_request_id - self._next_request_id = (request_id + 1) % uint64_max - return request_id - - @staticmethod - def create_inference_request( - req_id: int, parameters: dict[str, Any]) -> tllm.InferenceRequest: - - def set_property(name: str, dtype: torch.dtype = torch.int32): - if name in parameters and parameters[name] is not None: - setattr(request, name, - torch.tensor([parameters[name]], dtype=dtype)) - - request = tllm.InferenceRequest(req_id) - request.input_ids = parameters["input_ids"] - set_property("end_id") - set_property("pad_id") - set_property("max_new_tokens") - set_property("min_length") - set_property("temperature", torch.float32) - set_property("runtime_top_k", torch.float32) - set_property("runtime_top_p", torch.float32) - set_property("random_seed", torch.int64) - if "streaming" in parameters: - request.is_streaming = parameters["streaming"] - - return request - - def add_request(self, request_dict: dict[str, - Any]) -> tllm.InferenceRequest: - ids = self.tokenizer(request_dict.pop("prompt"), - return_tensors="pt", - return_attention_mask=False) - request_dict["input_ids"] = ids["input_ids"].to(torch.int32) - request_dict["end_id"] = self.tokenizer.eos_token_id - if getattr(self.tokenizer, "pad_token_id") is not None: - request_dict["pad_id"] = self.tokenizer.pad_token_id - else: - request_dict["pad_id"] = request_dict["end_id"] - - request = AsyncLLMEngine.create_inference_request( - self.next_request_id, request_dict) - - self.results[request.request_id] = janus.Queue() - self.requests.append(request) - - return request - - async def get_response(self, - request_id: int) -> tuple[dict[str, Any], bool]: - outputs, finished = None, False - while outputs is None: - outputs, finished = await self.results[request_id].async_q.get() - - last_idx = outputs["sequence_length"][0, 0].item() - output = outputs["output_ids"][0, 0, :last_idx] - - if finished: - self.results.pop(request_id) - - return output, finished - - # Callbacks for BatchManager - - def _fetch_requests_callback( - self, max_num_sequences) -> list[tllm.InferenceRequest]: - fetched = [] - for _ in range(max_num_sequences): - if len(self.requests) == 0: - break - fetched.append(self.requests.pop()) - return fetched - - def _handle_response_callback(self, req_id: int, - tensors: list[tllm.NamedTensor], is_ok: bool, - err_msg: str) -> None: - if err_msg: - logger.error(f"AsyncLLMEngine process request failed: {err_msg}") - - self.results[req_id].sync_q.put( - [{t.name: t.tensor - for t in tensors}, is_ok] if not err_msg else err_msg) - - def _get_stop_set_callback(self) -> set[int]: - return self.stop_set - - def _handle_stats_callback(self, stats: str): - if self.stats is None: - self.stats = janus.LifoQueue() - - while self.stats.sync_q.full(): - self.stats.sync_q.get() - - self.stats.sync_q.put(stats) diff --git a/tensorrt_llm/executor.py b/tensorrt_llm/executor.py index bf236987a..1f8ead4cc 100644 --- a/tensorrt_llm/executor.py +++ b/tensorrt_llm/executor.py @@ -1,19 +1,23 @@ import asyncio -import time +import secrets +from abc import ABC, abstractmethod +from multiprocessing.managers import BaseManager from pathlib import Path -from queue import Queue +from queue import Empty, Queue +from threading import Lock, Semaphore, Thread from typing import Any, Dict, Generator, List, Optional, Set, Tuple, Union +import numpy as np import torch from janus import Queue as AsyncQueue -from transformers import AutoTokenizer +from mpi4py import MPI -import tensorrt_llm.bindings as tllm -from tensorrt_llm._utils import mpi_broadcast, mpi_rank, mpi_world_size -from tensorrt_llm.hlapi.mpi_session import MpiSession, NodeSession, SocketClient -from tensorrt_llm.hlapi.tokenizer import TokenizerBase -from tensorrt_llm.hlapi.utils import GenerationOutput, print_traceback_on_error -from tensorrt_llm.logger import logger +from tensorrt_llm._utils import mpi_rank, mpi_world_size +from tensorrt_llm.hlapi.mpi_session import MpiSession, find_free_port +from tensorrt_llm.hlapi.tokenizer import TokenizerBase, tokenizer_factory +from tensorrt_llm.hlapi.utils import GenerationOutput + +from . import bindings as tllm def has_event_loop() -> bool: @@ -27,39 +31,61 @@ def has_event_loop() -> bool: class GenerationRequest: def __init__(self, - req_id: int, - ids: torch.Tensor, - end_id: int, - pad_id: int, + ids_or_prompt: Union[torch.Tensor, np.ndarray, list, str], streaming: bool = True, - digit_input=False, + tokenizer: Optional[TokenizerBase] = None, **kwargs): - self.prompt = None - self.ids = ids + if isinstance(ids_or_prompt, str): + assert tokenizer is not None, "GenerationRequest constructor with str prompt requires a tokenizer argument" + self.input_ids = (tokenizer.encode(ids_or_prompt, + return_tensors="pt", + return_attention_mask=False).to( + torch.int32).numpy()) + else: + if isinstance(ids_or_prompt, list): + self.input_ids = np.array(ids_or_prompt, dtype="int32") + elif isinstance(ids_or_prompt, torch.Tensor): + self.input_ids = ids_or_prompt.to(torch.int32).numpy() + elif isinstance(ids_or_prompt, np.ndarray): + self.input_ids = ids_or_prompt + else: + raise ValueError( + f"ids_or_prompt (={ids_or_prompt}) should be an instance of str, torch.Tensor, np.ndarray or list" + ) + + self.tokenizer = tokenizer self.streaming = streaming - self.kwargs = kwargs - self.end_id = end_id - self.pad_id = pad_id - self.digit_input = digit_input - self._id = req_id - - def get_inference_request(self) -> tllm.InferenceRequest: - ir = tllm.InferenceRequest(self._id) - ir.input_ids = self.ids.to(dtype=torch.int32) + self.options = kwargs + if tokenizer is not None: + end_id, pad_id = tokenizer.eos_token_id, tokenizer.pad_token_id + self.options.setdefault("end_id", end_id) + self.options.setdefault("pad_id", + pad_id if pad_id is not None else end_id) + + self.id = -1 + + def set_id(self, id): + self.id = id + return self + + def as_inference_request(self) -> tllm.InferenceRequest: + ir = tllm.InferenceRequest(self.id) + ir.input_ids = torch.from_numpy(self.input_ids) ir.is_streaming = self.streaming def set_property(name: str, dtype: torch.dtype = torch.int32, default: Any = None): - if name in self.kwargs or default is not None: - value = self.kwargs.get(name, default) + if name in self.options or default is not None: + value = self.options.get(name, default) setattr(ir, name, torch.tensor([value], dtype=dtype)) - set_property("max_new_tokens", default=[8]) - - set_property("end_id", default=self.end_id) - set_property("pad_id", default=self.pad_id) - + if "max_new_tokens" in self.options: + self.options["max_new_tokens"] = [self.options["max_new_tokens"]] + set_property("beam_width") + set_property("max_new_tokens", default=[32]) + set_property("end_id") + set_property("pad_id") set_property("min_length") set_property("temperature", torch.float32) set_property("runtime_top_k", torch.float32) @@ -74,100 +100,238 @@ class GenerationResult(GenerationOutput): def __init__(self, generation_request: GenerationRequest, tokenizer: Optional[TokenizerBase] = None) -> None: - self.running = True - self.done = False + self._done = False + self._cancelled = False self.generation_request = generation_request self.tokenizer = tokenizer + self.streaming = generation_request.streaming if has_event_loop(): - self._base_queue = AsyncQueue() - self.queue = self._base_queue.sync_q - self.aqueue = self._base_queue.async_q + aqueue = AsyncQueue() + self.queue = aqueue.sync_q + self.aqueue = aqueue.async_q else: - self._base_queue = Queue() - self.queue = self._base_queue + self.queue = Queue() self.aqueue = None - self.generation: Optional[torch.Tensor] - if generation_request.streaming: - self.generation = generation_request.ids - else: - self.generation = None + beam_width = generation_request.options.get("beam_width", 1) + self.beam_search_enabled = beam_width > 1 + self._token_ids = [[] for _ in range(beam_width)] - # TODO: fill the following fields from GenerationOutput - self.token_ids = [] self.logprobs = [] + self.last_text = "" - def enqueue(self, msg: Tuple[Union[str, Dict[str, torch.Tensor]], bool]): - self.queue.put(msg) - - def handle_generation_msg(self, msg: Union[str, Dict[str, torch.Tensor]]): - if isinstance(msg, str): - raise RuntimeError(msg) - - # TODO[chunweiy]: Unify the msg format for parallel and non-parallel mode - if isinstance(msg, dict): - self.token_ids = msg["output_ids"][0][0] - else: - # this is for parallel mode - assert isinstance(msg, list) - self.token_ids = msg[0] - - @staticmethod - def process_generation(msg: dict): - token_ids = msg["output_ids"][0] - # TODO: add other fields if needed - return token_ids + @property + def token_ids(self): + if not self.beam_search_enabled: + return self._token_ids[0] + return self._token_ids + + def handle_generation_msg(self, tensors: Dict[str, np.ndarray], error: str): + if error: + raise RuntimeError(error) + new_ids = tensors["output_ids"].squeeze(0).tolist() + for idx, beam_ids in enumerate(new_ids): + self._token_ids[idx] += beam_ids + + def result_step(self, timeout: Optional[float] = None): + _, tensors, self._done, error = self.queue.get(timeout=timeout) + self.handle_generation_msg(tensors, error) + + async def aresult_step(self): + assert self.aqueue is not None + _, tensors, self._done, error = await self.aqueue.get() + self.handle_generation_msg(tensors, error) - def wait_step(self, timeout: Optional[float] = None): - msg, self.done = self.queue.get(timeout=timeout) - self.handle_generation_msg(msg) + @property + def text_diff(self) -> str: + assert self.streaming is not None + assert not self.beam_search_enabled, "text_diff is not supported with beam_search" - async def await_step(self): - assert self.aqueue is not None - msg, self.done = await self.aqueue.get() - self.handle_generation_msg(msg) + new_txt = self.text + diff = new_txt[len(self.last_text):] + self.last_text = new_txt + return diff @property - def text(self) -> str: + def text(self) -> Union[str, List[str]]: if self.tokenizer is None: return '' - return self.tokenizer.decode(self.token_ids) + texts = self.tokenizer.batch_decode(self._token_ids) + if not self.beam_search_enabled: + return texts[0] + return texts def result(self, timeout: Optional[float] = None) -> "GenerationResult": - while not self.done: - self.wait_step(timeout) + while not self._done: + self.result_step(timeout) return self async def aresult(self) -> "GenerationResult": - while not self.done: - await self.await_step() + while not self._done: + await self.aresult_step() return self def __iter__(self): return self def __next__(self): - if self.done: + if self._done: raise StopIteration - self.wait_step() + self.result_step() return self def __aiter__(self): return self async def __anext__(self): - if self.done: + if self._done: raise StopAsyncIteration - await self.await_step() + await self.aresult_step() return self + def running(self) -> bool: + return not self._done + + def cancelled(self) -> bool: + return self._cancelled + + def cancel(self): + raise NotImplementedError + + def done(self) -> bool: + return self._done -class GenerationExecutor: + def exception(self, timeout: Optional[float] = None): + try: + self.result(timeout) + except RuntimeError as e: + return e + + +class GenerationExecutor(ABC): TERMINATE_REQUEST_ID = 0 + def __init__(self): + self.id_counter = GenerationExecutor.TERMINATE_REQUEST_ID + 1 + self.tokenizer = None + + def generate_id(self) -> int: + gen_id = self.id_counter + + # underlying C type is uint64 + uint64_max = 2**64 - 1 + self.id_counter = (self.id_counter + 1) % uint64_max + + if self.id_counter == GenerationExecutor.TERMINATE_REQUEST_ID: + self.id_counter += 1 + + return gen_id + + @abstractmethod + def submit(self, request: GenerationRequest) -> GenerationResult: + pass + + def generate_async( + self, prompt: Union[str, List[int], List[str], + List[List[int]]], streaming: bool, + **kwargs: Any) -> Union[GenerationResult, List[GenerationResult]]: + unbatched = isinstance(prompt, str) or (isinstance(prompt, list) + and isinstance(prompt[0], int)) + string_input = isinstance( + prompt, str) or (not unbatched and isinstance(prompt[0], str)) + tokenizer = self.tokenizer if string_input else None + + if unbatched: + results = self.submit( + GenerationRequest(prompt, streaming, tokenizer, **kwargs)) + else: + results = [] + for idx, p in enumerate(prompt): + request_kwargs = { + k: v[idx] if isinstance(v, list) else v + for k, v in kwargs.items() + } + results.append( + self.submit( + GenerationRequest(p, streaming, tokenizer, + **request_kwargs))) + return results + + def generate( + self, + prompt: Union[str, List[int], List[str], List[List[int]]], + streaming: bool = False, + **kwargs: Any) -> Union[GenerationResult, List[GenerationResult]]: + futures = self.generate_async(prompt, streaming=streaming, **kwargs) + if isinstance(futures, GenerationRequest): + futures.result() + else: + for future in futures: + future.result() + return futures + + @abstractmethod + def shutdown(self): + pass + + @abstractmethod + def get_stats(self): + pass + + @abstractmethod + async def aget_stats(self): + pass + + @staticmethod + def create( + engine_dir: Path, + tokenizer: Union[str, Path, TokenizerBase], + max_beam_width: int = 1, + executor_type: tllm.TrtGptModelType = tllm.TrtGptModelType. + InflightBatching, + executor_policy: tllm.SchedulerPolicy = tllm.SchedulerPolicy. + GUARANTEED_NO_EVICT, + executor_config: tllm.TrtGptModelOptionalParams = tllm. + TrtGptModelOptionalParams(), + model_world_size: int = 1, + world_size: int = 0, + mpi_session: Optional[MpiSession] = None, + ) -> Union["GenerationExecutorProxy", "GenerationExecutorWorker"]: + + if world_size == 0: + world_size = mpi_world_size() + + if world_size > 1 and world_size < model_world_size: + raise RuntimeError( + "Cannot instantiate Generator for engine built " + f"for {model_world_size} ranks, while currently running " + f"on {world_size} ranks.") + + worker_kwargs = { + "engine_dir": engine_dir, + "tokenizer": tokenizer, + "max_beam_width": max_beam_width, + "executor_type": executor_type, + "executor_policy": executor_policy, + "executor_config": executor_config, + } + + if world_size == 1 and model_world_size > 1: + return GenerationExecutorProxy(worker_kwargs, + model_world_size=model_world_size, + mpi_session=mpi_session) + + return GenerationExecutorWorker(**worker_kwargs) + + +class GenerationExecutorWorker(GenerationExecutor): + + class WorkerExit(GeneratorExit): + pass + def __init__( self, engine_dir: Path, @@ -180,24 +344,16 @@ def __init__( executor_config: tllm.TrtGptModelOptionalParams = tllm. TrtGptModelOptionalParams(), ) -> None: + super().__init__() - self.active_requests = 0 - - self.tokenizer = tokenizer - if tokenizer is not None and not isinstance(tokenizer, TokenizerBase): - self.tokenizer = AutoTokenizer.from_pretrained( - tokenizer, - legacy=False, - padding_side='left', - truncation_side='left', - trust_remote_code=True, - use_fast=True) + self.engine = None + self.tokenizer = tokenizer_factory(tokenizer) # NOTE: underscore variables are used for communication with the C++ runtime self._requests: List[tllm.InferenceRequest] = [] self._results: Dict[int, GenerationResult] = {} self._cancelled_ids: Set[int] = set() - self._completed: Queue = Queue() + self._pending: set = set() if has_event_loop(): self._stats = AsyncQueue() self.stats_queue = self._stats.sync_q @@ -206,6 +362,25 @@ def __init__( self._stats = Queue() self.stats_queue = self._stats self.stats_aqueue = None + """ + Note: in single-node only (when using .block_subordinates()) the termination + process is as follow: + 0. Nodes > 0 (main threads) directly wait on termination_ack. Node 0 continues execution. + 1. Node 0 (main thread) is finishing and must close GptManager. + 2. Node 0 (main thread) sets _termination_requested and wait on termination_ack + 3. Node 0 (BatchManager thread) exchange _termination_requested via MPI.bcast with all other nodes. + 4. All nodes (BatchManager threads) signal the _termination_ack semaphore and set _termination_pending to avoid fetching new requests. + 5. All nodes (main threads) go through _termination_ack and ask BatchManager to join its threads. + """ + self._block_subordinates = False + self._termination_requested = False + self._termination_pending = False + self._termination_ack = Semaphore(0) + self._termination_lock = Lock() + self.result_queue = None + + self.comm = MPI.COMM_WORLD + self.rank = mpi_rank() self.engine = tllm.GptManager(engine_dir, executor_type, max_beam_width, executor_policy, self.fetch_requests, @@ -214,106 +389,51 @@ def __init__( executor_config, GenerationExecutor.TERMINATE_REQUEST_ID) - self._next_request_id = GenerationExecutor.TERMINATE_REQUEST_ID + 1 + def shutdown(self): + if self.engine is not None: + self.engine.shutdown() + self.engine = None - def submit(self, request: GenerationRequest) -> GenerationResult: - """ - Low-level API to the executor. Return a "future" GenerationResult which can be waited. - """ + def block_subordinates(self): + self._block_subordinates = True + if self.rank != 0: + self._termination_ack.acquire() + self.shutdown() + raise self.WorkerExit( + "block_subordinates() should be used in a `with GenerationExecutorWorker() as ...:` block" + ) - inference_request = request.get_inference_request() + def __enter__(self): + return self - tokenizer = self.tokenizer if not request.digit_input else None - result = GenerationResult(request, tokenizer) - self._results[inference_request.request_id] = result + def __exit__(self, exc_type, exc_value, traceback) -> bool: + del exc_value, traceback # unused arguments - self.active_requests += 1 - self._requests.append(inference_request) + if self._block_subordinates and self.rank == 0: + if self.rank == 0: + self._termination_lock.acquire() + self._termination_requested = True + self._termination_lock.release() - return result + self._termination_ack.acquire() - def get_next_request_id(self) -> int: - # underlying type is uint64 - uint64_max = 2**64 - 1 - request_id = self._next_request_id - self._next_request_id = (request_id + 1) % uint64_max - return request_id + self.shutdown() - def generate_async( - self, - prompt: Union[str, List[int], List[str], List[List[int]]], - streaming: bool, - max_new_tokens: Union[int, List[int]], - end_id: int = -1, - pad_id: int = -1 - ) -> Union[GenerationResult, List[GenerationResult]]: - batched = False - digit_input = False - if isinstance(prompt, list): - if isinstance(prompt[0], str): # List[str] - batched = True - if isinstance(max_new_tokens, int): - max_new_tokens = [max_new_tokens] * len(prompt) - elif isinstance(prompt[0], int): # List[int] - digit_input = True - prompt = [prompt] - if not isinstance(max_new_tokens, list): - max_new_tokens = [max_new_tokens] - # List[List[int]] - elif isinstance(prompt[0], list) and isinstance(prompt[0][0], int): - batched = True - digit_input = True - if not isinstance(max_new_tokens, list): - max_new_tokens = [max_new_tokens] * len(prompt) - else: # str - prompt = [prompt] - if not isinstance(max_new_tokens, list): - max_new_tokens = [max_new_tokens] - - def get_ids(prompt: str | List[int]) -> torch.Tensor: - if digit_input: - return torch.tensor([prompt], dtype=torch.int32) - return self.tokenizer.encode(prompt, - return_tensors="pt", - return_attention_mask=False) - - if end_id == -1: - assert self.tokenizer is not None, "Please specify end_id if tokenizer is not provided" - end_id = self.tokenizer.eos_token_id - pad_id = getattr(self.tokenizer, "pad_token_id", end_id) - - results = [ - self.submit( - GenerationRequest(req_id=self.get_next_request_id(), - ids=get_ids(p), - streaming=streaming, - max_new_tokens=[m], - pad_id=pad_id, - end_id=end_id, - digit_input=digit_input)) - for p, m in zip(prompt, max_new_tokens) - ] - if not batched: - results = results[0] - return results + return exc_type is None or exc_type == GenerationExecutorWorker.WorkerExit - def generate( - self, - prompt: Union[str, List[str]], - max_new_tokens: Union[int, List[int]], - end_id: int = -1, - pad_id: int = -1 - ) -> Union[GenerationResult, List[GenerationResult]]: - results = self.generate_async(prompt, - False, - max_new_tokens, - end_id=end_id, - pad_id=pad_id) - result_list = [results] if isinstance(results, - GenerationRequest) else results - for result in result_list: - result.result() - return results + def submit(self, request: GenerationRequest) -> GenerationResult: + """ + Low-level API to the executor. Return a "future" GenerationResult which can be waited. + """ + result = GenerationResult(request, request.tokenizer) + req_id = self.generate_id() + + request.set_id(req_id) + self._results[req_id] = result + self._pending.add(req_id) + self._requests.append(request.as_inference_request()) + + return result def get_stats(self): return self.stats_queue.get() @@ -325,38 +445,74 @@ async def aget_stats(self): def wait_first_completed( self, futures: List[GenerationResult] ) -> Generator[GenerationResult, None, None]: - wait_set = set(f.generation_request._id for f in futures) + wait_set = set(f.generation_request.id for f in futures) # clear already-finished requests for f in futures: - if f.done: - wait_set.remove(f.generation_request._id) + if f._done: + wait_set.remove(f.generation_request.id) yield f # wait remaining active requests while len(wait_set) > 0: - req_id = self._completed.get() - if req_id in wait_set: - wait_set.remove(req_id) + req_id = wait_set.pop() + + if req_id not in self._pending: yield self._results[req_id] + else: + wait_set.add(req_id) + + def set_result_queue(self, queue): + self.result_queue = queue + + def return_queue(self, req_id: int): + """ If a centralized result queue is registered (used for communication with the proxy) + send the message there. + Otherwise, push the result directly in the GenerationResult queue. + """ + + if self.result_queue is not None: + return self.result_queue + return self._results[req_id].queue # Callbacks for BatchManager def fetch_requests(self, max_num_sequences) -> List[tllm.InferenceRequest]: + if self._termination_pending: + return [] + fetched = [] - for _ in range(max_num_sequences): - if len(self._requests) == 0: - break - fetched.append(self._requests.pop()) + if not self._block_subordinates or self.rank == 0: + for _ in range(max_num_sequences): + if len(self._requests) == 0: + break + fetched.append(self._requests.pop()) + + if self._block_subordinates: + self._termination_lock.acquire() + self._termination_requested = self.comm.bcast( + self._termination_requested) + + if self._termination_requested: + self._termination_ack.release() + self._termination_pending = True + else: + fetched = self.comm.bcast(fetched) + + self._termination_lock.release() + return fetched def handle_response(self, req_id: int, tensors: List[tllm.NamedTensor], finished: bool, err: str) -> None: - self._results[req_id].enqueue( - ({t.name: t.tensor - for t in tensors - if t.tensor is not None} if not err else err, finished)) + if self._block_subordinates and self.rank != 0: + return + + self.return_queue(req_id).put((req_id, { + t.name: t.tensor.numpy() + for t in tensors if t.tensor is not None + }, finished, err)) if finished: - self._completed.put(req_id) + self._pending.remove(req_id) def get_cancelled_ids(self) -> Set[int]: return self._cancelled_ids @@ -367,221 +523,146 @@ def handle_stats(self, stats: str): self.stats_queue.put(stats) - def __enter__(self): - self.engine.__enter__() - return self - - def __exit__(self, exc_type, exc_value, traceback): - if self.engine is not None: - self.engine.__exit__(exc_type, exc_value, traceback) - self.engine = None - def __del__(self): - self.__exit__(None, None, None) + self.shutdown() + +class GenerationExecutorProxy(GenerationExecutor): -class ParallelGenerationExecutor(GenerationExecutor): - ''' GenerationExecutor with MPI enabled. ''' + class ExecutorManager(BaseManager): + pass def __init__( self, - world_size: int, - engine_dir: Path, - tokenizer: Union[str, Path, TokenizerBase, None], - max_beam_width: int = 1, - executor_type: tllm.TrtGptModelType = tllm.TrtGptModelType. - InflightFusedBatching, - executor_policy: tllm.SchedulerPolicy = tllm.SchedulerPolicy. - GUARANTEED_NO_EVICT, - executor_config: tllm.TrtGptModelOptionalParams = tllm. - TrtGptModelOptionalParams(), - socket_client: Optional[SocketClient] = None, + workers_kwargs, + model_world_size: int = 1, + mpi_session: Optional[MpiSession] = None, ) -> None: + super().__init__() + + self.workers_started = False + self.tokenizer = tokenizer_factory(workers_kwargs["tokenizer"]) + + manager_address = ("localhost", find_free_port()) + manager_secret = secrets.token_bytes(512) + self.manager = GenerationExecutorProxy.ExecutorManager( + manager_address, manager_secret) + request_queue, result_queue = Queue(), Queue() + GenerationExecutorProxy.ExecutorManager.register( + "request_queue", lambda: request_queue) + GenerationExecutorProxy.ExecutorManager.register( + "result_queue", lambda: result_queue) + self.manager.start() + self._results: Dict[int, GenerationResult] = {} - self.on_PMP = mpi_world_size() == 1 - self.on_MPI = mpi_world_size() > 1 - - self._terminated = False - self._terminated_sync = False - - self.active_requests = 0 - - self.tokenizer = tokenizer - if tokenizer is not None and not isinstance(tokenizer, TokenizerBase): - self.tokenizer = AutoTokenizer.from_pretrained( - tokenizer, - legacy=False, - padding_side='left', - truncation_side='left', - trust_remote_code=True, - use_fast=True) - - # NOTE: underscore variables are used for communication with the C++ runtime - self._requests: list[tllm.InferenceRequest] = [] - self._results: dict[int, GenerationResult] = {} - self._cancelled_ids: set[int] = set() - self._completed: Queue = Queue() - if has_event_loop(): - self._stats = AsyncQueue() - self.stats_queue = self._stats.sync_q - self.stats_aqueue = self._stats.async_q - else: - self._stats = Queue() - self.stats_queue = self._stats - self.stats_aqueue = None - - self._next_request_id = GenerationExecutor.TERMINATE_REQUEST_ID + 1 - self.socket_client = socket_client - - if self.on_PMP: - # initialize the executor on each MPI node - assert isinstance(self.tokenizer, - TokenizerBase), "tokenizer not initialized" - - self.mpi_session = MpiSession( - n_workers=world_size, - async_callback=self._async_listener_callback) - self.socket_client = self.mpi_session.get_socket_client() - - self.mpi_session.submit_sync( - ParallelGenerationExecutor._node_init_executor_task, engine_dir, - self.tokenizer, max_beam_width, executor_type, executor_policy, - executor_config, self.socket_client) + if mpi_session is None: + self.mpi_session = MpiSession(n_workers=model_world_size) else: - self.engine = tllm.GptManager( - engine_dir, executor_type, max_beam_width, executor_policy, - self.fetch_requests_on_mpi_node, - self.handle_response_on_mpi_node, self.get_cancelled_ids, - self.handle_stats, executor_config, - GenerationExecutor.TERMINATE_REQUEST_ID) - - def submit(self, request: GenerationRequest) -> GenerationResult: - # submit on the PMP - inference_request = request.get_inference_request() - result = GenerationResult(request, self.tokenizer) - self._results[inference_request.request_id] = result - - self.active_requests += 1 + self.mpi_session = mpi_session + self.model_world_size = model_world_size - self.mpi_session.submit_sync( - ParallelGenerationExecutor._node_add_request_task, - inference_request) + self.workers_kwargs = workers_kwargs + self.workers_kwargs.update({ + "manager_address": manager_address, + "manager_secret": manager_secret + }) + self.dispatcher = Thread(target=self.dispatcher_thread) - return result - - @print_traceback_on_error @staticmethod - def _node_add_request_task(inference_request): - executor: GenerationExecutor = NodeSession.state - assert isinstance(executor, - GenerationExecutor), 'executor not initialized' - executor._requests.append(inference_request) - - @print_traceback_on_error - @staticmethod - def _node_init_executor_task( - engine_dir: Path, - tokenizer: TokenizerBase, - max_beam_width: int, - executor_type: tllm.TrtGptModelType, - executor_policy: tllm.SchedulerPolicy, - executor_config: tllm.TrtGptModelOptionalParams, - socket_client: Optional[SocketClient], - ): - ''' Create a local GenerationExecutor instance for each MPI process. ''' - assert not NodeSession.is_initialized(), 'executor already initialized' - - logger.info(f'Initializing executor on MPI node #{mpi_rank()}') - - world_size = mpi_world_size() - NodeSession.state = ParallelGenerationExecutor( - world_size, - engine_dir, - tokenizer, - max_beam_width, - executor_type, - executor_policy, - executor_config=executor_config, - socket_client=socket_client) - - # Callbacks for BatchManager - - @print_traceback_on_error - def fetch_requests_on_mpi_node( - self, max_num_sequences) -> List[tllm.InferenceRequest]: - if mpi_rank() != 0 or self._terminated_sync: - if self._terminated: - return [] - - terminated = mpi_broadcast(self._terminated, 0) - if terminated: - logger.warning(f'#node{mpi_rank()} to terminate') - self._terminated_sync = True - self._terminated = True - - if terminated: - return [] + def workers_main(engine_dir: Path, + tokenizer: Union[str, Path, TokenizerBase], + max_beam_width: int = 1, + executor_type: tllm.TrtGptModelType = tllm.TrtGptModelType. + InflightBatching, + executor_policy: tllm.SchedulerPolicy = tllm. + SchedulerPolicy.GUARANTEED_NO_EVICT, + executor_config: tllm.TrtGptModelOptionalParams = tllm. + TrtGptModelOptionalParams(), + manager_address: Tuple[str, int] = ("", 0), + manager_secret: bytes = b"") -> None: + + with GenerationExecutorWorker(engine_dir, tokenizer, max_beam_width, + executor_type, executor_policy, + executor_config) as executor: + executor.block_subordinates() + + manager = GenerationExecutorProxy.ExecutorManager( + manager_address, manager_secret) + GenerationExecutorProxy.ExecutorManager.register("request_queue") + GenerationExecutorProxy.ExecutorManager.register("result_queue") + manager.connect() + request_queue = manager.request_queue() + manager.result_queue().put(True) + + executor.set_result_queue(manager.result_queue()) + while (req := request_queue.get()) is not None: + executor.submit(req) + + def dispatcher_thread(self): + """ Collect centralized results from Manager's result queue and dispatch them in the + correct GenerationResult queues. """ + + while (res := self.manager.result_queue().get()) is not None: + id, tensors, finished, err = res + self._results[id].queue.put( + (id, + {name: torch.tensor(value) + for name, value in tensors.items()}, finished, err)) + + def start(self): + self.mpi_futures = self.mpi_session.submit( + GenerationExecutorProxy.workers_main, **self.workers_kwargs) + self.workers_started = True + + while True: + ack = False + try: + ack = self.manager.result_queue().get(timeout=0.5) + except Empty: + pass + if not ack: + if any(f.done() for f in self.mpi_futures): + self.shutdown() + raise RuntimeError( + "GenerationExecutorWorker has exited unexpectedly") + else: + break - batch_size = 0 - fetched = [] - if mpi_rank() == 0: - batch_size = min(len(self._requests), max_num_sequences) - batch_size = mpi_broadcast(batch_size, 0) + self.dispatcher.start() - for _ in range(batch_size): - # the MPIPoolExecutor will always submit the same input to every worker, sometimes they arrive at slightly different time - while len(self._requests) == 0: - time.sleep(0.05) - fetched.append(self._requests.pop()) + def shutdown(self): + if not self.workers_started: + return + self.manager.request_queue().put(None) + self.manager.result_queue().put(None) + for f in self.mpi_futures: + f.result() + self.dispatcher.join() + self.workers_started = False - return fetched + def submit(self, request: GenerationRequest) -> GenerationResult: + """ + Low-level API to the executor. Return a "future" GenerationResult which can be waited. + Forwards the request to the workers through the Manager's request queue. + """ + if not self.workers_started: + self.start() - def handle_response_on_mpi_node(self, req_id: int, - tensors: List[tllm.NamedTensor], - finished: bool, err: str) -> None: - if mpi_rank() != 0: - return + req_id = self.generate_id() + request.set_id(req_id) - tensor_dic = {t.name: t.tensor for t in tensors if t.tensor is not None} - output = GenerationResult.process_generation( - tensor_dic) if not err else err - - self.socket_client.send( - dict( - req_id=req_id, - output=output if isinstance(output, str) else output.tolist(), - finished=finished, - )) - - def _async_listener_callback(self, data: Dict[str, Any]): - req_id = data['req_id'] - output = data['output'] - finished = data['finished'] - self._results[req_id].enqueue((output, finished)) - if finished: - self._completed.put(req_id) + result = GenerationResult(request, request.tokenizer) + self._results[req_id] = result - @print_traceback_on_error - @staticmethod - def _node_quit_task(): - executor: GenerationExecutor = NodeSession.state - assert isinstance(executor, - GenerationExecutor), 'executor not initialized' - if mpi_rank() == 0: - executor._terminated = True + self.manager.request_queue().put(request) - time.sleep(1) - executor.engine.__exit__(None, None, None) - NodeSession.state = None + return result - def _shutdown_mpi_nodes(self): - self.mpi_session.submit_sync(ParallelGenerationExecutor._node_quit_task) + def get_stats(self): + pass - def shutdown(self): - if self.on_PMP and self.mpi_session is not None: - self._shutdown_mpi_nodes() - self.mpi_session.shutdown() - self.mpi_session = None + async def aget_stats(self): + pass def __del__(self): self.shutdown() diff --git a/tensorrt_llm/functional.py b/tensorrt_llm/functional.py index e3793227b..77fea2e03 100644 --- a/tensorrt_llm/functional.py +++ b/tensorrt_llm/functional.py @@ -4492,8 +4492,9 @@ def lora_plugin( def selective_scan(input: Tensor, state: Tensor, delta: Tensor, delta_bias: Tensor, A: Tensor, B: Tensor, C: Tensor, - D: Tensor, z: Tensor, host_request_types: Tensor, dim: int, - dstate: int, is_variable_B: bool, is_variable_C: bool, + D: Tensor, z: Tensor, host_request_types: Tensor, + last_token_ids: Tensor, dim: int, dstate: int, + is_variable_B: bool, is_variable_C: bool, delta_softplus: bool, dtype: str): ''' Parameters: @@ -4524,6 +4525,10 @@ def selective_scan(input: Tensor, state: Tensor, delta: Tensor, z : Tensor (On GPU) The z tensor. Its shape is [batch_size, seq_len, dim] + last_token_ids : Tensor (On GPU) + The inclusive prefix-sum of the lengths or the lengths of the + sequences in the batch. + host_request_types : Tensor (On CPU) The tensor on the host that indicates if a request is in context or generation phase. Its shape is [batch_size]. See Inflight Batching @@ -4583,7 +4588,8 @@ def selective_scan(input: Tensor, state: Tensor, delta: Tensor, "selective_scan", pfc) plug_inputs = [ - input, state, delta, delta_bias, A, B, C, D, z, host_request_types + input, state, delta, delta_bias, A, B, C, D, z, host_request_types, + last_token_ids ] plug_inputs = [i.trt_tensor for i in plug_inputs] diff --git a/tensorrt_llm/hlapi/llm.py b/tensorrt_llm/hlapi/llm.py index 381382dae..f485f0b67 100644 --- a/tensorrt_llm/hlapi/llm.py +++ b/tensorrt_llm/hlapi/llm.py @@ -12,24 +12,22 @@ import tensorrt as trt import torch -import tensorrt_llm.bindings as tllm -from tensorrt_llm.bindings import KvCacheConfig, SchedulerPolicy - -from .._utils import mpi_rank +from .. import bindings as tllm +from .._utils import mpi_rank, release_gc from ..auto_parallel.config import AutoParallelConfig, infer_cluster_key +from ..bindings import KvCacheConfig, SchedulerPolicy from ..builder import (BuildConfig, Engine, EngineConfig, PluginConfig, QuantMode, build) -from ..executor import (GenerationExecutor, GenerationResult, - ParallelGenerationExecutor) +from ..executor import GenerationExecutor, GenerationResult from ..logger import logger from ..mapping import Mapping from ..models.modeling_utils import PretrainedConfig from ..module import Module from ..runtime import SamplingConfig -from .mpi_session import MpiSession, NodeSession +from .mpi_session import MPINodeState, MpiSession from .tokenizer import TokenizerBase, TransformersTokenizer from .utils import (GenerationOutput, file_with_suffix_exists, get_device_count, - print_colored, print_traceback_on_error, release_gc) + print_colored, print_traceback_on_error) @dataclass @@ -265,6 +263,7 @@ def __init__(self, self.decoding_mode = decoding_mode self.scheduling_policy = scheduling_policy + self.mpi_session = None # TODO[chunweiy]: Support more models and gpus self._extra_build_config = ModelLoader.load_extra_build_configs_from_engine( @@ -285,7 +284,8 @@ def __init__(self, self.mpi_session = MpiSession(n_workers=self.config.world_size) # Due to the gptManager can only accept a engine path, we need to save the engine to a directory - self._engine_dir: Union[tempfile.TemporaryDirectory, str, Path] = None + self._engine_dir: Union[tempfile.TemporaryDirectory, str, Path, + None] = None self._executor: Optional[GenerationExecutor] = None self.runtime_context: Optional[_ModelRuntimeContext] = None @@ -390,21 +390,28 @@ def save(self, engine_dir: str): if src_engine_dir != engine_dir: shutil.copytree(src_engine_dir, engine_dir, dirs_exist_ok=True) - def __enter__(self): - return self + def shutdown(self): + if self._executor is not None: + self._executor.shutdown() - def __exit__(self, exc_type, exc_value, traceback): if self.mpi_session is not None: self.mpi_session.shutdown() self.mpi_session = None - if hasattr(self, "_executor") and self._executor is not None: - self._executor.__exit__(exc_type, exc_value, traceback) - self._executor = None + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback) -> bool: + del exc_value, traceback + self.shutdown() + return exc_type is not None def _save_engine(self, engine_dir: str): logger.info(f"Save model to {engine_dir}") if self.config.is_multi_gpu: + if self._executor is not None: + self._executor.shutdown() self.mpi_session.submit_sync(LLM._node_save_task, engine_dir, self.config.model_dir) else: @@ -443,6 +450,9 @@ def get_engine_dir(): if model_format is not _ModelFormatKind.TLLM_ENGINE: + if self._executor is not None: + self._executor.shutdown() + self._engine_dir = self.async_engine_tmp_dir if self._engine_dir is None: self._engine_dir = tempfile.TemporaryDirectory() @@ -464,7 +474,7 @@ def get_engine_dir(): runtime_context = model_loader() - # TODO[chunweiy]: Make GptManager support in-memory engine-buffer to save disk loading lantenecy + # TODO[chunweiy]: Make GptManager support in-memory engine-buffer to save disk loading latency ModelLoader.save(runtime_context, self.config.model_dir, engine_dir=get_engine_dir(), @@ -488,42 +498,32 @@ def get_engine_dir(): executor_config.decoding_mode = self.decoding_mode.to_cpp( ) if self.decoding_mode else None - if self.config.is_multi_gpu: - self._executor = ParallelGenerationExecutor( - world_size=self.config.world_size, - engine_dir=get_engine_dir(), - tokenizer=tokenizer, - max_beam_width=self.config.max_beam_width, - executor_policy=self.scheduling_policy, - executor_config=executor_config, - ) - else: - - self._executor = GenerationExecutor( - get_engine_dir(), - tokenizer=tokenizer, - max_beam_width=self.config.max_beam_width, - executor_config=executor_config, - executor_policy=self.scheduling_policy, - ) + self._executor = GenerationExecutor.create( + get_engine_dir(), + tokenizer, + max_beam_width=self.config.max_beam_width, + executor_config=executor_config, + executor_policy=self.scheduling_policy, + model_world_size=self.config.world_size, + mpi_session=self.mpi_session) @print_traceback_on_error @staticmethod def _node_build_task(config: ModelConfig, tokenizer: TokenizerBase = None) -> bool: - assert not NodeSession.is_initialized() + assert not MPINodeState.is_initialized() with ModelLoader(config, tokenizer=tokenizer) as model_loader: runtime_context = model_loader() # Hold the model builder for later use - NodeSession.state = runtime_context + MPINodeState.state = runtime_context return True @print_traceback_on_error @staticmethod def _node_save_task(engine_dir: str, model_dir: str): - runtime_context: _ModelRuntimeContext = NodeSession.state + runtime_context: _ModelRuntimeContext = MPINodeState.state assert isinstance(runtime_context, _ModelRuntimeContext), "Model is not built yet." @@ -535,7 +535,7 @@ def _node_save_task(engine_dir: str, model_dir: str): @print_traceback_on_error @staticmethod def _node_free_state_task(): - NodeSession.state = None + MPINodeState.state = None # release the large resource explicitly and immediately, since the following LLM pipeline may need a lot of memory release_gc() @@ -543,7 +543,7 @@ def __getstate__(self): raise RuntimeError("LLM object can not be pickled.") def __del__(self): - self.__exit__(None, None, None) + self.shutdown() class _ModelFormatKind(Enum): @@ -811,6 +811,8 @@ def _load_model_from_hf(self): mapping=self.mapping, quant_mode=self.config.quant_config.quant_mode, quantize_lm_head=self.config.quant_config.quantize_lm_head, + load_model_on_cpu= + True, # TODO:TRTLLM-195 to enhance the weights loading memory usage and chose best location ) self.pretrained_config = self.model.config self._model_info = _ModelInfo.from_pretrained_config( diff --git a/tensorrt_llm/hlapi/mpi_session.py b/tensorrt_llm/hlapi/mpi_session.py index 22192499c..94672c1e7 100644 --- a/tensorrt_llm/hlapi/mpi_session.py +++ b/tensorrt_llm/hlapi/mpi_session.py @@ -1,23 +1,21 @@ import pickle # nosec B403 import socket import sys -import threading -import time from concurrent.futures import Future -from typing import Any, Callable, List, Optional +from typing import Any, List, Optional from mpi4py.futures import MPIPoolExecutor -class NodeSession: - ''' NodeSession Act as a central global state shares between tasks on MPI node. +class MPINodeState: + ''' MPINodeState acts as a central global state shares between tasks on MPI node. An example: def task(): - if NodeSession.state is None: - NodeSession.state = 0 - NodeSession.state += 1 - return NodeSession.state + if MPINodeState.state is None: + MPINodeState.state = 0 + MPINodeState.state += 1 + return MPINodeState.state n_workers = 4 with MPIPoolExecutor(max_workers=n_workers) as executor: @@ -33,30 +31,26 @@ def task(): @staticmethod def is_initialized() -> bool: - return NodeSession.state is not None + return MPINodeState.state is not None class MpiSession: - def __init__(self, - n_workers: int, - async_callback: Callable[[Any], None] = None): + def __init__(self, n_workers: int): self.n_workers = n_workers self.mpi_pool: Optional[MPIPoolExecutor] = None - self.async_callback = async_callback self._start_mpi_pool() - if self.async_callback: - self._socket_listener = SocketListener(callback=async_callback) - - def submit(self, task: (...), *args) -> List[Future]: + def submit(self, task: (...), *args, **kwargs) -> List[Future]: return [ - self.mpi_pool.submit(task, *args) for i in range(self.n_workers) + self.mpi_pool.submit(task, *args, **kwargs) + for i in range(self.n_workers) ] - def submit_sync(self, task: (...), *args) -> List[Any]: + def submit_sync(self, task: (...), *args, **kwargs) -> List[Any]: futures = [ - self.mpi_pool.submit(task, *args) for i in range(self.n_workers) + self.mpi_pool.submit(task, *args, **kwargs) + for i in range(self.n_workers) ] return [future.result() for future in futures] @@ -65,23 +59,6 @@ def shutdown(self): self.mpi_pool.shutdown() self.mpi_pool = None - if self.async_callback is not None and self._socket_listener is not None: - self._socket_listener.shutdown() - self._socket_listener = None - - def _start(self): - assert not self.mpi_pool, 'MPI session already started' - - self.mpi_pool = MPIPoolExecutor(max_workers=self.n_workers, - path=sys.path) - - @property - def async_enabled(self) -> bool: - return hasattr(self, '_socket_listener') - - def get_socket_client(self) -> "SocketClient": - return self._socket_listener.get_client() - def _start_mpi_pool(self): assert not self.mpi_pool, 'MPI session already started' @@ -95,72 +72,6 @@ def __reduce__(self): raise TypeError('cannot pickle MPI session') -class SocketClient: - - def __init__(self, port): - self.port = port - - def send(self, data: Any): - # TODO[chunweiy]: reuse socket - client_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - client_socket.connect((SocketListener.IP, self.port)) - client_socket.send(pickle.dumps(data)) - client_socket.close() - - -class SocketListener: - IP = 'localhost' - - def __init__(self, - callback: Optional[Callable[[Any], Any]], - buf_size: int = 4096): - self.buf_size = buf_size - self.callback = callback - self.port = -1 - self.server_socket = None - - self._start_service() - - def _start_service(self): - self.server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self.port = find_free_port() - self.server_socket.bind((SocketListener.IP, self.port)) - - def loop(): - self.server_socket.listen(5) - try: - while True: - client_socket, address = self.server_socket.accept() - received_data = client_socket.recv(self.buf_size) - real_data = pickle.loads(received_data) # nosec B301 - if real_data is None: - # get the quit signal - break - - self.callback(real_data) - - finally: - self.server_socket.close() - - self.thread = threading.Thread(target=loop) - self.thread.start() - - def get_client(self) -> SocketClient: - return SocketClient(self.port) - - def shutdown(self): - if self.server_socket is not None: - client = self.get_client() - client.send(None) - time.sleep(0.1) - self.server_socket = None - - def __del__(self): - self.shutdown() - - self.thread.join() - - def find_free_port() -> int: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.bind(('', 0)) diff --git a/tensorrt_llm/hlapi/tokenizer.py b/tensorrt_llm/hlapi/tokenizer.py index 90b84553e..99cbbc699 100644 --- a/tensorrt_llm/hlapi/tokenizer.py +++ b/tensorrt_llm/hlapi/tokenizer.py @@ -1,5 +1,7 @@ -from typing import Any, List +from pathlib import Path +from typing import Any, List, Union +from transformers import AutoTokenizer, PreTrainedTokenizerBase from transformers.tokenization_utils_base import PreTrainedTokenizerBase TokenIdsTy = List[int] @@ -42,3 +44,22 @@ def decode(self, token_ids: TokenIdsTy, *args, **kwargs) -> str: def batch_encode_plus(self, texts: List[str], *args, **kwargs) -> dict: return self.tokenizer.batch_encode_plus(texts, *args, **kwargs) + + +def tokenizer_factory( + obj: Union[str, Path, TokenizerBase, PreTrainedTokenizerBase, None], + **kwargs) -> Union[TokenizerBase, PreTrainedTokenizerBase, None]: + if obj is None: + return None + if isinstance(obj, (str, Path)): + default_kwargs = { + 'legacy': False, + 'padding_side': 'left', + 'truncation_side': 'left', + 'trust_remote_code': True, + 'use_fast': True, + } + default_kwargs.update(kwargs) + return AutoTokenizer.from_pretrained(obj, **kwargs) + + return obj diff --git a/tensorrt_llm/hlapi/utils.py b/tensorrt_llm/hlapi/utils.py index a1603b295..6e425f4e3 100644 --- a/tensorrt_llm/hlapi/utils.py +++ b/tensorrt_llm/hlapi/utils.py @@ -1,10 +1,9 @@ -import gc import sys import traceback from dataclasses import dataclass, field from functools import wraps from pathlib import Path -from typing import List +from typing import List, Union import torch @@ -12,7 +11,7 @@ @dataclass class GenerationOutput: text: str = "" - token_ids: List[int] = field(default_factory=list) + token_ids: Union[List[int], List[List[int]]] = field(default_factory=list) logprobs: List[float] = field(default_factory=list) @@ -59,11 +58,3 @@ def get_device_count() -> int: def get_total_gpu_memory(device: int) -> float: return torch.cuda.get_device_properties(device).total_memory - - -def release_gc(): - ''' Release memory allocated by PyTorch and Python garbage collector explicitly and immediately. - This could be used when some states might be kept in memory even after the variables are deleted. - ''' - gc.collect() - torch.cuda.empty_cache() diff --git a/tensorrt_llm/layers/attention.py b/tensorrt_llm/layers/attention.py index 58e8420a5..5beea950e 100644 --- a/tensorrt_llm/layers/attention.py +++ b/tensorrt_llm/layers/attention.py @@ -35,7 +35,7 @@ from ..quantization.functional import dequantize, quantize from ..quantization.layers import FP8Linear, FP8RowLinear from .linear import ColumnLinear, QKVColumnLinear, RowLinear -from .lora import Lora, LoraRuntimeParams +from .lora import LoraRuntimeParams class RopeEmbeddingUtils: @@ -446,7 +446,6 @@ def __init__( max_distance=0, num_buckets=0, dense_bias=None, - max_lora_rank=None, clip_qkv=None, alibi_bias_max=8, skip_cross_qkv=False, @@ -463,7 +462,7 @@ def __init__( self.num_attention_kv_heads = ( num_kv_heads + tp_size - 1 ) // tp_size if num_kv_heads is not None else self.num_attention_heads - self.hidden_size = hidden_size // tp_size + self.hidden_size = hidden_size self.attention_hidden_size = self.attention_head_size * self.num_attention_heads self.max_position_embeddings = max_position_embeddings self.bias = bias @@ -539,16 +538,14 @@ def __init__( dtype=dtype, tp_group=tp_group, tp_size=tp_size, - gather_output=False, - max_lora_rank=max_lora_rank) + gather_output=False) self.dense = FP8RowLinear(tp_size * self.num_attention_heads * self.attention_head_size, hidden_size, bias=dense_bias, dtype=dtype, tp_group=tp_group, - tp_size=tp_size, - max_lora_rank=max_lora_rank) + tp_size=tp_size) else: # out dim is not necessarily hidden_size + kv specific size (in MQA/GQA), but num_heads * heads_size # example: d_model != num_heads * head_size in Flan-T5 @@ -561,16 +558,14 @@ def __init__( dtype=dtype, tp_group=tp_group, tp_size=tp_size, - gather_output=False, - max_lora_rank=max_lora_rank) + gather_output=False) self.dense = RowLinear(tp_size * self.num_attention_heads * self.attention_head_size, hidden_size, bias=dense_bias, dtype=dtype, tp_group=tp_group, - tp_size=tp_size, - max_lora_rank=max_lora_rank) + tp_size=tp_size) # per-layer relative attention table if relative_attention: @@ -578,21 +573,6 @@ def __init__( tp_size, num_buckets), dtype=dtype) - if max_lora_rank is None: - max_lora_rank = min( - hidden_size, - self.num_attention_heads * self.attention_head_size, - self.num_attention_kv_heads * self.attention_head_size) - self.qkv_lora = Lora( - in_hidden_size=hidden_size, - out_hidden_sizes=[ - self.num_attention_heads * self.attention_head_size, - self.num_attention_kv_heads * self.attention_head_size, - self.num_attention_kv_heads * self.attention_head_size - ], - max_low_rank=max_lora_rank, - ) - if clip_qkv is not None: self.clip_qkv = fp32_array([clip_qkv]) else: @@ -1181,8 +1161,7 @@ def __init__(self, tp_rank=0, relative_attention=False, max_distance=0, - num_buckets=0, - max_lora_rank=None): + num_buckets=0): super().__init__() self.attention_head_size = hidden_size // num_attention_heads if attention_head_size is None else attention_head_size @@ -1190,7 +1169,7 @@ def __init__(self, self.num_attention_kv_heads = ( num_kv_heads + tp_size - 1 ) // tp_size if num_kv_heads is not None else self.num_attention_heads - self.hidden_size = hidden_size // tp_size + self.hidden_size = hidden_size self.attention_hidden_size = self.attention_head_size * self.num_attention_heads self.max_position_embeddings = max_position_embeddings self.norm_factor = math.sqrt(self.attention_head_size) @@ -1222,16 +1201,14 @@ def __init__(self, dtype=dtype, tp_group=tp_group, tp_size=tp_size, - gather_output=False, - max_lora_rank=max_lora_rank) + gather_output=False) self.dense = RowLinear(tp_size * self.num_attention_heads * self.attention_head_size, hidden_size, bias=bias, dtype=dtype, tp_group=tp_group, - tp_size=tp_size, - max_lora_rank=max_lora_rank) + tp_size=tp_size) # per-layer relative attention table if relative_attention: @@ -1239,21 +1216,6 @@ def __init__(self, tp_size, num_buckets), dtype=dtype) - if max_lora_rank is None: - max_lora_rank = min( - hidden_size, - self.num_attention_heads * self.attention_head_size, - self.num_attention_kv_heads * self.attention_head_size) - self.qkv_lora = Lora( - in_hidden_size=hidden_size, - out_hidden_sizes=[ - self.num_attention_heads * self.attention_head_size, - self.num_attention_kv_heads * self.attention_head_size, - self.num_attention_kv_heads * self.attention_head_size - ], - max_low_rank=max_lora_rank, - ) - def forward(self, hidden_states: Tensor, attention_mask=None, diff --git a/tensorrt_llm/layers/embedding.py b/tensorrt_llm/layers/embedding.py index fb93c90b5..2a15f9551 100644 --- a/tensorrt_llm/layers/embedding.py +++ b/tensorrt_llm/layers/embedding.py @@ -52,6 +52,7 @@ def __init__(self, self.tp_group = tp_group self.sharding_dim = sharding_dim self.tp_rank = tp_rank + self.dtype = dtype if sharding_dim == 1: self.weight = Parameter(shape=(self.num_embeddings, diff --git a/tensorrt_llm/layers/linear.py b/tensorrt_llm/layers/linear.py index a0337f8ac..22b3ef6c6 100644 --- a/tensorrt_llm/layers/linear.py +++ b/tensorrt_llm/layers/linear.py @@ -27,7 +27,7 @@ from ..module import Module from ..parameter import Parameter from ..plugin import TRT_LLM_PLUGIN_NAMESPACE -from .lora import Lora, LoraRuntimeParams +from .lora import LoraRuntimeParams def _gemm_plugin(input: Tensor, @@ -77,8 +77,7 @@ def __init__(self, tp_size=1, gather_output=True, share_weight=None, - strict_dtype=False, - max_lora_rank=None): + strict_dtype=False): super().__init__() self.in_features = in_features self.out_features = out_features // tp_size @@ -107,14 +106,6 @@ def __init__(self, else: self.register_parameter('bias', None) - if max_lora_rank is None: - max_lora_rank = min(self.in_features, self.out_features) - self.lora = Lora( - in_hidden_size=self.in_features, - out_hidden_sizes=[self.out_features], - max_low_rank=max_lora_rank, - ) - def multiply_gather(self, x, weight, @@ -216,8 +207,7 @@ def __init__(self, use_fp8=False, tp_group=None, tp_size=1, - strict_dtype: bool = False, - max_lora_rank=None): + strict_dtype: bool = False): super().__init__() self.in_features = in_features // tp_size self.out_features = out_features @@ -237,14 +227,6 @@ def __init__(self, self.tp_group = tp_group self.tp_size = tp_size - - if max_lora_rank is None: - max_lora_rank = min(self.in_features, self.out_features) - self.lora = Lora( - in_hidden_size=self.in_features, - out_hidden_sizes=[self.out_features], - max_low_rank=max_lora_rank, - ) self.strict_dtype = self.dtype if strict_dtype else None def multiply_reduce(self, diff --git a/tensorrt_llm/layers/mlp.py b/tensorrt_llm/layers/mlp.py index 5095532ee..753ff25a9 100644 --- a/tensorrt_llm/layers/mlp.py +++ b/tensorrt_llm/layers/mlp.py @@ -18,7 +18,7 @@ from ..quantization import QuantMode from ..quantization.layers import FP8Linear, FP8RowLinear from .linear import ColumnLinear, RowLinear -from .lora import Lora, LoraRuntimeParams +from .lora import LoraRuntimeParams class MLP(Module): @@ -33,7 +33,6 @@ def __init__( tp_group=None, tp_size=1, quant_mode=QuantMode(0), - max_lora_rank=None, ): super().__init__() if hidden_act not in ACT2FN: @@ -49,15 +48,13 @@ def __init__( dtype=dtype, tp_group=tp_group, tp_size=tp_size, - gather_output=False, - max_lora_rank=max_lora_rank) + gather_output=False) self.proj = FP8RowLinear(ffn_hidden_size, hidden_size, bias=bias, dtype=dtype, tp_group=tp_group, - tp_size=tp_size, - max_lora_rank=max_lora_rank) + tp_size=tp_size) else: self.fc = ColumnLinear(hidden_size, fc_output_size, @@ -65,15 +62,13 @@ def __init__( dtype=dtype, tp_group=tp_group, tp_size=tp_size, - gather_output=False, - max_lora_rank=max_lora_rank) + gather_output=False) self.proj = RowLinear(ffn_hidden_size, hidden_size, bias=bias, dtype=dtype, tp_group=tp_group, - tp_size=tp_size, - max_lora_rank=max_lora_rank) + tp_size=tp_size) self.hidden_act = hidden_act self.dtype = dtype @@ -108,7 +103,6 @@ def __init__( tp_group=None, tp_size=1, quant_mode=QuantMode(0), - max_lora_rank=None, ): super().__init__(hidden_size, ffn_hidden_size, @@ -127,7 +121,6 @@ def __init__( self.tp_group = tp_group self.tp_size = tp_size self.quant_mode = quant_mode - self.max_lora_rank = max_lora_rank if self.use_fp8_qdq: self.gate = FP8Linear(hidden_size, @@ -136,8 +129,7 @@ def __init__( dtype=dtype, tp_group=tp_group, tp_size=tp_size, - gather_output=False, - max_lora_rank=max_lora_rank) + gather_output=False) else: self.gate = ColumnLinear(hidden_size, ffn_hidden_size, @@ -145,8 +137,7 @@ def __init__( dtype=dtype, tp_group=tp_group, tp_size=tp_size, - gather_output=False, - max_lora_rank=max_lora_rank) + gather_output=False) def forward(self, hidden_states, lora_layer_params=None): @@ -186,7 +177,6 @@ def __init__( tp_group=None, tp_size=1, quant_mode=QuantMode(0), - max_lora_rank=None, ): super().__init__() self.hidden_size = hidden_size @@ -214,8 +204,7 @@ def __init__( bias=bias, dtype=dtype, tp_group=tp_group, - tp_size=tp_size, - max_lora_rank=max_lora_rank) + tp_size=tp_size) else: self.fused_fc = ColumnLinear( self.hidden_size, @@ -231,18 +220,7 @@ def __init__( bias=bias, dtype=dtype, tp_group=tp_group, - tp_size=tp_size, - max_lora_rank=max_lora_rank) - - if max_lora_rank is None: - max_lora_rank = min(hidden_size, ffn_hidden_size // tp_size) - self.mlp_in_lora = Lora( - in_hidden_size=hidden_size, - out_hidden_sizes=[ - ffn_hidden_size // tp_size, ffn_hidden_size // tp_size - ], - max_low_rank=max_lora_rank, - ) + tp_size=tp_size) def forward(self, hidden_states, lora_layer_params=None): # Combine the following pattern diff --git a/tensorrt_llm/layers/moe.py b/tensorrt_llm/layers/moe.py index 4df2efed4..9f7c1d6dd 100644 --- a/tensorrt_llm/layers/moe.py +++ b/tensorrt_llm/layers/moe.py @@ -196,8 +196,7 @@ def __init__(self, tp_group: List[int] = None, tp_size: int = 1, tp_rank: int = 0, - quant_mode=QuantMode(0), - max_lora_rank=None): + quant_mode=QuantMode(0)): super().__init__() self.moe_config = moe_config @@ -309,7 +308,7 @@ def __init__(self, self.experts = [ ClsMLP(self.hidden_size, ffn_hidden_size, non_gated_version(self.hidden_act), bias, dtype, tp_group, - tp_size, quant_mode, 0) for _ in range(self.experts_per_node) + tp_size, quant_mode) for _ in range(self.experts_per_node) ] def set_ootb_weight(self): diff --git a/tensorrt_llm/layers/ssm.py b/tensorrt_llm/layers/ssm.py index 1ec2a9d81..3186a3701 100644 --- a/tensorrt_llm/layers/ssm.py +++ b/tensorrt_llm/layers/ssm.py @@ -16,7 +16,7 @@ import math from dataclasses import dataclass -from ..functional import (ACT2FN, Tensor, concat, selective_scan, shape, slice, +from ..functional import (ACT2FN, Tensor, concat, gather, selective_scan, shape, split) from ..module import Module from ..parameter import Parameter @@ -92,7 +92,8 @@ def __init__(self, gather_output=False) def forward(self, hidden_states: Tensor, conv_state: Tensor, - ssm_state: Tensor, host_request_types: Tensor): + ssm_state: Tensor, host_request_types: Tensor, + conv_indices: Tensor, last_token_ids: Tensor): ''' Parameters: hidden_states: [B, L, D] @@ -107,14 +108,10 @@ def forward(self, hidden_states: Tensor, conv_state: Tensor, # In context phase, conv_state is a zero tensor, and it is used for padding # In generation phase, conv_state is a tensor of the past x - slice_shape = concat([shape(x, 0), self.d_inner, self.d_conv - 1]) - past_conv_state = slice(conv_state, - concat([0, 0, shape(conv_state, 2) - 3]), - slice_shape) - x_pad = concat([past_conv_state, x], dim=2) + x_pad = concat([conv_state, x], dim=2) # Update conv_state - conv_state = x_pad + conv_state = gather(x_pad, 2, conv_indices) # Convolution x_pad = x_pad.view( @@ -145,6 +142,7 @@ def forward(self, hidden_states: Tensor, conv_state: Tensor, self.D.value, z, host_request_types, + last_token_ids, self.d_inner, self.d_state, is_variable_B=True, diff --git a/tensorrt_llm/lora_manager.py b/tensorrt_llm/lora_manager.py index f827efb4d..54dd0042b 100644 --- a/tensorrt_llm/lora_manager.py +++ b/tensorrt_llm/lora_manager.py @@ -1,16 +1,21 @@ import json import os import re +from dataclasses import dataclass, field from pathlib import Path +from typing import Dict, List import numpy as np import torch -from ._utils import str_dtype_to_torch, torch_to_numpy, unpack_nemo_weights +from ._utils import (DictConversion, pad_vocab_size, str_dtype_to_torch, + torch_to_numpy, unpack_nemo_weights) +from .layers.linear import ColumnLinear +from .models.convert_utils import split_matrix_tp def get_all_nemo_lora_weights(num_layers, lora_weights): - layer_weights = [{} for _ in range(2 * num_layers)] + layer_weights = [{} for _ in range(num_layers)] adapter_key = "self_attention.adapter_layer.lora_kqv_adapter" layer_pattern = re.compile(r'.*\.layers\.([0-9]+)\..*') for key, weights in lora_weights.items(): @@ -27,22 +32,198 @@ def get_all_nemo_lora_weights(num_layers, lora_weights): return layer_weights -class LoraConfig(object): - LORA_MODULE_IDS = { - "attn_qkv": 0, - "attn_q": 1, - "attn_k": 2, - "attn_v": 3, - "attn_dense": 4, - "mlp_h_to_4h": 5, - "mlp_4h_to_h": 6, - "mlp_gate": 7, - "cross_attn_qkv": 8, - "cross_attn_q": 9, - "cross_attn_k": 10, - "cross_attn_v": 11, - "cross_attn_dense": 12, +@dataclass +class LoraBuildConfig(DictConversion): + lora_dir: List[str] = field(default_factory=list) + lora_ckpt_source: str = 'hf' + max_lora_rank: int = 64 + lora_target_modules: List[str] = field(default_factory=list) + trtllm_modules_to_hf_modules: Dict[str, str] = field(default_factory=dict) + + def __post_init__(self): + assert self.lora_ckpt_source in [ + 'hf', 'nemo' + ], f"lora_ckpt_source must be one of 'hf' or 'nemo', got {self.lora_ckpt_source}" + + +class HfLoraLoader: + + def __init__(self, lora_dirs: List[str]): + self.lora_target_modules = [] + self.is_valid = False + self.lm_head = None + self.embed_tokens = None + self.vocab_size = 0 + + if len(lora_dirs) == 0: + return + + for lora_dir in lora_dirs: + for filename in ["adapter_config.json", "adapter_model.bin"]: + path = Path(f"{lora_dir}/{filename}") + if not path.exists(): + raise ValueError(f"{path} does not exist") + if not path.is_file(): + raise ValueError(f"{path} is not a file") + self.is_valid = True + + lora_dir = lora_dirs[0] + with open(f"{lora_dir}/adapter_config.json") as f: + adapter_config = json.load(f) + self.lora_target_modules = adapter_config["target_modules"] + + lora_weight = torch.load(f"{lora_dir}/adapter_model.bin") + if adapter_config["modules_to_save"] is not None: + if "lm_head" in adapter_config["modules_to_save"]: + self.lm_head = lora_weight["base_model.model.lm_head.weight"] + self.vocab_size = self.lm_head.shape[0] + + if "embed_tokens" in adapter_config["modules_to_save"]: + self.embed_tokens = lora_weight[ + "base_model.model.model.embed_tokens.weight"] + + def get_target_modules(self, trtllm_modules_to_hf_modules): + hf_modules_to_trtllm_modules = { + v: k + for k, v in trtllm_modules_to_hf_modules.items() + } + lora_target_modules = [] + if self.is_valid: + # lora_target_modules[m] can ba either a string or a list of strings + for m in self.lora_target_modules: + trtllm_module = hf_modules_to_trtllm_modules[m] + if isinstance(trtllm_module, list): + lora_target_modules.extend(trtllm_module) + else: + lora_target_modules.append(trtllm_module) + return lora_target_modules + + +class NemoLoraLoader: + + def __init__(self, lora_dirs: List[str]): + self.lora_target_modules = [] + self.is_valid = False + + if len(lora_dirs) == 0: + return + + for lora_file in lora_dirs: + path = Path(lora_file) + if not path.exists(): + raise ValueError(f"{path} does not exist") + if not path.is_file(): + raise ValueError(f"{path} is not a file") + self.is_valid = True + # Hardcoded since LoraManager only supports this case now + self.lora_target_modules = ["attn_qkv"] + + +def load_nemo_lora(model, lora_config: LoraBuildConfig): + lora_loader = NemoLoraLoader(lora_config.lora_dir) + if len(lora_config.lora_target_modules) == 0: + lora_config.lora_target_modules = lora_loader.lora_target_modules + + +def load_hf_lora( + model, + lora_config: LoraBuildConfig, + trtllm_modules_to_hf_modules: Dict[str, str] = None, +): + trtllm_modules_to_hf_modules = trtllm_modules_to_hf_modules or { + "attn_q": "q_proj", + "attn_k": "k_proj", + "attn_v": "v_proj", + "attn_dense": "o_proj", + "mlp_h_to_4h": "gate_proj", + "mlp_4h_to_h": "down_proj", + "mlp_gate": "up_proj", } + lora_config.trtllm_modules_to_hf_modules = trtllm_modules_to_hf_modules + + lora_loader = HfLoraLoader(lora_config.lora_dir) + + if len(lora_config.lora_target_modules) == 0: + lora_config.lora_target_modules = lora_loader.get_target_modules( + trtllm_modules_to_hf_modules) + + config = model.config + if lora_loader.is_valid: + # the lora checkpoint might finetune the embedding + if lora_loader.vocab_size != 0: + config.vocab_size = lora_loader.vocab_size + mapping = config.mapping + if mapping.is_first_pp_rank() and lora_loader.embed_tokens is not None: + weight = lora_loader.embed_tokens + if config.use_parallel_embedding: + weight = split_matrix_tp( + weight, + mapping.tp_size, + mapping.tp_rank, + dim=config.embedding_sharding_dim, + ) + if model.transformer.vocab_embedding.weight.raw_value.shape != weight.shape: + model.transformer.vocab_embedding = model.transformer.vocab_embedding.__class__( + num_embeddings=config.vocab_size, + embedding_dim=config.hidden_size, + dtype=config.dtype, + tp_size=mapping.tp_size + if config.use_parallel_embedding else 1, + tp_group=mapping.tp_group + if config.use_parallel_embedding else None, + sharding_dim=config.embedding_sharding_dim, + tp_rank=mapping.tp_rank, + ) + model.transformer.vocab_embedding.weight.value = weight + if mapping.is_last_pp_rank() and lora_loader.lm_head is not None: + weight = lora_loader.lm_head + vocab_size = lora_loader.vocab_size + if vocab_size % mapping.tp_size != 0: + # padding + vocab_size_padded = pad_vocab_size(vocab_size, mapping.tp_size) + pad_width = vocab_size_padded - vocab_size + + weight = torch.from_numpy( + np.pad(weight.detach().cpu().numpy(), + ((0, pad_width), (0, 0)), + 'constant', + constant_values=0)) + else: + vocab_size_padded = vocab_size + if model.lm_head.weight.raw_value.shape != weight.shape: + model.lm_head = ColumnLinear( + config.hidden_size, + vocab_size_padded, + bias=False, + dtype=config.dtype, + tp_group=mapping.tp_group, + tp_size=mapping.tp_size, + gather_output=True, + ) + model.lm_head.weight.value = split_matrix_tp( + weight, + mapping.tp_size, + mapping.tp_rank, + dim=0, + ) + + +def use_lora( + model, + lora_config: LoraBuildConfig, + trtllm_modules_to_hf_modules: Dict[str, str] = None, +): + model.lora_config = lora_config + if lora_config.lora_ckpt_source == "nemo": + load_nemo_lora(model, lora_config) + elif lora_config.lora_ckpt_source == "hf": + load_hf_lora(model, lora_config, trtllm_modules_to_hf_modules) + else: + raise ValueError( + f"Unsupported lora_ckpt_source: {lora_config.lora_ckpt_source}") + + +class LoraConfig(object): def __init__(self, hf_lora_dir: str = None, @@ -121,6 +302,21 @@ def from_hf(cls, hf_lora_dir, hf_modules_to_trtllm_modules, class LoraManager(object): + LORA_MODULE_IDS = { + "attn_qkv": 0, + "attn_q": 1, + "attn_k": 2, + "attn_v": 3, + "attn_dense": 4, + "mlp_h_to_4h": 5, + "mlp_4h_to_h": 6, + "mlp_gate": 7, + "cross_attn_qkv": 8, + "cross_attn_q": 9, + "cross_attn_k": 10, + "cross_attn_v": 11, + "cross_attn_dense": 12, + } def __init__(self): self._lora_uid_to_key = {} @@ -239,8 +435,7 @@ def load_from_nemo(self, model_files, model_config, runtime_mapping): t_out.flatten()])) self._lora_weight_config[uid].append( np.array([ - LoraConfig.LORA_MODULE_IDS[lora_module], - layer_idx, + self.LORA_MODULE_IDS[lora_module], layer_idx, int(rank) ], dtype=np.int32)) @@ -421,7 +616,7 @@ def load_from_hf(self, model_dirs, model_config, runtime_mapping): t_out.flatten()])) self._lora_weight_config[uid].append( np.array([ - LoraConfig.LORA_MODULE_IDS[lora_module], layer_idx, + self.LORA_MODULE_IDS[lora_module], layer_idx, int(hf_config['r']) ], dtype=np.int32)) @@ -603,7 +798,7 @@ def load_from_hf_bart(self, component, model_dirs, model_config, t_out.flatten()])) self._lora_weight_config[uid].append( np.array([ - LoraConfig.LORA_MODULE_IDS[lora_module], layer_idx, + self.LORA_MODULE_IDS[lora_module], layer_idx, int(hf_config['r']) ], dtype=np.int32)) diff --git a/tensorrt_llm/models/__init__.py b/tensorrt_llm/models/__init__.py index a8eb9a29f..36ee10786 100755 --- a/tensorrt_llm/models/__init__.py +++ b/tensorrt_llm/models/__init__.py @@ -20,7 +20,7 @@ from .enc_dec.model import DecoderModel, EncoderModel, WhisperEncoder from .falcon.model import FalconForCausalLM, FalconModel from .gemma.model import GemmaForCausalLM -from .gpt.model import GPTLMHeadModel, GPTModel +from .gpt.model import GPTForCausalLM, GPTModel from .gptj.model import GPTJForCausalLM, GPTJModel from .gptneox.model import GPTNeoXForCausalLM, GPTNeoXModel from .llama.model import LLaMAForCausalLM, LLaMAModel @@ -43,7 +43,7 @@ 'FalconForCausalLM', 'FalconModel', 'GPTModel', - 'GPTLMHeadModel', + 'GPTForCausalLM', 'OPTForCausalLM', 'OPTModel', 'LLaMAForCausalLM', @@ -73,6 +73,7 @@ ] MODEL_MAP = { + 'GPTForCausalLM': GPTForCausalLM, 'OPTForCausalLM': OPTForCausalLM, 'BloomForCausalLM': BloomForCausalLM, 'FalconForCausalLM': FalconForCausalLM, @@ -90,4 +91,5 @@ 'BaichuanForCausalLM': BaichuanForCausalLM, 'SkyworkForCausalLM': LLaMAForCausalLM, 'GemmaForCausalLM': GemmaForCausalLM, + 'QWenForCausalLM': QWenForCausalLM, } diff --git a/tensorrt_llm/models/baichuan/model.py b/tensorrt_llm/models/baichuan/model.py index a0fb62b16..7904e9022 100644 --- a/tensorrt_llm/models/baichuan/model.py +++ b/tensorrt_llm/models/baichuan/model.py @@ -15,7 +15,7 @@ from ..._utils import pad_vocab_size from ...functional import Tensor from ...layers import (Attention, AttentionMaskType, ColumnLinear, Embedding, - GatedMLP, PromptTuningEmbedding, RmsNorm) + GatedMLP, RmsNorm) from ...module import Module from ..modeling_utils import (DecoderLayerList, DecoderModelForCausalLM, PretrainedConfig) @@ -104,10 +104,7 @@ def __init__(self, config: PretrainedConfig): super().__init__() hidden_size = config.hidden_size dtype = config.dtype - self.use_prompt_tuning = config.use_prompt_tuning - - EmbeddingCls = PromptTuningEmbedding if config.use_prompt_tuning else Embedding - self.vocab_embedding = EmbeddingCls( + self.vocab_embedding = Embedding( config.vocab_size, config.hidden_size, dtype=config.dtype, @@ -134,7 +131,7 @@ def forward(self, prompt_tasks=None, prompt_vocab_size=None): args = [prompt_embedding_table, prompt_tasks, prompt_vocab_size - ] if self.use_prompt_tuning else [] + ] if prompt_embedding_table is not None else [] hidden_states = self.vocab_embedding(input_ids, *args) hidden_states = self.layers(hidden_states, diff --git a/tensorrt_llm/models/bert/model.py b/tensorrt_llm/models/bert/model.py index e34b9f7d4..122a32247 100644 --- a/tensorrt_llm/models/bert/model.py +++ b/tensorrt_llm/models/bert/model.py @@ -17,9 +17,9 @@ import numpy as np from ..._common import default_net -from ...functional import (ACT2FN, bert_attention, concat, constant, expand, - expand_mask, matmul, select, shape, slice, softmax, - split, unsqueeze) +from ...functional import (ACT2FN, bert_attention, cast, concat, constant, + expand, expand_mask, matmul, select, shape, slice, + softmax, split, unsqueeze) from ...layers import MLP, ColumnLinear, Embedding, LayerNorm, Linear, RowLinear from ...mapping import Mapping from ...module import Module, ModuleList @@ -111,6 +111,7 @@ def transpose_for_scores(x): attention_scores = attention_scores / self.norm_factor if attention_mask is not None: + attention_mask = cast(attention_mask, attention_scores.dtype) attention_scores = attention_scores + attention_mask attention_probs = softmax(attention_scores, dim=-1) diff --git a/tensorrt_llm/models/convert_utils.py b/tensorrt_llm/models/convert_utils.py new file mode 100644 index 000000000..2339c9f35 --- /dev/null +++ b/tensorrt_llm/models/convert_utils.py @@ -0,0 +1,34 @@ +import torch + + +def split(v, tp_size, idx, dim=0): + if tp_size == 1: + return v + if len(v.shape) == 1: + return torch.chunk(v, tp_size)[idx].contiguous() + else: + return torch.chunk(v, tp_size, dim=dim)[idx].clone() + + +def split_qkv_tp(v, n_head, n_hidden, tensor_parallel, rank): + """ + Splits the QKV matrix according to tensor parallelism + """ + v = v.reshape(3, n_hidden, n_hidden) + split_v = split(v, tensor_parallel, rank, dim=1) + split_v = split_v.reshape(3 * (n_hidden // tensor_parallel), n_hidden) + return split_v.clone() + + +def split_qkv_bias_tp(v, n_head, n_hidden, tensor_parallel, rank): + """ + Splits the QKV bias according to tensor parallelism + """ + v = v.reshape(3, n_hidden) + split_v = split(v, tensor_parallel, rank, dim=1) + split_v = split_v.reshape(3 * (n_hidden // tensor_parallel)) + return split_v.clone() + + +def split_matrix_tp(v, tensor_parallel, rank, dim): + return split(v, tensor_parallel, rank, dim=dim) diff --git a/tensorrt_llm/models/enc_dec/model.py b/tensorrt_llm/models/enc_dec/model.py index 3ea0e7699..95fc041ec 100644 --- a/tensorrt_llm/models/enc_dec/model.py +++ b/tensorrt_llm/models/enc_dec/model.py @@ -32,6 +32,7 @@ LoraParams, PromptTuningEmbedding, RmsNorm) from tensorrt_llm.mapping import Mapping from tensorrt_llm.models.generation_mixin import GenerationMixin +from tensorrt_llm.models.modeling_utils import optimize_model from tensorrt_llm.module import Module, ModuleList from tensorrt_llm.parameter import Parameter from tensorrt_llm.plugin.plugin import current_all_reduce_helper @@ -175,8 +176,7 @@ def __init__(self, relative_attention=False, max_distance=0, num_buckets=0, - fp16_clamping=False, - max_lora_rank=None): + fp16_clamping=False): super().__init__() # e.g. BART regular, T5 RMS @@ -201,8 +201,7 @@ def __init__(self, dtype=dtype, relative_attention=relative_attention, max_distance=max_distance, - num_buckets=num_buckets, - max_lora_rank=max_lora_rank) + num_buckets=num_buckets) self.attention_layernorm = ln_type(normalized_shape=hidden_size, eps=layernorm_eps, @@ -219,7 +218,6 @@ def __init__(self, tp_group=mapping.tp_group, tp_size=mapping.tp_size, dtype=dtype, - max_lora_rank=max_lora_rank, ) self.mlp_layernorm = ln_type(normalized_shape=hidden_size, @@ -313,7 +311,6 @@ def __init__(self, max_distance=0, num_buckets=0, fp16_clamping=False, - max_lora_rank=None, skip_cross_qkv=False): super().__init__() @@ -344,9 +341,7 @@ def __init__(self, max_distance=max_distance, num_buckets=num_buckets, position_embedding_type=PositionEmbeddingType.relative - if relative_attention else PositionEmbeddingType.learned_absolute, - max_lora_rank=max_lora_rank, - skip_cross_qkv=skip_cross_qkv) + if relative_attention else PositionEmbeddingType.learned_absolute) self.self_attention_layernorm = ln_type(normalized_shape=hidden_size, eps=layernorm_eps, @@ -379,7 +374,7 @@ def __init__(self, max_distance=max_distance, num_buckets=num_buckets, position_embedding_type=PositionEmbeddingType.learned_absolute, - max_lora_rank=max_lora_rank) + skip_cross_qkv=skip_cross_qkv) self.cross_attention_layernorm = ln_type(normalized_shape=hidden_size, eps=layernorm_eps, @@ -396,7 +391,6 @@ def __init__(self, tp_group=mapping.tp_group, tp_size=mapping.tp_size, dtype=dtype, - max_lora_rank=max_lora_rank, ) self.mlp_layernorm = ln_type(normalized_shape=hidden_size, @@ -615,8 +609,7 @@ def __init__(self, relative_attention=relative_attention, max_distance=max_distance, num_buckets=num_buckets, - fp16_clamping=fp16_clamping, - max_lora_rank=max_lora_rank) + fp16_clamping=fp16_clamping) for _ in self.mapping.pp_layers(self.total_num_layers) ]) @@ -626,6 +619,9 @@ def __init__(self, eps=layernorm_eps, dtype=dtype) + if max_lora_rank is not None: + optimize_model(self, use_lora=True, max_lora_rank=max_lora_rank) + def forward(self, input_ids: Tensor, input_lengths=None, @@ -777,8 +773,7 @@ def prepare_inputs(self, ) if use_custom_all_reduce and self.mapping.tp_size > 1: - current_all_reduce_helper().set_workspace_tensor( - self.mapping, False) + current_all_reduce_helper().set_workspace_tensor(self.mapping, 1) input_lengths = Tensor( name="input_lengths", @@ -1044,7 +1039,6 @@ def __init__(self, max_distance=max_distance, num_buckets=num_buckets, fp16_clamping=fp16_clamping, - max_lora_rank=max_lora_rank, skip_cross_qkv=skip_cross_qkv, ) for layer_idx in layers_range ]) @@ -1065,6 +1059,9 @@ def __init__(self, gather_output=True, ) + if max_lora_rank is not None: + optimize_model(self, use_lora=True, max_lora_rank=max_lora_rank) + def forward(self, decoder_input_ids: Tensor, encoder_output: Tensor, @@ -1441,8 +1438,7 @@ def prepare_inputs( ) if use_custom_all_reduce and self.mapping.tp_size > 1: - current_all_reduce_helper().set_workspace_tensor( - self.mapping, False) + current_all_reduce_helper().set_workspace_tensor(self.mapping, 1) layers_range = self.mapping.pp_layers(self.total_num_layers) num_pp_layers = len(layers_range) @@ -1607,7 +1603,9 @@ def prepare_inputs( ('boolean', [1]), ])) cross_qkv_reuse = None - cross_qkv_out_dim = self.num_heads * self.head_size + 2 * self.num_kv_heads * self.head_size + num_heads = (self.num_heads + self.mapping.tp_size - + 1) // self.mapping.tp_size + cross_qkv_out_dim = num_heads * self.head_size + 2 * num_kv_heads * self.head_size if self.skip_cross_qkv: if remove_input_padding: cross_qkv_reuse = Tensor( diff --git a/tensorrt_llm/models/gemma/model.py b/tensorrt_llm/models/gemma/model.py index 66483391f..54dea07be 100644 --- a/tensorrt_llm/models/gemma/model.py +++ b/tensorrt_llm/models/gemma/model.py @@ -17,8 +17,7 @@ from ..._utils import pad_vocab_size from ...functional import Tensor, recv, send from ...layers import (Attention, AttentionMaskType, ColumnLinear, Embedding, - GatedMLP, PositionEmbeddingType, PromptTuningEmbedding, - RmsNorm) + GatedMLP, PositionEmbeddingType, RmsNorm) from ...mapping import Mapping from ...module import Module from ...quantization import QuantMode @@ -118,10 +117,8 @@ def __init__(self, config) -> None: super().__init__() self.mapping = config.mapping - self.use_prompt_tuning = config.use_prompt_tuning - EmbeddingCls = PromptTuningEmbedding if config.use_prompt_tuning else Embedding if self.mapping.is_first_pp_rank(): - self.vocab_embedding = EmbeddingCls( + self.vocab_embedding = Embedding( num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, dtype=config.dtype, @@ -160,7 +157,7 @@ def forward(self, ptuning_args = [ prompt_embedding_table, prompt_tasks, prompt_vocab_size - ] if self.use_prompt_tuning else [] + ] if prompt_embedding_table is not None else [] if self.mapping.is_first_pp_rank(): hidden_states = self.vocab_embedding(input_ids, *ptuning_args) @@ -265,7 +262,6 @@ def from_hugging_face(cls, 'use_parallel_embedding': kwargs.get("use_parallel_embedding", False), 'embedding_sharding_dim': kwargs.get("embedding_sharding_dim", 0), - 'use_prompt_tuning': kwargs.get("use_prompt_tuning", False), 'use_fused_mlp': kwargs.get("use_fused_mlp", False), } diff --git a/tensorrt_llm/models/generation_mixin.py b/tensorrt_llm/models/generation_mixin.py index be9835938..908cbdd67 100644 --- a/tensorrt_llm/models/generation_mixin.py +++ b/tensorrt_llm/models/generation_mixin.py @@ -26,10 +26,10 @@ class GenerationMixin: @staticmethod - def has_two_optimization_profiles(use_gpt_attention_plugin: bool, - use_gemm_plugin: bool, - remove_input_padding: bool, - paged_kv_cache: bool) -> bool: + def has_ctx_gen_opt_profiles(use_gpt_attention_plugin: bool, + use_gemm_plugin: bool, + remove_input_padding: bool, + paged_kv_cache: bool) -> bool: res = False if not use_gpt_attention_plugin or not use_gemm_plugin: use_in_flight_batching = use_gpt_attention_plugin and remove_input_padding and paged_kv_cache @@ -41,6 +41,26 @@ def default_range(max_range, offset=0): result = [1, (max_range + 1) // 2, max_range] return [elem + offset for elem in result] + @staticmethod + def split_num_tokens_range(max_num_tokens): + split_point = [64, 128, 256, 512, 1024] + num_tokens_ranges = [] + for i, p in enumerate(split_point): + if i == 0 and max_num_tokens <= p: + return [0, max_num_tokens, max_num_tokens] + elif max_num_tokens <= p: + num_tokens_ranges.append( + [split_point[i - 1], max_num_tokens, max_num_tokens]) + return num_tokens_ranges + elif i == 0 and max_num_tokens > p: + num_tokens_ranges = [[0, 64, 64]] + else: + num_tokens_ranges.append( + [split_point[i - 1], split_point[i], split_point[i]]) + num_tokens_ranges.append( + [split_point[-1], max_num_tokens, max_num_tokens]) + return num_tokens_ranges + def prepare_attention_inputs(self, *, max_batch_size, @@ -51,9 +71,10 @@ def prepare_attention_inputs(self, head_size, num_layers, kv_dtype, + num_profiles=1, + enable_ctx_gen_opt_profiles=False, remove_input_padding=False, use_gpt_attention_plugin=False, - use_gemm_plugin=False, paged_kv_cache=False, tokens_per_block=64, mapping=Mapping(), @@ -70,26 +91,21 @@ def prepare_attention_inputs(self, _kv_cache_range_gen = default_range(max_seq_len, -1) _kv_cache_range = default_range(max_seq_len) - enable_two_optimization_profiles = GenerationMixin.has_two_optimization_profiles( - use_gpt_attention_plugin, use_gemm_plugin, remove_input_padding, - paged_kv_cache) - if enable_two_optimization_profiles: + if enable_ctx_gen_opt_profiles: + assert num_profiles == 2 bb_range = [bb_range_cxt, bb_range_gen] - bs_range = [_bs_range, _bs_range] - beam_width_range = [_beam_width_range, _beam_width_range] - max_len_range = [_max_len_range, _max_len_range] mask_len_range = [_mask_len_ctx, _max_len_range] if use_gpt_attention_plugin: kv_cache_range = [_kv_cache_range, _kv_cache_range] else: kv_cache_range = [_kv_cache_range_ctx, _kv_cache_range_gen] else: - bb_range = [bb_range_gen] - bs_range = [_bs_range] - beam_width_range = [_beam_width_range] - max_len_range = [_max_len_range] - mask_len_range = [_max_len_range] - kv_cache_range = [_kv_cache_range] + bb_range = [bb_range_gen] * num_profiles + mask_len_range = [_max_len_range] * num_profiles + kv_cache_range = [_kv_cache_range] * num_profiles + bs_range = [_bs_range] * num_profiles + beam_width_range = [_beam_width_range] * num_profiles + max_len_range = [_max_len_range] * num_profiles num_kv_heads = (num_kv_heads + mapping.tp_size - 1) // mapping.tp_size layers_range = mapping.pp_layers(num_layers) @@ -102,13 +118,10 @@ def prepare_attention_inputs(self, for i in layers_range: kv_dim_range = OrderedDict([ ('batch_size_beam_width', bb_range), - ('kv', - [2, 2] if enable_two_optimization_profiles else [2]), - ('num_heads', [num_kv_heads, num_kv_heads] if - enable_two_optimization_profiles else [num_kv_heads]), + ('kv', [2] * num_profiles), + ('num_heads', [num_kv_heads] * num_profiles), ('past_key_len', kv_cache_range), - ('head_size', [head_size, head_size] - if enable_two_optimization_profiles else [head_size]), + ('head_size', [head_size] * num_profiles), ]) kv = Tensor(name=f'past_key_value_{i}', dtype=kv_dtype, @@ -116,7 +129,7 @@ def prepare_attention_inputs(self, dim_range=kv_dim_range) past_key_value.append(kv) else: - if enable_two_optimization_profiles: + if enable_ctx_gen_opt_profiles: max_blocks_per_seq_range = [ [ math.ceil(kv_cache_range[0][0] / tokens_per_block), @@ -129,29 +142,21 @@ def prepare_attention_inputs(self, math.ceil(kv_cache_range[1][2] / tokens_per_block) ] ] - max_blocks_per_seq_range = [[ - x for x in max_blocks_per_seq_range[0] - ], [x for x in max_blocks_per_seq_range[1]]] else: max_blocks_per_seq_range = [[ math.ceil(kv_cache_range[0][0] / tokens_per_block), math.ceil(kv_cache_range[0][1] / tokens_per_block), math.ceil(kv_cache_range[0][2] / tokens_per_block) - ]] - max_blocks_per_seq_range = [[ - x for x in max_blocks_per_seq_range[0] - ]] + ]] * num_profiles kv_cache_block_pointers = Tensor( name=f'kv_cache_block_pointers', dtype=trt.int64, shape=[num_pp_layers, -1, 2, -1], dim_range=OrderedDict([ - ('num_layers', [num_pp_layers, num_pp_layers] if - enable_two_optimization_profiles else [num_pp_layers]), + ('num_layers', [num_pp_layers] * num_profiles), ('batch_size_beam_width', bb_range), - ('kv', - [2, 2] if enable_two_optimization_profiles else [2]), + ('kv', [2] * num_profiles), ('max_blocks_per_seq', max_blocks_per_seq_range), ])) host_kv_cache_block_pointers = Tensor( @@ -159,11 +164,9 @@ def prepare_attention_inputs(self, dtype=trt.int64, shape=[num_pp_layers, -1, 2, -1], dim_range=OrderedDict([ - ('num_layers', [num_pp_layers, num_pp_layers] if - enable_two_optimization_profiles else [num_pp_layers]), + ('num_layers', [num_pp_layers] * num_profiles), ('batch_size_beam_width', bb_range), - ('kv', - [2, 2] if enable_two_optimization_profiles else [2]), + ('kv', [2] * num_profiles), ('max_blocks_per_seq', max_blocks_per_seq_range), ])) @@ -234,19 +237,15 @@ def prepare_attention_inputs(self, name=f'host_max_attention_window_sizes', dtype=trt.int32, shape=[num_pp_layers], - dim_range=OrderedDict([ - ('num_layers', [num_pp_layers, num_pp_layers] - if enable_two_optimization_profiles else [num_pp_layers]) - ])) + dim_range=OrderedDict([('num_layers', + [num_pp_layers] * num_profiles)])) - host_sink_token_length = Tensor( - name='host_sink_token_length', - dtype=trt.int32, - shape=[1], - dim_range=OrderedDict([ - ('scalar', - [1, 1] if enable_two_optimization_profiles else [1]) - ])) + host_sink_token_length = Tensor(name='host_sink_token_length', + dtype=trt.int32, + shape=[1], + dim_range=OrderedDict([ + ('scalar', [1] * num_profiles) + ])) if use_cache: cache_indirection = Tensor( @@ -301,7 +300,8 @@ def prepare_basic_inputs(self, position_encoding_2d=False, use_lora_plugin: bool = False, lora_target_modules: List[str] = None, - max_draft_len=0): + max_draft_len=0, + multiple_profiles: bool = False): default_range = GenerationMixin.default_range last_token_range = [1, max_draft_len + 1, max_draft_len + 1] @@ -318,14 +318,15 @@ def prepare_basic_inputs(self, inlen_range_cxt = default_range(max_input_len) inlen_range_gen = [1, 1, max_draft_len + 1] - enable_two_optimization_profiles = GenerationMixin.has_two_optimization_profiles( + enable_ctx_gen_opt_profiles = GenerationMixin.has_ctx_gen_opt_profiles( use_gpt_attention_plugin, use_gemm_plugin, remove_input_padding, paged_kv_cache) if max_num_tokens is None: max_num_tokens = max( max_input_len * max_batch_size, max_beam_width * (max_draft_len + 1) * max_batch_size) - if enable_two_optimization_profiles: + if enable_ctx_gen_opt_profiles: + num_profiles = 2 bb_range = [bb_range_cxt, bb_range_gen] bbd_range = [bbd_range_ctx, bbd_range_gen] inlen_range = [inlen_range_cxt, inlen_range_gen] @@ -334,16 +335,20 @@ def prepare_basic_inputs(self, num_tokens_range_gen = default_range( max_batch_size * (max_draft_len + 1) * max_beam_width) num_tokens_range = [num_tokens_range_ctx, num_tokens_range_gen] - last_token_range = [last_token_range, last_token_range] else: - bb_range = [bb_range_gen] - bbd_range = [bbd_range_gen] - last_token_range = [last_token_range] - inlen_range = [[1, 1, max_input_len]] - position_ids_inlen_range = [[1, 1, max_input_len]] - num_tokens_range = [[ - 1, max_batch_size * max_beam_width, max_num_tokens - ]] + if multiple_profiles: + num_tokens_range = GenerationMixin.split_num_tokens_range( + max_num_tokens) + else: + num_tokens_range = [[ + 1, max_batch_size * max_beam_width, max_num_tokens + ]] + num_profiles = len(num_tokens_range) + bb_range = [bb_range_gen] * num_profiles + bbd_range = [bbd_range_gen] * num_profiles + inlen_range = [[1, 1, max_input_len]] * num_profiles + position_ids_inlen_range = [[1, 1, max_input_len]] * num_profiles + last_token_range = [last_token_range] * num_profiles position_ids_num_tokens_range = num_tokens_range input_ids = None @@ -363,8 +368,7 @@ def prepare_basic_inputs(self, dtype=trt.int32, shape=[2, -1], dim_range=OrderedDict([ - ('2', [2, 2] - if enable_two_optimization_profiles else [2]), + ('2', [2] * num_profiles), ('position_ids_num_tokens_range', position_ids_num_tokens_range), ]), @@ -388,10 +392,7 @@ def prepare_basic_inputs(self, shape=[-1, head_size * num_heads], dim_range=OrderedDict([ ('num_tokens', num_tokens_range), - ('hidden_size', - [head_size * num_heads, head_size * - num_heads] if enable_two_optimization_profiles else - [head_size * num_heads]), + ('hidden_size', [head_size * num_heads] * num_profiles), ])) else: @@ -410,8 +411,7 @@ def prepare_basic_inputs(self, shape=[-1, 2, -1], dim_range=OrderedDict([ ('batch_size_beam_width', bb_range), - ('2', [2, 2] - if enable_two_optimization_profiles else [2]), + ('2', [2] * num_profiles), ('position_ids_inlen_range', position_ids_inlen_range), ]), @@ -437,15 +437,12 @@ def prepare_basic_inputs(self, dim_range=OrderedDict([ ('batch_size_beam_width', bb_range), ('input_len', inlen_range), - ('hidden_size', - [head_size * num_heads, head_size * - num_heads] if enable_two_optimization_profiles else - [head_size * num_heads]), + ('hidden_size', [head_size * num_heads] * num_profiles), ])) if use_custom_all_reduce and mapping.tp_size > 1: current_all_reduce_helper().set_workspace_tensor( - mapping, enable_two_optimization_profiles) + mapping, num_profiles) prompt_embedding_table = None tasks = None @@ -456,20 +453,17 @@ def prepare_basic_inputs(self, _p_embedding_range = [ 1, prompt_embedding_table_size // 2, prompt_embedding_table_size ] - if enable_two_optimization_profiles: - p_embedding_range = [_p_embedding_range, _p_embedding_range] - else: - p_embedding_range = [_p_embedding_range] - - prompt_embedding_table = Tensor( - name='prompt_embedding_table', - dtype=dtype, - shape=[-1, hidden_size], - dim_range=OrderedDict([ - ('prompt_embedding_table_size', p_embedding_range), - ('hidden_size', [hidden_size, hidden_size] - if enable_two_optimization_profiles else [hidden_size]), - ])) + p_embedding_range = [_p_embedding_range] * num_profiles + + prompt_embedding_table = Tensor(name='prompt_embedding_table', + dtype=dtype, + shape=[-1, hidden_size], + dim_range=OrderedDict([ + ('prompt_embedding_table_size', + p_embedding_range), + ('hidden_size', + [hidden_size] * num_profiles), + ])) if remove_input_padding: tasks = Tensor(name='tasks', dtype=trt.int32, @@ -478,23 +472,19 @@ def prepare_basic_inputs(self, ('input_len_task', num_tokens_range), ])) else: - tasks = Tensor( - name='tasks', - dtype=trt.int32, - shape=[-1, 1], - dim_range=OrderedDict([ - ('batch_size_beam_width', bb_range), - ('broadcast_dim', - [1, 1] if enable_two_optimization_profiles else [1]), - ])) - prompt_vocab_size = Tensor( - name='prompt_vocab_size', - dtype=trt.int32, - shape=[1], - dim_range=OrderedDict([ - ('size', - [1, 1] if enable_two_optimization_profiles else [1]) - ])) + tasks = Tensor(name='tasks', + dtype=trt.int32, + shape=[-1, 1], + dim_range=OrderedDict([ + ('batch_size_beam_width', bb_range), + ('broadcast_dim', [1] * num_profiles), + ])) + prompt_vocab_size = Tensor(name='prompt_vocab_size', + dtype=trt.int32, + shape=[1], + dim_range=OrderedDict([ + ('size', [1] * num_profiles) + ])) lora_weights_pointers = None lora_ranks = None @@ -513,8 +503,7 @@ def prepare_basic_inputs(self, shape=[-1, 2], dim_range=OrderedDict([ ('batch_size_beam_width', bb_range), - ('in_out', [2, 2] - if enable_two_optimization_profiles else [2]), + ('in_out', [2] * num_profiles), ])) lora_weight_pointer_dict.update({ f"{lora_module}_lora_weights_pointers": @@ -577,9 +566,10 @@ def prepare_basic_inputs(self, head_size=head_size, num_layers=num_layers, kv_dtype=kv_dtype, + num_profiles=num_profiles, + enable_ctx_gen_opt_profiles=enable_ctx_gen_opt_profiles, remove_input_padding=remove_input_padding, use_gpt_attention_plugin=use_gpt_attention_plugin, - use_gemm_plugin=use_gemm_plugin, paged_kv_cache=paged_kv_cache, tokens_per_block=tokens_per_block, mapping=mapping) diff --git a/tensorrt_llm/models/gpt/model.py b/tensorrt_llm/models/gpt/model.py index e8fdd5559..f88dbdedf 100644 --- a/tensorrt_llm/models/gpt/model.py +++ b/tensorrt_llm/models/gpt/model.py @@ -13,22 +13,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List - -import tensorrt as trt - -from ..._common import default_net -from ..._utils import pad_vocab_size, str_dtype_to_trt -from ...functional import (Tensor, gather_last_token_logits, - is_gated_activation, non_gated_version) -from ...layers import (MLP, MOE, Attention, AttentionMaskType, AttentionParams, - ColumnLinear, Embedding, GatedMLP, KeyValueCacheParams, - LayerNorm, LoraParams, MoeConfig, PositionEmbeddingType, - PromptTuningEmbedding) -from ...mapping import Mapping -from ...module import Module, ModuleList +from ..._utils import pad_vocab_size +from ...functional import Tensor, is_gated_activation, non_gated_version +from ...layers import (MLP, MOE, Attention, AttentionMaskType, ColumnLinear, + Embedding, GatedMLP, LayerNorm, MoeConfig, + PositionEmbeddingType) +from ...lora_manager import LoraBuildConfig, use_lora +from ...module import Module from ...quantization import QuantMode -from ..generation_mixin import GenerationMixin +from ..modeling_utils import (DecoderLayerList, DecoderModelForCausalLM, + PretrainedConfig) def MLPFactory(hidden_size, @@ -40,8 +34,7 @@ def MLPFactory(hidden_size, tp_group=None, tp_size=1, tp_rank=0, - quant_mode=QuantMode(0), - max_lora_rank=None): + quant_mode=QuantMode(0)): if moe_config.has_moe(): return MOE(moe_config, hidden_size, @@ -52,99 +45,82 @@ def MLPFactory(hidden_size, tp_group, tp_size, tp_rank, - quant_mode=quant_mode, - max_lora_rank=max_lora_rank) + quant_mode=quant_mode) MLPClass = GatedMLP if is_gated_activation(hidden_act) else MLP hidden_act = non_gated_version(hidden_act) - return MLPClass(hidden_size, - ffn_hidden_size, - hidden_act, - bias, - dtype, - tp_group, - tp_size, - quant_mode, - max_lora_rank=max_lora_rank) + return MLPClass( + hidden_size, + ffn_hidden_size, + hidden_act, + bias, + dtype, + tp_group, + tp_size, + quant_mode, + ) class GPTDecoderLayer(Module): - def __init__(self, - *, - local_layer_idx, - hidden_size, - num_attention_heads, - max_position_embeddings, - num_layers, - dtype=None, - apply_query_key_layer_scaling=False, - attention_mask_type=AttentionMaskType.causal, - hidden_act='relu', - position_embedding_type=PositionEmbeddingType.learned_absolute, - quant_mode=QuantMode(0), - rotary_embedding_percentage=1.0, - rotary_base=10000.0, - rotary_scaling=None, - inter_size=None, - bias=True, - num_kv_heads=None, - moe_config: MoeConfig = MoeConfig(), - tp_group=None, - tp_size=1, - tp_rank=0, - max_lora_rank=None): + def __init__(self, config: PretrainedConfig, layer_idx: int): super().__init__() - self.hidden_size = hidden_size - self.num_attention_heads = num_attention_heads - self.max_position_embeddings = max_position_embeddings - self.num_layers = num_layers - self.dtype = dtype - self.apply_query_key_layer_scaling = apply_query_key_layer_scaling - self.attention_mask_type = attention_mask_type - self.hidden_act = hidden_act - self.position_embedding_type = position_embedding_type - self.tp_group = tp_group - self.tp_size = tp_size - self.input_layernorm = LayerNorm(normalized_shape=hidden_size, - dtype=dtype) + self.layer_idx = layer_idx + self.config = config + tp_group = config.mapping.tp_group + tp_size = config.mapping.tp_size + tp_rank = config.mapping.tp_rank + + self.input_layernorm = LayerNorm(normalized_shape=config.hidden_size, + eps=config.norm_epsilon, + dtype=config.dtype) + + layers_range = config.mapping.pp_layers(config.num_hidden_layers) + local_layer_idx = layer_idx - layers_range[0] self.attention = Attention( local_layer_idx=local_layer_idx, - hidden_size=hidden_size, - num_attention_heads=num_attention_heads, - num_kv_heads=num_kv_heads, - max_position_embeddings=max_position_embeddings, - num_layers=num_layers, - apply_query_key_layer_scaling=apply_query_key_layer_scaling, - dtype=dtype, - attention_mask_type=attention_mask_type, - position_embedding_type=position_embedding_type, - rotary_embedding_percentage=rotary_embedding_percentage, - rotary_embedding_base=rotary_base, - rotary_embedding_scaling=rotary_scaling, - bias=bias, + hidden_size=config.hidden_size, + num_attention_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + max_position_embeddings=config.max_position_embeddings, + num_layers=config.num_hidden_layers, + apply_query_key_layer_scaling=config.apply_query_key_layer_scaling, + dtype=config.dtype, + attention_mask_type=AttentionMaskType.causal, + position_embedding_type=config.position_embedding_type, + rotary_embedding_percentage=config.rotary_pct, + rotary_embedding_base=config.rotary_base, + rotary_embedding_scaling=config.rotary_scaling, + bias=config.bias, tp_group=tp_group, tp_size=tp_size, tp_rank=tp_rank, - quant_mode=quant_mode, - max_lora_rank=max_lora_rank) - - if inter_size is None: - inter_size = hidden_size * 4 - - self.mlp = MLPFactory(hidden_size=hidden_size, - ffn_hidden_size=inter_size, - hidden_act=hidden_act, - dtype=dtype, - bias=bias, + quant_mode=config.quant_mode) + + mlp_hidden_size = config.hidden_size * 4 if config.intermediate_size is None else config.intermediate_size + + moe_config = MoeConfig() + if config.moe_num_experts > 1: + moe_config = MoeConfig( + config.moe_num_experts, + config.moe_top_k, + config.moe_tp_mode, + config.moe_normalization_mode, + ) + self.mlp = MLPFactory(hidden_size=config.hidden_size, + ffn_hidden_size=mlp_hidden_size, + hidden_act=config.hidden_act, + dtype=config.dtype, + bias=config.bias, moe_config=moe_config, tp_group=tp_group, tp_size=tp_size, tp_rank=tp_rank, - quant_mode=quant_mode, - max_lora_rank=max_lora_rank) - self.post_layernorm = LayerNorm(normalized_shape=hidden_size, - dtype=dtype) + quant_mode=config.quant_mode) + + self.post_layernorm = LayerNorm(normalized_shape=config.hidden_size, + eps=config.norm_epsilon, + dtype=config.dtype) def forward(self, hidden_states: Tensor, @@ -186,77 +162,33 @@ def forward(self, class GPTModel(Module): - def __init__(self, - num_layers, - num_heads, - hidden_size, - vocab_size, - hidden_act, - max_position_embeddings, - dtype=None, - mapping=Mapping(), - apply_query_key_layer_scaling=False, - position_embedding_type=PositionEmbeddingType.learned_absolute, - rotary_embedding_percentage=1.0, - rotary_base=10000.0, - rotary_scaling=None, - inter_size=None, - bias=True, - quant_mode=QuantMode(0), - num_kv_heads=None, - use_prompt_tuning=False, - use_parallel_embedding=False, - embedding_sharding_dim=0, - moe_config=MoeConfig(), - max_lora_rank=None): + def __init__(self, config: PretrainedConfig): super().__init__() - self.mapping = mapping - self.use_prompt_tuning = use_prompt_tuning - self.position_embedding_type = position_embedding_type - - EmbeddingCls = PromptTuningEmbedding if use_prompt_tuning else Embedding - self.vocab_embedding = EmbeddingCls( - vocab_size, - hidden_size, - dtype=dtype, - tp_size=mapping.tp_size if use_parallel_embedding else 1, - tp_group=mapping.tp_group if use_parallel_embedding else None, - sharding_dim=embedding_sharding_dim, - tp_rank=mapping.tp_rank) - if position_embedding_type == PositionEmbeddingType.learned_absolute: - self.position_embedding = Embedding(max_position_embeddings, - hidden_size, - dtype=dtype) - - layers_range = self.mapping.pp_layers(num_layers) - self.layers = ModuleList([ - GPTDecoderLayer( - local_layer_idx=layer_idx - layers_range[0], - hidden_size=hidden_size, - num_attention_heads=num_heads, - max_position_embeddings=max_position_embeddings, - num_layers=num_layers, - dtype=dtype, - apply_query_key_layer_scaling=apply_query_key_layer_scaling, - attention_mask_type=AttentionMaskType.causal, - hidden_act=hidden_act, - position_embedding_type=position_embedding_type, - rotary_embedding_percentage=rotary_embedding_percentage, - rotary_base=rotary_base, - rotary_scaling=rotary_scaling, - num_kv_heads=num_kv_heads, - tp_group=mapping.tp_group, - tp_size=mapping.tp_size, - tp_rank=mapping.tp_rank, - inter_size=inter_size, - bias=bias, - quant_mode=quant_mode, - moe_config=moe_config, - max_lora_rank=max_lora_rank, - ) for layer_idx in layers_range - ]) - - self.ln_f = LayerNorm(normalized_shape=hidden_size, dtype=dtype) + self.mapping = config.mapping + self.position_embedding_type = config.position_embedding_type + + tp_group = config.mapping.tp_group + tp_size = config.mapping.tp_size + tp_rank = config.mapping.tp_rank + + self.vocab_embedding = Embedding( + num_embeddings=config.vocab_size, + embedding_dim=config.hidden_size, + dtype=config.dtype, + tp_size=tp_size if config.use_parallel_embedding else 1, + tp_group=tp_group if config.use_parallel_embedding else None, + sharding_dim=config.embedding_sharding_dim, + tp_rank=tp_rank) + if config.position_embedding_type == PositionEmbeddingType.learned_absolute: + self.position_embedding = Embedding( + num_embeddings=config.max_position_embeddings, + embedding_dim=config.hidden_size, + dtype=config.dtype) + self.layers = DecoderLayerList(GPTDecoderLayer, config) + + self.ln_f = LayerNorm(normalized_shape=config.hidden_size, + eps=config.norm_epsilon, + dtype=config.dtype) def forward(self, input_ids, @@ -270,48 +202,27 @@ def forward(self, prompt_vocab_size=None, lora_params=None): - args = [prompt_embedding_table, prompt_tasks, prompt_vocab_size - ] if self.use_prompt_tuning else [] - hidden_states = self.vocab_embedding(input_ids, *args) - if self.position_embedding_type == PositionEmbeddingType.learned_absolute: - hidden_states = hidden_states + self.position_embedding( - position_ids) - kv_cache_params.fill_none_tensor_list(len(self.layers)) if use_cache: presents = [] - for layer_idx, (layer, past) in enumerate( - zip(self.layers, kv_cache_params.past_key_value)): - - lora_layer_params = None - if lora_params.lora_ranks is not None: - lora_layer_params = lora_params.get_layer_params(layer_idx) - - hidden_states = layer( - hidden_states, - use_cache=use_cache, - attention_mask=attention_mask, - kv_cache_params=KeyValueCacheParams( - past_key_value=[past], - host_past_key_value_lengths=kv_cache_params. - host_past_key_value_lengths, - host_max_attention_window_sizes=kv_cache_params. - host_max_attention_window_sizes, - host_sink_token_length=kv_cache_params. - host_sink_token_length, - kv_cache_block_pointers=kv_cache_params. - kv_cache_block_pointers, - host_kv_cache_block_pointers=kv_cache_params. - host_kv_cache_block_pointers, - cache_indirection=kv_cache_params.cache_indirection), - attention_params=attention_params, - lora_layer_params=lora_layer_params) - - if use_cache: - presents.append(hidden_states[1]) - hidden_states = hidden_states[0] + ptuning_args = [ + prompt_embedding_table, prompt_tasks, prompt_vocab_size + ] if prompt_embedding_table is not None else [] + hidden_states = self.vocab_embedding(input_ids, *ptuning_args) + if self.position_embedding_type == PositionEmbeddingType.learned_absolute: + hidden_states = hidden_states + self.position_embedding( + position_ids) + + hidden_states = self.layers(hidden_states, + use_cache=use_cache, + attention_mask=attention_mask, + kv_cache_params=kv_cache_params, + attention_params=attention_params, + lora_params=lora_params) + if use_cache: + hidden_states, presents = hidden_states hidden_states = self.ln_f(hidden_states) @@ -320,232 +231,49 @@ def forward(self, return hidden_states -class GPTLMHeadModel(GPTModel, GenerationMixin): - - def __init__(self, - num_layers, - num_heads, - hidden_size, - vocab_size, - hidden_act, - max_position_embeddings, - dtype, - logits_dtype='float32', - mapping=Mapping(), - apply_query_key_layer_scaling=False, - position_embedding_type=PositionEmbeddingType.learned_absolute, - rotary_embedding_percentage=1.0, - rotary_base=10000.0, - rotary_scaling=None, - inter_size=None, - bias=True, - quant_mode=QuantMode(0), - num_kv_heads=None, - use_prompt_tuning=False, - use_parallel_embedding=False, - embedding_sharding_dim=0, - moe_config=MoeConfig(), - share_embedding_table=False, - max_lora_rank=None): - - if isinstance(dtype, str): - self._kv_dtype = str_dtype_to_trt(dtype) - else: - assert isinstance(dtype, trt.DataType) - self._kv_dtype = dtype - - if share_embedding_table and mapping.tp_size > 1: - if (not use_parallel_embedding) or (use_parallel_embedding and - embedding_sharding_dim == 1): +class GPTForCausalLM(DecoderModelForCausalLM): + + def __init__(self, config: PretrainedConfig): + self.check_config(config) + transformer = GPTModel(config) + vocab_size_padded = pad_vocab_size(config.vocab_size, + config.mapping.tp_size) + + if config.share_embedding_table and config.mapping.tp_size > 1: + if (not config.use_parallel_embedding) or ( + config.use_parallel_embedding + and config.embedding_sharding_dim == 1): raise NotImplementedError( 'For multiple-processes cases, sharing the embedding table must set use_parallel_embedding=True and embedding_sharding_dim = 0' ) - self._dtype = self._kv_dtype - self.quant_mode = quant_mode - if quant_mode.has_int8_kv_cache(): - self._kv_dtype = str_dtype_to_trt('int8') - elif quant_mode.has_fp8_kv_cache(): - self._kv_dtype = str_dtype_to_trt('fp8') - - if isinstance(logits_dtype, str): - self._logits_dtype = str_dtype_to_trt(logits_dtype) - else: - assert isinstance(logits_dtype, trt.DataType) - self._logits_dtype = logits_dtype - - self._num_layers = num_layers - self._num_heads = num_heads - self._hidden_size = hidden_size - self._vocab_size = vocab_size - self._tp_size = mapping.tp_size - self._num_kv_heads = num_kv_heads if num_kv_heads else num_heads - - super().__init__( - num_layers=num_layers, - num_heads=num_heads, - hidden_size=hidden_size, - vocab_size=vocab_size, - hidden_act=hidden_act, - max_position_embeddings=max_position_embeddings, - dtype=dtype, - mapping=mapping, - apply_query_key_layer_scaling=apply_query_key_layer_scaling, - position_embedding_type=position_embedding_type, - rotary_embedding_percentage=rotary_embedding_percentage, - rotary_base=rotary_base, - rotary_scaling=rotary_scaling, - inter_size=inter_size, - bias=bias, - quant_mode=quant_mode, - num_kv_heads=num_kv_heads, - use_prompt_tuning=use_prompt_tuning, - use_parallel_embedding=use_parallel_embedding, - embedding_sharding_dim=embedding_sharding_dim, - moe_config=moe_config, - max_lora_rank=max_lora_rank, - ) - vocab_size_padded = pad_vocab_size(vocab_size, mapping.tp_size) - share_weight = None - if share_embedding_table: - share_weight = self.vocab_embedding.weight - self.lm_head = ColumnLinear(hidden_size, - vocab_size_padded, - bias=False, - dtype=dtype, - tp_group=mapping.tp_group, - tp_size=mapping.tp_size, - gather_output=True, - share_weight=share_weight) - - def forward(self, - input_ids: Tensor, - position_ids=None, - use_cache=False, - last_token_ids=None, - attention_mask=None, - kv_cache_params=None, - attention_params=None, - prompt_embedding_table=None, - prompt_tasks=None, - prompt_vocab_size=None, - lora_params=None): - - hidden_states = super().forward(input_ids, position_ids, use_cache, - attention_mask, kv_cache_params, - attention_params, - prompt_embedding_table, prompt_tasks, - prompt_vocab_size, lora_params) - - if use_cache: - hidden_states, presents = hidden_states - - hidden_states = gather_last_token_logits( - hidden_states, last_token_ids, - default_net().plugin_config.remove_input_padding) - - # [batch_size, hidden_size] -> [batch_size, vocab_size] - lm_logits = self.lm_head(hidden_states) - lm_logits.mark_output('logits', self._logits_dtype) - - if use_cache: - if not default_net().plugin_config.paged_kv_cache: - for i, present in enumerate(presents): - present.mark_output(f'present_key_value_{i}', - self._kv_dtype) - return (lm_logits, presents) - - return lm_logits - - def prepare_inputs(self, - max_batch_size, - max_input_len, - max_seq_len, - use_cache, - max_beam_width: int = 1, - max_num_tokens: int = None, - prompt_embedding_table_size: int = 0, - gather_context_logits: bool = False, - gather_generation_logits: bool = False, - max_draft_len: int = 0, - lora_target_modules: List[str] = None): - '''@brief: Prepare inputs Tensors for the model, the given sizes are used to determine the - ranges of the dimensions of when using TRT dynamic shapes. - - @return: a list contains values which can be fed into the self.forward() - ''' - - # Prepare inputs - head_size = self._hidden_size // self._num_heads - num_heads_kv = self._num_kv_heads - remove_input_padding = default_net().plugin_config.remove_input_padding - use_gpt_attention_plugin = default_net( - ).plugin_config.gpt_attention_plugin - use_gemm_plugin = default_net().plugin_config.gemm_plugin - paged_kv_cache = default_net().plugin_config.paged_kv_cache - tokens_per_block = default_net().plugin_config.tokens_per_block - use_custom_all_reduce = default_net( - ).plugin_config.use_custom_all_reduce - use_lora_plugin = default_net().plugin_config.lora_plugin - - model_inputs = self.prepare_basic_inputs( - max_batch_size=max_batch_size, - max_beam_width=max_beam_width, - max_input_len=max_input_len, - max_seq_len=max_seq_len, - num_kv_heads=num_heads_kv, - head_size=head_size, - num_layers=self._num_layers, - kv_dtype=self._kv_dtype, - num_heads=self._num_heads, - dtype=self._dtype, - remove_input_padding=remove_input_padding, - use_gpt_attention_plugin=use_gpt_attention_plugin, - use_gemm_plugin=use_gemm_plugin, - use_custom_all_reduce=use_custom_all_reduce, - paged_kv_cache=paged_kv_cache, - tokens_per_block=tokens_per_block, - gather_context_logits=gather_context_logits, - gather_generation_logits=gather_generation_logits, - mapping=self.mapping, - max_num_tokens=max_num_tokens, - prompt_embedding_table_size=prompt_embedding_table_size, - use_lora_plugin=use_lora_plugin, - max_draft_len=max_draft_len, - lora_target_modules=lora_target_modules) - - return ( - model_inputs['input_ids'], - model_inputs['position_ids'], - True, - model_inputs['last_token_ids'], - model_inputs['attention_mask'], - KeyValueCacheParams( - past_key_value=model_inputs['past_key_value'], - host_past_key_value_lengths=model_inputs[ - 'host_past_key_value_lengths'], - host_max_attention_window_sizes=model_inputs[ - 'host_max_attention_window_sizes'], - host_sink_token_length=model_inputs['host_sink_token_length'], - kv_cache_block_pointers=model_inputs['kv_cache_block_pointers'], - host_kv_cache_block_pointers=model_inputs[ - 'host_kv_cache_block_pointers'], - cache_indirection=model_inputs['cache_indirection'], - ), - AttentionParams( - sequence_length=model_inputs['sequence_length'], - context_lengths=model_inputs['context_lengths'], - host_context_lengths=model_inputs['host_context_lengths'], - max_context_length=max_input_len, - host_request_types=model_inputs['host_request_types']), - model_inputs['prompt_embedding_table'], - model_inputs['tasks'], - model_inputs['prompt_vocab_size'], - LoraParams( - model_inputs['lora_ranks'], - model_inputs['lora_weights_pointers'], - host_context_lengths=model_inputs['host_context_lengths'], - max_context_length=max_input_len, - host_request_types=model_inputs['host_request_types']), - ) + if config.share_embedding_table: + share_weight = transformer.vocab_embedding.weight + + lm_head = ColumnLinear(config.hidden_size, + vocab_size_padded, + bias=False, + dtype=config.dtype, + tp_group=config.mapping.tp_group, + tp_size=config.mapping.tp_size, + gather_output=True, + share_weight=share_weight) + super().__init__(config, transformer, lm_head) + + def check_config(self, config: PretrainedConfig): + config.set_if_not_exist('bias', True) + config.set_if_not_exist('apply_query_key_layer_scaling', False) + config.set_if_not_exist('rotary_pct', 1.0) + config.set_if_not_exist('rotary_base', 10000.0) + config.set_if_not_exist('rotary_scaling', None) + config.set_if_not_exist('moe_num_experts', 0) + config.set_if_not_exist('moe_top_k', 0) + config.set_if_not_exist('moe_tp_mode', + MoeConfig.ParallelismMode.TENSOR_PARALLEL) + config.set_if_not_exist( + 'moe_normalization_mode', + MoeConfig.ExpertScaleNormalizationMode.RENORMALIZE) + + def use_lora(self, lora_config: LoraBuildConfig): + use_lora(self, lora_config) diff --git a/tensorrt_llm/models/llama/convert.py b/tensorrt_llm/models/llama/convert.py index dca7baa7e..2ff012c15 100644 --- a/tensorrt_llm/models/llama/convert.py +++ b/tensorrt_llm/models/llama/convert.py @@ -5,6 +5,7 @@ import sys import time from collections import defaultdict +from pathlib import Path from typing import Optional import numpy as np @@ -17,8 +18,7 @@ from transformers.models.llama.modeling_llama import LlamaDecoderLayer from transformers.pytorch_utils import Conv1D -from tensorrt_llm._utils import pad_vocab_size - +from ..._utils import pad_vocab_size, release_gc from ...layers import MoeConfig from ...lora_manager import LoraConfig from ...mapping import Mapping @@ -236,8 +236,10 @@ def smooth_llama_model(model, scales, alpha, llama_qkv_para, llama_smoother): # Smooth the activation and weights with smoother = $\diag{s}$ for name, module in model.named_modules(): if not isinstance( - module, LlamaDecoderLayer - ) and not module.__class__.__name__ == "InternLMDecoderLayer": + module, + LlamaDecoderLayer) and not module.__class__.__name__ in [ + "InternLMDecoderLayer", "MistralDecoderLayer" + ]: continue # qkv_proj layer_name_q = name + ".self_attn.q_proj" @@ -497,7 +499,8 @@ def multi_query_split(data, local_dim, head_size, tp_size, cur_rank): if is_qkv: hidden_dim = cur_weights.shape[0] cur_weights = cur_weights.reshape(hidden_dim, -1) - results[prefix + 'weight'] = torch.from_numpy(cur_weights).t().clone() + results[prefix + 'weight'] = torch.from_numpy( + cur_weights).t().clone().contiguous() if smoother_value is None: results[last_prefix] = torch.from_numpy( np.array([1.0], dtype=np.float32)) @@ -524,7 +527,8 @@ def multi_query_split(data, local_dim, head_size, tp_size, cur_rank): if is_qkv: hidden_dim = cur_weights.shape[0] cur_weights = cur_weights.reshape(hidden_dim, -1) - results[prefix + 'weight'] = torch.from_numpy(cur_weights).t().clone() + results[prefix + 'weight'] = torch.from_numpy( + cur_weights).t().clone().contiguous() cur_per_channel_value = vals["scale_y_accum_quant"] @@ -578,13 +582,14 @@ def convert_hf_llama(hf_model, dtype = getattr(torch, dtype) num_attention_heads = hf_model.config.num_attention_heads hidden_size = hf_model.config.hidden_size + head_size = hidden_size // num_attention_heads intermediate_size = hf_model.config.intermediate_size num_key_value_heads = getattr(hf_model.config, 'num_key_value_heads', num_attention_heads) mha_mode = (num_key_value_heads == num_attention_heads) layers_range = mapping.pp_layers(hf_model.config.num_hidden_layers) - for l in layers_range: + def convert_layer(l): prefix = f'model.layers.{l}.' tllm_prex = f'transformer.layers.{l - layers_range[0]}.' q_weight = get_weight(model_params, prefix + 'self_attn.q_proj', dtype) @@ -592,7 +597,6 @@ def convert_hf_llama(hf_model, v_weight = get_weight(model_params, prefix + 'self_attn.v_proj', dtype) if not mha_mode: - head_size = hidden_size // num_attention_heads if num_key_value_heads < tensor_parallel: # duplicate the KV heads up to tensor_parallel k_weight = dup_kv_weight(k_weight, num_key_value_heads, @@ -630,11 +634,10 @@ def convert_hf_llama(hf_model, qkv_weight = qkv_para[prefix + 'self_attn.qkv_proj'] if not mha_mode: - hidden_size = qkv_weight.shape[0] - local_dim = hidden_size - head_size = (qkv_weight.shape[-1] - local_dim) // 2 - qkv_weight = qkv_weight.reshape(hidden_size, - local_dim + 2 * head_size) + local_dim = qkv_weight.shape[0] + kv_hidden_size = (qkv_weight.shape[-1] - local_dim) // 2 + qkv_weight = qkv_weight.reshape(local_dim, + local_dim + 2 * kv_hidden_size) else: qkv_weight = qkv_weight.reshape(hidden_size, 3, hidden_size) @@ -650,8 +653,7 @@ def convert_hf_llama(hf_model, tllm_prex + 'attention.qkv.', [ 1, 3 * hidden_size // tensor_parallel if mha_mode else hidden_size // tensor_parallel + - (hidden_size // num_key_value_heads) // - tensor_parallel * 2 + (head_size * num_key_value_heads) // tensor_parallel * 2 ], tensor_parallel, is_qkv=True, @@ -791,11 +793,6 @@ def convert_hf_llama(hf_model, moe_experts_gate_weights.to(torch.float32), tllm_prex + 'mlp.router.', None, use_weight_only, plugin_weight_only_quant_type, dtype, use_gemm_woq_plugin)) - del w1, w2, w3, moe_experts_w2_weights, moe_experts_w3w1_weights - - if torch.cuda.is_available(): - torch.cuda.empty_cache() - torch.cuda.ipc_collect() else: mlp_gate_weight = get_weight(model_params, prefix + 'mlp.up_proj', dtype) @@ -910,9 +907,9 @@ def convert_hf_llama(hf_model, for weight_name in cur_block_weights: model_params[weight_name] = None - if torch.cuda.is_available(): - torch.cuda.empty_cache() - torch.cuda.ipc_collect() + for l in layers_range: + convert_layer(l) + release_gc() v = get_weight(model_params, 'model.embed_tokens', dtype) if lora_config.is_valid and lora_config.embedding_weight is not None: @@ -1034,10 +1031,13 @@ def smooth_quant(model, return act_range, llama_qkv_para, llama_smoother -def create_config_from_hugging_face(hf_model, dtype, mapping, - override_fields: dict, **kwargs): +def create_config_from_hugging_face(hf_model, + dtype, + mapping, + quantization: 'QuantizationConfig' = None, + override_fields: dict = {}, + **kwargs): config = {} - assert isinstance(hf_model, str) hf_config = AutoConfig.from_pretrained(hf_model, trust_remote_code=True) if hf_config.model_type == "llava": # LLaVA = Vision model + Llama LLM @@ -1095,7 +1095,7 @@ def create_config_from_hugging_face(hf_model, dtype, mapping, 'mapping': { 'world_size': mapping.tp_size * mapping.pp_size, 'tp_size': mapping.tp_size, - 'pp_size': mapping.pp_size, + 'pp_size': mapping.pp_size }, 'attn_bias': attn_bias, }) @@ -1121,7 +1121,7 @@ def create_config_from_hugging_face(hf_model, dtype, mapping, # the lora checkpoint might finetune the embedding if lora_config.vocab_size != 0: config['vocab_size'] = lora_config.vocab_size - + config['quantization'] = quantization.asdict() config.update(override_fields) moe_config = MoeConfig(config['moe_num_experts'], config['moe_top_k'], @@ -1139,15 +1139,20 @@ def from_hugging_face(cls, dtype, *, mapping, + quantization: 'QuantizationConfig' = None, load_by_shard=False, load_model_on_cpu=False, override_fields={}, hf_lora_dir=None, lora_target_modules=None, - max_lora_rank=None): + max_lora_rank=None, + skip_loading_weights=False, + preloaded_model=None): ''' Create a LLaMAForCausalLM object from give parameters ''' assert model_dir is not None + if isinstance(model_dir, Path): # some code relies on this as string + model_dir = str(model_dir) kwargs = { 'hf_lora_dir': hf_lora_dir, 'lora_target_modules': lora_target_modules, @@ -1164,11 +1169,18 @@ def from_hugging_face(cls, config = create_config_from_hugging_face(model_dir, dtype, mapping, + quantization, override_fields=override_fields, **kwargs) - model = None - # TODO: accept one model from outside of the world - if not load_by_shard: # when load by shard, no need to create complete hf model + + pretrained_config = PretrainedConfig.from_dict(config) + pretrained_config.set_rank(mapping.rank) #TODO: remove this hack + llama = cls.from_config(pretrained_config) + if skip_loading_weights == True: + return llama + + model = preloaded_model + if model is None and not load_by_shard: # when load by shard, no need to create complete hf model hf_config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True) if hf_config.model_type == "llava": @@ -1182,10 +1194,11 @@ def from_hugging_face(cls, torch_dtype='auto', trust_remote_code=True, ) + if load_by_shard: lora_config = create_lora_config(hf_lora_dir) - weights = load_from_hf_checkpoint(model_dir, mapping, - PretrainedConfig.from_dict(config), + + weights = load_from_hf_checkpoint(model_dir, mapping, pretrained_config, lora_config) else: weights = load_weights_from_hf(config=config, @@ -1193,9 +1206,6 @@ def from_hugging_face(cls, model=model, hf_lora_dir=hf_lora_dir) - pretrained_config = PretrainedConfig.from_dict(config) - pretrained_config.set_rank(mapping.rank) #TODO: remove this hack - llama = cls.from_config(pretrained_config) llama.load(weights) return llama @@ -1204,11 +1214,11 @@ def quantize(dtype, model_dir, output_dir, mapping, + quantization: 'QuantizationConfig', *, override_fields, dataset_cache_dir: Optional[str] = None, smoothquant_val: Optional[float] = None, - int8_kv_cache=False, hf_lora_dir=None, lora_target_modules=None, max_lora_rank=None): @@ -1224,6 +1234,7 @@ def quantize(dtype, config = create_config_from_hugging_face(model_dir, dtype, mapping, + quantization, override_fields=override_fields, **kwargs) @@ -1235,12 +1246,16 @@ def quantize(dtype, # smoother for inputs of self_attn.o_proj and mlp.down_proj llama_smoother = {} model = None - assert smoothquant_val is not None or int8_kv_cache - assert model_dir is not None - quant_algo = config['quantization']['quant_algo'] - use_smooth_quant = quant_algo is not None and quant_algo.startswith( + assert config['quantization']['quant_algo'] == quantization.quant_algo + int8_kv_cache = quantization.kv_cache_quant_algo == "INT8" + use_smooth_quant = quantization.quant_algo is not None and quantization.quant_algo.startswith( 'W8A8_SQ') + assert use_smooth_quant or int8_kv_cache, "Call from_hugging_face when there is no quantization" + if use_smooth_quant: + assert smoothquant_val is not None, "A smooth value must be specified when using smooth quant" + + assert model_dir is not None ## only load and call smooth quant routine once for all ranks hf_config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True) assert "llava" not in hf_config.model_type, "Smooth quant llava/vila is not supported yet" @@ -1270,6 +1285,8 @@ def quantize(dtype, ) safetensors.torch.save_file( weights, os.path.join(output_dir, f'rank{rank}.safetensors')) + del weights + release_gc() def load_weights_from_hf(*, @@ -1307,11 +1324,12 @@ def load_weights_from_hf(*, vocab_size=config['vocab_size'], dtype=config['dtype'], use_weight_only=use_weight_only, - use_gemm_woq_plugin=not config['disable_weight_only_quant_plugin'], + use_gemm_woq_plugin=not config.get('disable_weight_only_quant_plugin', + False), plugin_weight_only_quant_type=plugin_weight_only_quant_type, - use_parallel_embedding=config['use_parallel_embedding'], - sharding_dim=config['embedding_sharding_dim'], - share_embedding_table=config['share_embedding_table'], + use_parallel_embedding=config.get('use_parallel_embedding', False), + sharding_dim=config.get('embedding_sharding_dim', 0), + share_embedding_table=config.get('share_embedding_table', False), use_smooth_quant=use_smooth_quant, per_channel=per_channel_sq, per_token=per_token_sq, diff --git a/tensorrt_llm/models/llama/model.py b/tensorrt_llm/models/llama/model.py index 092152ddc..48c760b99 100644 --- a/tensorrt_llm/models/llama/model.py +++ b/tensorrt_llm/models/llama/model.py @@ -17,26 +17,21 @@ from pathlib import Path from typing import Optional -from transformers import AutoConfig, AutoModelForCausalLM - -from tensorrt_llm.models.llama.weight import (load_from_awq_llama, - load_from_fp8_llama) - -from ... import profiler from ..._utils import pad_vocab_size -from ...functional import RotaryScalingType, Tensor, recv, send +from ...functional import Tensor, recv, send from ...layers import (MOE, Attention, AttentionMaskType, ColumnLinear, Embedding, GatedMLP, MoeConfig, PositionEmbeddingType, - PromptTuningEmbedding, RmsNorm) -from ...lora_manager import LoraConfig + RmsNorm) +from ...lora_manager import LoraBuildConfig, use_lora from ...mapping import Mapping from ...module import Module from ...plugin import init_all_reduce_helper +# this is to use to module global algo string with a quant_algo prefix from ...quantization import QuantMode +from ...quantization import mode as quant_algo from ...top_model_mixin import TopModelMixin from ..modeling_utils import (DecoderLayerList, DecoderModelForCausalLM, - PretrainedConfig) -from .weight import load_from_hf_llama + PretrainedConfig, QuantizationConfig) class LLaMADecoderLayer(Module): @@ -67,8 +62,7 @@ def __init__(self, config: PretrainedConfig, layer_idx: int): tp_group=config.mapping.tp_group, tp_size=config.mapping.tp_size, tp_rank=config.mapping.tp_rank, - quant_mode=config.quant_mode, - max_lora_rank=config.max_lora_rank) + quant_mode=config.quant_mode) mlp_hidden_size = config.hidden_size * 4 if config.intermediate_size is None else config.intermediate_size @@ -96,7 +90,6 @@ def __init__(self, config: PretrainedConfig, layer_idx: int): tp_group=config.mapping.tp_group, tp_size=config.mapping.tp_size, quant_mode=config.quant_mode, - max_lora_rank=config.max_lora_rank, **mlp_kwargs) self.post_layernorm = RmsNorm(normalized_shape=config.hidden_size, eps=config.norm_epsilon, @@ -149,10 +142,8 @@ def __init__(self, config: PretrainedConfig) -> None: init_all_reduce_helper() self.mapping = config.mapping - self.use_prompt_tuning = config.use_prompt_tuning - EmbeddingCls = PromptTuningEmbedding if config.use_prompt_tuning else Embedding if self.mapping.is_first_pp_rank(): - self.vocab_embedding = EmbeddingCls( + self.vocab_embedding = Embedding( num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, dtype=config.dtype, @@ -194,7 +185,7 @@ def forward( ptuning_args = [ prompt_embedding_table, prompt_tasks, prompt_vocab_size - ] if self.use_prompt_tuning else [] + ] if prompt_embedding_table is not None else [] if self.mapping.is_first_pp_rank(): hidden_states = self.vocab_embedding(input_ids, *ptuning_args) @@ -265,197 +256,49 @@ def from_hugging_face(cls, mapping: Optional[Mapping] = None, quant_mode: Optional[QuantMode] = None, **kwargs): - cfg = AutoConfig.from_pretrained(hf_model_dir) - - num_kv_heads = cfg.num_key_value_heads if hasattr(cfg, "num_key_value_heads") \ - else cfg.num_attention_heads - use_fused_mlp = kwargs.get("use_fused_mlp", False) - mapping = mapping or Mapping() - quant_mode = quant_mode or QuantMode(0) - - cfg.mapping = mapping - - cfg.dtype = dtype - cfg.quant_mode = quant_mode - - cfg.norm_epsilon = cfg.rms_norm_eps - - if cfg.model_type == 'mixtral': - moe_config = MoeConfig( - num_experts=cfg.num_local_experts, - top_k=cfg.num_experts_per_tok, - tp_mode=kwargs.get("moe_tp_mode", - MoeConfig.ParallelismMode.TENSOR_PARALLEL), - normalization_mode=kwargs.get( - "moe_normalization_mode", - MoeConfig.ExpertScaleNormalizationMode.RENORMALIZE), - ).validate() - # HF LLaMA-type models are implicitly using gated activation. - # With our MoE implementation, we must make it explicit - cfg.hidden_act = 'swiglu' - cfg.rotary_base = cfg.rope_theta - else: - moe_config = MoeConfig() - config = { - 'architecture': cfg.architectures[0], - 'dtype': cfg.dtype, - 'logits_dtype': 'float32', - 'num_hidden_layers': cfg.num_hidden_layers, - 'num_attention_heads': cfg.num_attention_heads, - 'hidden_size': cfg.hidden_size, - 'intermediate_size': cfg.intermediate_size, - 'num_key_value_heads': num_kv_heads, - 'vocab_size': cfg.vocab_size, - 'position_embedding_type': 'rope_gpt_neox', - 'max_position_embeddings': cfg.max_position_embeddings, - 'hidden_act': cfg.hidden_act, - 'rotary_base': getattr(cfg, 'rotary_base', 10000.0), - 'rotary_scaling': getattr(cfg, 'rotary_scaling', None), - 'norm_epsilon': cfg.rms_norm_eps, - 'quantization': { - 'group_size': 128, - }, - "moe_config": { - "num_experts": moe_config.num_experts, - "top_k": moe_config.top_k, - "tp_mode": moe_config.tp_mode, - "normalization_mode": moe_config.normalization_mode, - }, - 'use_parallel_embedding': kwargs.get("use_parallel_embedding", - False), - 'embedding_sharding_dim': kwargs.get("embedding_sharding_dim", 0), - 'use_prompt_tuning': kwargs.get("use_prompt_tuning", False), - 'moe_num_experts': moe_config.num_experts, - 'moe_top_k': moe_config.top_k, - 'moe_tp_mode': moe_config.tp_mode, - 'moe_normalization_mode': moe_config.normalization_mode, - 'use_fused_mlp': use_fused_mlp, - } - if quant_mode.is_int4_weight_only_per_group(): - config['quantization'].update({ - 'quant_algo': 'W4A16_AWQ', - 'has_zero_point': False, - 'pre_quant_scale': True, - 'exclude_modules': [], - }) - elif quant_mode.has_fp8_qdq() and quant_mode.has_fp8_kv_cache(): - config['quantization'].update({ - 'quant_algo': 'FP8', - 'kv_cache_quant_algo': 'FP8' - }) - else: - if quant_mode != QuantMode(0): - raise ValueError(f"Unsupported quantization mode: {quant_mode}") - - model_config = PretrainedConfig.from_dict(config) - model_config.mapping = mapping - tllm_llama = LLaMAForCausalLM(model_config) - q_weights = {} - if quant_mode.has_any_quant(): - q_weights = tllm_llama._quantize(hf_model_dir, dtype, cfg, **kwargs) - - # For debug purpose, skip weights loading to be faster - if kwargs.get("skip_loading_weights", False): + from . import convert + if quant_mode is not None and quant_mode.has_any_quant(): + #TODO: TRTLLM-208 delete this after LLM class calls .quantize directly + quantized_temp_dir = tempfile.TemporaryDirectory("llama-quantized") + quantized_checkpoint_path = kwargs.get("quantization_cache_dir", + quantized_temp_dir.name) + quant_config = QuantizationConfig() + if quant_mode.has_fp8_qdq(): + quant_config.quant_algo = "FP8" + if quant_mode.has_fp8_kv_cache(): + quant_config.kv_cache_quant_algo = "FP8" + elif quant_mode.is_int4_weight_only_per_group(): + quant_config.quant_algo = 'W4A16_AWQ' + cls.quantize(hf_model_dir, + quantized_checkpoint_path, + quant_config, + dtype=dtype, + mapping=mapping) + tllm_llama = LLaMAForCausalLM.from_checkpoint( + quantized_checkpoint_path, rank=mapping.rank) return tllm_llama - - # weights already loaded in _quantize for int4 weight only - if not quant_mode.is_int4_weight_only_per_group(): - profiler.start("Loading weights from HF") - hf_llama = AutoModelForCausalLM.from_pretrained( + else: + # TODO: TRTLLM-180 the original convert_checkpoint use QuantizationConfig + # while the high level api uses the QuantMode, needs to be unified + # here it's a hacky before the conversion is done, we assume the convert_checkpoint.py + # always passes the quant_mode==None for now. + if mapping is None: + mapping = Mapping() + llama = convert.from_hugging_face( + cls, hf_model_dir, - device_map={ - "model": "cpu", - "lm_head": "cpu", - "embed_tokens": "cpu", - "layers": "cpu", - "norm": "cpu", - }, # Load to CPU memory - torch_dtype='auto', - ) - - weights = load_from_hf_llama( - tllm_llama, - hf_llama, + dtype, mapping=mapping, - dtype=dtype, - # TODO: these shall be outside from_hugging_face too. - use_gemm_woq_plugin=kwargs.get("use_gemm_woq_plugin", False), - lora_config=kwargs.get("lora_config", LoraConfig()), - ) - profiler.stop("Loading weights from HF") - del hf_llama - weights.update(q_weights) - tllm_llama.load(weights) - else: - tllm_llama.load(q_weights) - return tllm_llama - - def _quantize(self, hf_model_dir, dtype, cfg, **kwargs): - '''Given the quant_mode set in the Module object, read from given hf model - call AMMO to generate quantization scales, and set the scales back the module parameters. - ''' - # use self destructed temporary path if kwargs[quantization_cache_dir] is not specified - # sometimes the quantization checkpoint path needs to be saved for debug purpose - quantized_temp_dir = tempfile.TemporaryDirectory("llama-quantized") - quantized_checkpoint_path = kwargs.get("quantization_cache_dir", - quantized_temp_dir.name) - quantize_lm_head = kwargs.get("quantize_lm_head", False) - quant_mode = cfg.quant_mode - ammo_qformat = None - calib_size = None - if quant_mode.has_fp8_qdq() or quant_mode.has_fp8_kv_cache(): - ammo_qformat = 'fp8' - calib_size = 512 - # TODO: how to distinguish from quant_mode about int4_awq or int4_gptq? - elif quant_mode.is_int4_weight_only_per_group(): - ammo_qformat = 'int4_awq' - calib_size = 32 - assert ammo_qformat is not None - - # local import to avoid pytest issue when importing AMMO and transformers lib - from .quantize import quantize_llama_and_export - quantize_llama_and_export(hf_model_dir, - quantized_checkpoint_path, - ammo_qformat, - dtype, - calib_size=calib_size, - quantize_lm_head=quantize_lm_head) - - ckpt = Path(quantized_checkpoint_path) / "llama_tp1_rank0.npz" - assert ckpt.exists(), f"The expecting checkpoint path {ckpt} does not exist" \ - "it's likely quantization failed, pls check error logs" - hf_config = AutoConfig.from_pretrained(hf_model_dir, - trust_remote_code=True) - if ammo_qformat == 'fp8': - return load_from_fp8_llama( - str(ckpt), - hf_config.num_hidden_layers, - cfg.mapping, - fp8_kv_cache=quant_mode.has_fp8_kv_cache()) - else: - return load_from_awq_llama(str(ckpt), - hf_config.num_hidden_layers, - hf_config.vocab_size, - cfg.mapping, - dtype=dtype) - - # llama specific setters, user shall has the chance to change the module attributes after - # from_hugging_face factory method created the model when these attributes is not included in the huggingface checkpoint - - def rotary_base(self, val): - for decoder in self.layers: - decoder.attention.rotary_embedding_base = val - return self - - def rotary_scaling(self, scaling_type, factor): - # TODO: what if there are some other behaviors triggered by the these changes? - # should implement these assignment as setters of the Attention Module - assert scaling_type in ("linear", "dynamic"), f"Got {scaling_type}" - assert factor > 1.0, f"Got {factor}" - for decoder in self.layers: - decoder.attention.rotary_embedding_scale_type = RotaryScalingType.linear if scaling_type == "linear" else RotaryScalingType.dynamic - decoder.attention.rotary_embedding_scale = factor - return self + quantization=kwargs.get('quantization', QuantizationConfig()), + load_by_shard=kwargs.get('load_by_shard', False), + load_model_on_cpu=kwargs.get('load_model_on_cpu', False), + override_fields=kwargs.get('override_fields', {}), + hf_lora_dir=kwargs.get('hf_lora_dir', None), + lora_target_modules=kwargs.get('lora_target_modules', None), + max_lora_rank=kwargs.get('max_lora_rank', None), + skip_loading_weights=kwargs.get('skip_loading_weights', False), + preloaded_model=kwargs.get('preloaded_model', None)) + return llama def default_plugin_config(self, **kwargs): plugin_config = super().default_plugin_config(**kwargs) @@ -468,8 +311,8 @@ def from_meta_ckpt(cls, meta_ckpt_dir, dtype, mapping, - override_fileds=None, - **kwargs): + use_parallel_embedding: Optional[bool] = False, + embedding_sharding_dim: Optional[int] = 0): meta_config = None with open(Path(meta_ckpt_dir, "params.json")) as fp: meta_config: dict = json.load(fp) @@ -477,12 +320,7 @@ def from_meta_ckpt(cls, config = {} n_embd = meta_config["dim"] n_head = meta_config["n_heads"] - n_layer = meta_config["n_layers"] n_kv_head = meta_config.get("n_kv_heads", n_head) - # meta checkpoint don't have vocab_size|hidden_act|rotary_base specified, need to read from user input - vocab_size = 32000 - hidden_act = 'silu' - rotary_base = 10000.0 if "hidden_dim" in meta_config: inter_size = meta_config["hidden_dim"] else: @@ -492,45 +330,94 @@ def from_meta_ckpt(cls, inter_size = multiple_of * ( (int(n_embd_ * ffn_dim_multiplier) + multiple_of - 1) // multiple_of) - rms_norm_eps = meta_config["norm_eps"] - moe_num_experts = meta_config.get("moe", {}).get("num_experts", 0) - moe_top_k = meta_config.get("moe", {}).get("num_experts_per_tok", 0) - moe_tp_mode = None - n_positions = 2048 - config['moe_normalization_mode'] = None # meta checkpoint has no moe - architecture = "LlamaForCausalLM" - # config values from reading meta + # meta checkpoint don't have vocab_size|hidden_act|rotary_base specified, use same default value as HF config.update({ - 'architecture': architecture, + 'architecture': "LlamaForCausalLM", 'dtype': dtype, 'logits_dtype': 'float32', - 'num_hidden_layers': n_layer, + 'num_hidden_layers': meta_config["n_layers"], 'num_attention_heads': n_head, 'hidden_size': n_embd, 'intermediate_size': inter_size, 'num_key_value_heads': n_kv_head, - 'vocab_size': vocab_size, + 'vocab_size': 32000, 'position_embedding_type': 'rope_gpt_neox', - 'max_position_embeddings': n_positions, - 'hidden_act': hidden_act, - 'rotary_base': rotary_base, - 'norm_epsilon': rms_norm_eps, - 'moe_num_experts': moe_num_experts, - 'moe_top_k': moe_top_k, - 'moe_tp_mode': moe_tp_mode, + 'max_position_embeddings': 2048, + 'hidden_act': 'silu', + 'rotary_base': 10000.0, + 'norm_epsilon': meta_config["norm_eps"], 'mapping': { 'world_size': mapping.tp_size * mapping.pp_size, 'tp_size': mapping.tp_size, 'pp_size': mapping.pp_size, }, }) - config.update(override_fileds) - pretained_config = PretrainedConfig.from_dict(config) - pretained_config.set_rank( - mapping.rank - ) #TODO: remove the need of calling this, it's hacky design - llama = cls(pretained_config) + pretrained_config = PretrainedConfig.from_dict(config) + pretrained_config.use_parallel_embedding = use_parallel_embedding + pretrained_config.embedding_sharding_dim = embedding_sharding_dim + pretrained_config.set_rank(mapping.rank) + + llama = cls(pretrained_config) from .weight import load_from_meta_llama - weights = load_from_meta_llama(meta_ckpt_dir, mapping, pretained_config) + weights = load_from_meta_llama(meta_ckpt_dir, mapping, + pretrained_config) llama.load(weights) return llama + + @classmethod + def quantize( + cls, + hf_model_dir, + output_dir, + quant_config: QuantizationConfig, + *, + dtype='float16', + mapping: Optional[Mapping] = None, + calib_batches=512, + calib_batch_size=1, + random_seed=1234, + tokenizer_max_seq_length=2048, + **kwargs, + ): + DEFAULT_AMMO_FLOW = [ + quant_algo.W4A16_AWQ, quant_algo.FP8, + quant_algo.W8A8_SQ_PER_CHANNEL, quant_algo.W4A8_AWQ + ] + use_ammo_quantization = quant_config.quant_algo in DEFAULT_AMMO_FLOW + if use_ammo_quantization: + super().quantize(hf_model_dir, + output_dir, + quant_config, + dtype=dtype, + mapping=mapping, + calib_batches=calib_batches, + calib_batch_size=calib_batch_size, + random_seed=random_seed, + tokenizer_max_seq_length=tokenizer_max_seq_length) + else: + # non-ammo, the legacy TRT-LLM native quantization algorithm: + # sq, int4/int8 weights only, int8 kv cache + NATIVE_QUANT_FLOW = [quant_algo.W4A16, quant_algo.W8A16, None + ] + quant_algo.W8A8_SQ_PLUGIN_LIST + is_valid_native_quant = (quant_config.quant_algo in NATIVE_QUANT_FLOW) and \ + (quant_config.kv_cache_quant_algo in [quant_algo.INT8, None]) + assert quant_config.quant_algo is not None or quant_config.kv_cache_quant_algo is not None, \ + "There is no point to call the quantize function if both quant_algo and kv_cache_quant_algo is None" + assert is_valid_native_quant, f"Internal error: shall call AMMO for this quantization {quant_config}" + + from . import convert + convert.quantize( + dtype, + hf_model_dir, + output_dir, + mapping, + quant_config, + override_fields=kwargs.get('override_fields', {}), + dataset_cache_dir=kwargs.get('dataset_cache_dir', None), + smoothquant_val=kwargs.get('smoothquant_val', None), + hf_lora_dir=kwargs.get('hf_lora_dir', None), + lora_target_modules=kwargs.get('lora_target_modules', None), + max_lora_rank=kwargs.get('max_lora_rank', None)) + + def use_lora(self, lora_config: LoraBuildConfig): + use_lora(self, lora_config) diff --git a/tensorrt_llm/models/llama/quantize.py b/tensorrt_llm/models/llama/quantize.py deleted file mode 100644 index 86fc037fe..000000000 --- a/tensorrt_llm/models/llama/quantize.py +++ /dev/null @@ -1,140 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Adapted from examples/llama/quantize.py -""" - -import random -from typing import Optional - -import numpy as np -import torch -from datasets import load_dataset -from torch.utils.data import DataLoader -from transformers import AutoModelForCausalLM, AutoTokenizer - -from ..._utils import str_dtype_to_torch -from ...logger import logger -from ...models.quantized.ammo import quantize_and_export - - -def get_calib_dataloader(data="cnn_dailymail", - tokenizer=None, - batch_size=1, - calib_size=512, - block_size=512, - cache_dir=None): - print("Loading calibration dataset") - if data == "pileval": - dataset = load_dataset( - "json", - data_files="https://the-eye.eu/public/AI/pile/val.jsonl.zst", - split="train", - cache_dir=cache_dir) - dataset = dataset["text"][:calib_size] - elif data == "cnn_dailymail": - dataset = load_dataset("cnn_dailymail", - name="3.0.0", - split="train", - cache_dir=cache_dir) - dataset = dataset["article"][:calib_size] - else: - raise NotImplementedError - - dataset_input_ids = tokenizer(dataset, - return_tensors="pt", - padding=True, - truncation=True, - max_length=block_size).input_ids.cuda() - - calib_dataloader = DataLoader(dataset_input_ids, - batch_size=batch_size, - shuffle=False) - - return calib_dataloader - - -def get_tokenizer(ckpt_path, **kwargs): - logger.info(f"Loading tokenizer from {ckpt_path}") - tokenizer = AutoTokenizer.from_pretrained(ckpt_path, - padding_side="left", - **kwargs) - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token - return tokenizer - - -def get_model(ckpt_path, dtype="float16", cache_dir=None): - logger.info(f"Loading model from {ckpt_path}") - torch_dtype = str_dtype_to_torch(dtype) - model = AutoModelForCausalLM.from_pretrained( - ckpt_path, - device_map="cuda", - cache_dir=cache_dir, - trust_remote_code=True, - torch_dtype=torch_dtype, - ) - model.eval() - model = model.to(memory_format=torch.channels_last) - return model - - -def quantize_llama_and_export(hf_model_dir, - export_path, - qformat: str = 'fp8', - dtype: Optional[str] = 'float16', - calib_size: Optional[int] = 512, - hf_cache_dir: Optional[str] = None, - seed: Optional[int] = None, - quantize_lm_head=False): - ''' - Quantize a llama model from HF model dir and save it as export_path. - Parameters: - hf_model_dir: huggingface model directory - export_path: a path to save the quantized weights and scales tensors - qformat: quantization format, currently 'int4_awq' and 'fp8' are supported - dtype: the datatype to run the HF/pytorch model forward during quantization - calib_size: Number of samples for calibration. - seed: the seed to be used in the random and np.random package during quantization - - Return: None, raises exception if the quantization failed due to any reason. - ''' - assert qformat in ['int4_awq', 'fp8' - ], "More quantization format supported in future release" - if not torch.cuda.is_available(): - raise EnvironmentError("GPU is required for inference.") - if seed is not None: - random.seed(seed) - np.random.seed(seed) - - tokenizer = get_tokenizer(hf_model_dir, cache_dir=hf_cache_dir) - model = get_model(hf_model_dir, dtype, cache_dir=hf_cache_dir) - - calib_dataloader = get_calib_dataloader(tokenizer=tokenizer, - calib_size=calib_size, - cache_dir=hf_cache_dir) - quant_cfg_dict = {} - if quantize_lm_head: - quant_cfg_dict.update({ - "*lm_head*": { - "enable": True - }, - }) - - model = quantize_and_export(model, - qformat=qformat, - calib_dataloader=calib_dataloader, - export_path=export_path, - quant_cfg_dict=quant_cfg_dict) diff --git a/tensorrt_llm/models/llama/weight.py b/tensorrt_llm/models/llama/weight.py index 0fa04d5de..ca166f93c 100644 --- a/tensorrt_llm/models/llama/weight.py +++ b/tensorrt_llm/models/llama/weight.py @@ -733,111 +733,16 @@ def load_from_hf_llama(tensorrt_llm_llama: 'LLaMAForCausalLM', return weights -def load_from_fp8_llama(quant_ckpt_path: str, num_hidden_layers: int, - mapping: Mapping, fp8_kv_cache: bool): - """ - Get the fp8 scaling factors for Falcon model. - """ - fake_fp8_sf_dt = torch.float32 - fp8_llama = np.load(quant_ckpt_path) - weights = {} - - layers_range = mapping.pp_layers(num_hidden_layers) - for l in layers_range: - prefix = f'_np:layers:{l}' - tllm_prex = f'transformer.layers.{l-layers_range[0]}' - - weights[f'{tllm_prex}.attention.qkv.activation_scaling_factor'] = torch.tensor( - [ - max( - fp8_llama[ - f'{prefix}:attention:qkv:q:activation_scaling_factor']. - item(), fp8_llama[ - f'{prefix}:attention:qkv:k:activation_scaling_factor']. - item(), fp8_llama[ - f'{prefix}:attention:qkv:v:activation_scaling_factor']. - item()) - ], - dtype=fake_fp8_sf_dt) - weights[ - f'{tllm_prex}.attention.qkv.weights_scaling_factor'] = torch.tensor( - [ - max( - fp8_llama[ - f'{prefix}:attention:qkv:q:weights_scaling_factor']. - item(), fp8_llama[ - f'{prefix}:attention:qkv:k:weights_scaling_factor']. - item(), fp8_llama[ - f'{prefix}:attention:qkv:v:weights_scaling_factor']. - item()) - ], - dtype=fake_fp8_sf_dt) - weights[ - f'{tllm_prex}.attention.dense.activation_scaling_factor'] = torch.tensor( - [ - fp8_llama[ - f'{prefix}:attention:dense:activation_scaling_factor']. - item() - ], - dtype=fake_fp8_sf_dt) - weights[ - f'{tllm_prex}.attention.dense.weights_scaling_factor'] = torch.tensor( - [ - fp8_llama[ - f'{prefix}:attention:dense:weights_scaling_factor']. - item() - ], - dtype=fake_fp8_sf_dt) - - weights[f'{tllm_prex}.mlp.fc.activation_scaling_factor'] = torch.tensor( - [fp8_llama[f'{prefix}:mlp:fc:activation_scaling_factor'].item()], - dtype=fake_fp8_sf_dt) - weights[f'{tllm_prex}.mlp.fc.weights_scaling_factor'] = torch.tensor( - [fp8_llama[f'{prefix}:mlp:fc:weights_scaling_factor'].item()], - dtype=fake_fp8_sf_dt) - - weights[ - f'{tllm_prex}.mlp.gate.activation_scaling_factor'] = torch.tensor( - [ - fp8_llama[f'{prefix}:mlp:gate:activation_scaling_factor']. - item() - ], - dtype=fake_fp8_sf_dt) - weights[f'{tllm_prex}.mlp.gate.weights_scaling_factor'] = torch.tensor( - [fp8_llama[f'{prefix}:mlp:gate:weights_scaling_factor'].item()], - dtype=fake_fp8_sf_dt) - - weights[ - f'{tllm_prex}.mlp.proj.activation_scaling_factor'] = torch.tensor( - [ - fp8_llama[f'{prefix}:mlp:proj:activation_scaling_factor']. - item() - ], - dtype=fake_fp8_sf_dt) - weights[f'{tllm_prex}.mlp.proj.weights_scaling_factor'] = torch.tensor( - [fp8_llama[f'{prefix}:mlp:proj:weights_scaling_factor'].item()], - dtype=fake_fp8_sf_dt) - - if fp8_kv_cache: - # Not calibrating KV cache. - scaling_factor = 1.0 - weights[ - f'{tllm_prex}.attention.kv_cache_scaling_factor'] = torch.tensor( - [scaling_factor], dtype=fake_fp8_sf_dt) - - return weights - - -def load_from_gptq_llama(quant_ckpt_path, - num_hidden_layers=None, - vocab_size=32000, - mapping=Mapping(), - dtype="float16", - bin_model_dir=None): +def load_from_gptq_llama(config: PretrainedConfig, quant_ckpt_path): logger.info('Loading weights from groupwise GPTQ LLaMA safetensors...') weights = {} tik = time.time() + num_hidden_layers = config.num_hidden_layers + vocab_size = config.vocab_size + dtype = config.dtype + mapping = config.mapping + gptq_llama = safe_open(quant_ckpt_path, framework="pt", device=0) gptq_prefix = "model." gptq_suffix_list = [".qweight", ".qzeros", ".scales"] @@ -1011,7 +916,7 @@ def process_and_assign_weight(v: List[torch.Tensor], return weights -def load_from_meta_llama(meta_ckpt_dir, mapping=Mapping(), config=None): +def load_from_meta_llama(meta_ckpt_dir, mapping, config): torch_dtype = str_dtype_to_torch(config.dtype) weights = {} @@ -1034,10 +939,6 @@ def split_ckpt(ckpt, ranks_per_ckpt, ckpt_rank): if any(n in k for n in ["wo", "feed_forward.w2", "tok", "feed_forward.gate"]): d = 1 - if "feed_forward.experts" in k and ("w2" in k) == ( - not quant_mode.is_weight_only()): - d = 1 - if "norm" in k or "rope" in k: # no TP split_ckpt[k] = v.clone() elif config.num_key_value_heads < mapping.tp_size and any( @@ -1120,12 +1021,6 @@ def gather_embedding(cur_embed, name: str, num_ckpts): logger.info('Loading weights from Meta LLaMA checkpoints ...') tik = time.time() - quant_mode = config.quant_mode - if quant_mode.is_int8_weight_only(): - torch.int8 - elif quant_mode.is_int4_weight_only(): - torch.quint4x2 - quant_mode.is_weight_only() num_kv_heads = config.num_key_value_heads mha_mode = (num_kv_heads == config.num_attention_heads) @@ -1158,25 +1053,6 @@ def gather_embedding(cur_embed, name: str, num_ckpts): qkv_weight = torch.cat([q_weight, k_weight, v_weight], dim=0) ckpt[prefix + 'qkv.weight'] = qkv_weight - moe_config = MoeConfig(config.moe_num_experts, config.moe_top_k, - config.moe_tp_mode, config.moe_normalization_mode) - for l in layers_range: - if not moe_config.has_moe(): - continue - - rank_experts = list(range(moe_config.num_experts)) - if moe_config.tp_mode == moe_config.ParallelismMode.EXPERT_PARALLEL: - rank_experts = mapping.ep_experts(moe_config.num_experts) - for suffix in ["w1", "w2", "w3"]: - ckpt[f'layers.{l}.feed_forward.experts.{suffix}.weight'] = \ - torch.stack(list(ckpt[f'layers.{l}.feed_forward.experts.{expert}.{suffix}.weight'] - for expert in rank_experts)) - - # concat w3 and w1 for gated expert - ckpt[f'layers.{l}.feed_forward.experts.w3w1.weight'] = \ - torch.concat([ckpt[f'layers.{l}.feed_forward.experts.w3.weight'], - ckpt[f'layers.{l}.feed_forward.experts.w1.weight']], dim=-2) - for k, v in ckpt.items(): dtype = torch_dtype if 'feed_forward.gate' not in k else torch.float32 @@ -1229,10 +1105,6 @@ def gather_embedding(cur_embed, name: str, num_ckpts): weights[tllm_prex + 'attention.dense.weight'] = v elif 'attention.qkv.weight' in k: weights[tllm_prex + 'attention.qkv.weight'] = v - elif 'experts.w2.weight' in k: - weights[tllm_prex + 'mlp.experts_weight_2'] = v - elif 'experts.w3w1.weight' in k: - weights[tllm_prex + 'mlp.experts_weight_1'] = v elif 'feed_forward.gate' in k: weights[tllm_prex + 'mlp.router.weight'] = v @@ -1240,290 +1112,3 @@ def gather_embedding(cur_embed, name: str, num_ckpts): t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) logger.info(f'Weights loaded. Total time: {t}') return weights - - -def load_from_awq_llama(quant_ckpt_path, - num_hidden_layers, - vocab_size, - quantize_lm_head=False, - mapping=Mapping(), - dtype="float16", - bin_model_dir=None): - - weights = {} - - if quant_ckpt_path.endswith(".pt"): - awq_llama = torch.load(quant_ckpt_path) - awq_prefix = "model." - awq_suffix_list = [ - ".weight", - ".weight_quantizer._amax", - ".input_quantizer._pre_quant_scale", - ] - awq_key_list = [ - "embed_tokens.weight", # vocab_embedding - "lm_head", # lm_head - "norm.weight", # ln_f - "self_attn.", # attention.qkv - "_proj", # qkv suffix - "self_attn.o_proj", # attention.dense - "mlp.up_proj", # mlp.gate - "mlp.down_proj", # mlp.proj - "mlp.gate_proj", # mlp.fc - "input_layernorm.weight", # input_layernorm - "post_attention_layernorm.weight", # post_layernorm - ] - split_sym = "." - - def load(key): - if "lm_head" in key: - v = awq_llama[key] - else: - v = awq_llama[awq_prefix + key] - return v - - group_size = load("layers.0.self_attn.o_proj.weight").numel() // load( - "layers.0.self_attn.o_proj.weight_quantizer._amax").numel() - elif quant_ckpt_path.endswith(".npz"): - awq_llama = np.load(quant_ckpt_path) - awq_prefix = "_np:" - awq_suffix_list = [ - ":weight", - ":weights_scaling_factor", - ":prequant_scaling_factor", - ] - awq_key_list = [ - "vocab_embedding:weight", # vocab_embedding - "lm_head", # lm_head - "final_layernorm:weight", # ln_f - "attention:qkv:", # attention.qkv - "", # qkv suffix - "attention:dense", # attention.dense - "mlp:gate", # mlp.gate - "mlp:proj", # mlp.proj - "mlp:fc", # mlp.fc - "input_layernorm:weight", # input_layernorm - "post_layernorm:weight", # post_layernorm - ] - split_sym = ":" - - def load(key): - v = torch.from_numpy(awq_llama[awq_prefix + key]) - if "weights_scaling_factor" in key: - v *= 7 # For AMMO *.npz checkpoints - return v - - group_size = load("layers:0:attention:dense:weight").numel() // load( - "layers:0:attention:dense:weights_scaling_factor").numel() - else: - assert False, "Unsupported AWQ quantized checkpoint format" - - # quant_mode = getattr(tensorrt_llm_llama, 'quant_mode', QuantMode(0)) - # Int8 KV cache - # use_int8_kv_cache = quant_mode.has_int8_kv_cache() - - packer = torch.ops.trtllm.pack_int8_tensor_to_packed_int4 - preprocessor = torch.ops.trtllm.preprocess_weights_for_mixed_gemm - torch_dtype = str_dtype_to_torch(dtype) - - # def fromfile(dir_path, name, shape=None, dtype=None): - # p = dir_path + '/' + name - # if Path(p).exists(): - # t = np.fromfile(p, dtype=dtype) - # if shape is not None: - # t = t.reshape(shape) - # return t - # return None - - def torch_split(v, dim): - if v.shape[dim] % mapping.tp_size != 0: - logger.error( - "Current weight shape is invalid for mapping.tp_size=" + - str(mapping.tp_size)) - assert False, "Invalid TP size" - return v.split(v.shape[dim] // mapping.tp_size, - dim=dim)[mapping.tp_rank] - - def AWQ_quantize_pack_preprocess(weight, scale): - weight /= scale.repeat_interleave(group_size, dim=0) - qweight_int8 = torch.clamp(torch.round(weight.cuda()).char(), -8, 7) - int4_weight = preprocessor(packer(qweight_int8.cpu()), torch.quint4x2) - return int4_weight.view(torch.float16) - - def get_tllm_weight_from_awq(v: List[torch.Tensor], - tllm_prex: str, - tp_dim: int = 0): - weight = v[0].T.contiguous() - [k, n] = weight.shape - weight = torch_split(weight, tp_dim) - amax = v[1].reshape((n, k // group_size)).T.contiguous() - amax = torch_split(amax, tp_dim) - pre_quant_scale = v[2].reshape((1, k)) - if tp_dim == 0: - pre_quant_scale = torch_split(pre_quant_scale, 1) - scale = amax / 8.0 - results = { - f'{tllm_prex}.weight': AWQ_quantize_pack_preprocess(weight, scale), - f'{tllm_prex}.weights_scaling_factor': scale.to(torch_dtype), - f'{tllm_prex}.prequant_scaling_factor': - pre_quant_scale.to(torch_dtype), - } - return results - - def reSmooth_and_get_scale(weight, pre_quant_scale, avg_pre_quant_scale): - # deSmooth and reSmooth - [k, n] = weight.shape - - if quant_ckpt_path.endswith("pt"): - # NPZ files are already re-smoothed - weight *= pre_quant_scale.repeat((n, 1)).transpose(1, - 0).contiguous() - weight /= avg_pre_quant_scale.repeat( - (n, 1)).transpose(1, 0).contiguous() - - # Get scale - weight_t = weight.T.contiguous() - weight_t = weight_t.reshape(n, k // group_size, group_size) - weight_t = torch.abs(weight_t.reshape(-1, group_size)) - amax, idx = weight_t.max(1) - amax = amax.reshape(n, k // group_size).T.contiguous() - scale = amax / 8 - return weight, scale - - def get_tllm_qkv_weight_from_awq(prefix, tllm_prex): - q_weight = load(prefix + "q" + awq_key_list[4] + - awq_suffix_list[0]).T.contiguous() - k_weight = load(prefix + "k" + awq_key_list[4] + - awq_suffix_list[0]).T.contiguous() - v_weight = load(prefix + "v" + awq_key_list[4] + - awq_suffix_list[0]).T.contiguous() - dim_k = q_weight.shape[0] - q_weight = torch_split(q_weight, 1) - k_weight = torch_split(k_weight, 1) - v_weight = torch_split(v_weight, 1) - q_pre_quant_scale = load(prefix + "q" + awq_key_list[4] + - awq_suffix_list[2]).reshape((1, dim_k)) - k_pre_quant_scale = load(prefix + "k" + awq_key_list[4] + - awq_suffix_list[2]).reshape((1, dim_k)) - v_pre_quant_scale = load(prefix + "v" + awq_key_list[4] + - awq_suffix_list[2]).reshape((1, dim_k)) - qkv_pre_quant_scale = (q_pre_quant_scale + k_pre_quant_scale + - v_pre_quant_scale) / 3.0 - q_weight, q_scale = reSmooth_and_get_scale(q_weight, q_pre_quant_scale, - qkv_pre_quant_scale) - k_weight, k_scale = reSmooth_and_get_scale(k_weight, k_pre_quant_scale, - qkv_pre_quant_scale) - v_weight, v_scale = reSmooth_and_get_scale(v_weight, v_pre_quant_scale, - qkv_pre_quant_scale) - qkv_weights = torch.cat((q_weight, k_weight, v_weight), dim=1) - qkv_scale = torch.cat((q_scale, k_scale, v_scale), dim=1) - - results = { - f'{tllm_prex}.weight': - AWQ_quantize_pack_preprocess(qkv_weights, qkv_scale), - f'{tllm_prex}.weights_scaling_factor': - qkv_scale.to(torch_dtype), - f'{tllm_prex}.prequant_scaling_factor': - qkv_pre_quant_scale.to(torch_dtype), - } - return results - - # Load weights from AWQ checkpoint into TRT-LLM module - # 1. vocab_embedding - v = load(awq_key_list[0]) - # TRT-LLM requires vocab_size to be multiple of 64 for successful GEMM - if v.shape[0] % 64 != 0: - v = torch.nn.functional.pad(v, [0, 0, 0, 64 - v.shape[0] % 64]) - if mapping.is_first_pp_rank(): - weights['transformer.vocab_embedding.weight'] = v.to(torch_dtype) - - # 2. lm_head - if quantize_lm_head: - v = [load(awq_key_list[1] + suf) for suf in awq_suffix_list] - if v[0].shape[0] % 64 != 0: - v[0] = torch.nn.functional.pad(v[0], - [0, 0, 0, 64 - v[0].shape[0] % 64]) - scale_align = 64 * (v[0].shape[1] // group_size) - v[1] = v[1].reshape(-1) - v[1] = torch.nn.functional.pad( - v[1], [0, scale_align - v[1].shape[0] % scale_align], value=1) - if mapping.is_last_pp_rank(): - weights.update(get_tllm_weight_from_awq(v, 'lm_head', 1)) - else: - v = load(awq_key_list[1] + awq_suffix_list[0]) - if mapping.is_last_pp_rank(): - if vocab_size % mapping.tp_size != 0: - # padding - vocab_size_padded = pad_vocab_size(vocab_size, mapping.tp_size) - pad_width = vocab_size_padded - vocab_size - v = torch.from_numpy( - np.pad(v.detach().cpu().numpy(), ((0, pad_width), (0, 0)), - 'constant', - constant_values=0)) - weights['lm_head.weight'] = torch_split(v, 0).to(torch_dtype) - - # 3. ln_f - v = load(awq_key_list[2]) - if mapping.is_last_pp_rank(): - # tensorrt_llm_llama.ln_f.weight.value = v.to(torch_dtype).cpu().numpy() - weights['transformer.ln_f.weight'] = v.to(torch_dtype) - - # 4. Weights inside each layer - layers_range = mapping.pp_layers(num_hidden_layers) - - for l in layers_range: - layer_idx = l - layers_range[0] - prefix = "layers" + split_sym + str(layer_idx) + split_sym - tllm_prex = f'transformer.layers.{l-layers_range[0]}' - - logger.info(f'Process weights in layer: {layer_idx}') - # layer = tensorrt_llm_llama.layers[layer_idx] - - # 4.1 attention.qkv - weights.update( - get_tllm_qkv_weight_from_awq(prefix + awq_key_list[3], - f'{tllm_prex}.attention.qkv')) - - # 4.2 attention.dense - v = [load(prefix + awq_key_list[5] + suf) for suf in awq_suffix_list] - # process_and_assign_weight(layer.attention.dense, v, 0) - weights.update( - get_tllm_weight_from_awq(v, - f'{tllm_prex}.attention.dense', - tp_dim=0)) - # 4.3 mlp.gate - v = [load(prefix + awq_key_list[6] + suf) for suf in awq_suffix_list] - - weights.update( - get_tllm_weight_from_awq(v, f'{tllm_prex}.mlp.gate', tp_dim=1)) - - # 4.4 mlp.proj - v = [load(prefix + awq_key_list[7] + suf) for suf in awq_suffix_list] - weights.update( - get_tllm_weight_from_awq(v, f'{tllm_prex}.mlp.proj', tp_dim=0)) - # 4.5 mlp.fc - v = [load(prefix + awq_key_list[8] + suf) for suf in awq_suffix_list] - weights.update( - get_tllm_weight_from_awq(v, f'{tllm_prex}.mlp.fc', tp_dim=1)) - # 4.6 input_layernorm - v = load(prefix + awq_key_list[9]) - # layer.input_layernorm.weight.value = v.to(torch_dtype).cpu().numpy() - - weights[f'{tllm_prex}.input_layernorm.weight'] = v.to(torch_dtype) - # 4.7 post_layernorm - v = load(prefix + awq_key_list[10]) - # layer.post_layernorm.weight.value = v.to(torch_dtype).cpu().numpy() - weights[f'{tllm_prex}.post_layernorm.weight'] = v.to(torch_dtype) - - # 4.8 attention.kv_quant_orig_scale / kv_quant_orig_scale - # if use_int8_kv_cache: - # assert bin_model_dir, "You must pass --bin_model_dir to tell TRT-LLM where to look for scales of INT8 kv cache." - # t = fromfile( - # bin_model_dir, 'model.layers.' + str(layer_idx) + - # '.attention.query_key_value.scale_y_quant_orig.bin', [1], - # np.float32) - # assert t is not None, f"{bin_model_dir} does not contain model.layers.{layer_idx}.attention.query_key_value.scale_y_quant_orig.bin" - # layer.attention.kv_orig_quant_scale.value = 1.0 / t - # layer.attention.kv_quant_orig_scale.value = t - - return weights diff --git a/tensorrt_llm/models/mamba/model.py b/tensorrt_llm/models/mamba/model.py index dbc77f099..8c748f413 100644 --- a/tensorrt_llm/models/mamba/model.py +++ b/tensorrt_llm/models/mamba/model.py @@ -18,7 +18,8 @@ import tensorrt as trt from ..._utils import str_dtype_to_trt -from ...functional import Tensor, cast, gather_last_token_logits +from ...functional import (Tensor, arange, cast, concat, expand, + gather_last_token_logits, shape, unsqueeze) from ...layers import (Embedding, LayerNorm, Linear, Mamba, MambaParameters, RmsNorm) from ...module import Module, ModuleList @@ -49,7 +50,8 @@ def __init__(self, config: PretrainedConfig, last_layer=False): def forward(self, hidden_states: Tensor, residual: Tensor, conv_state: Tensor, ssm_state: Tensor, - host_request_types: Tensor): + host_request_types: Tensor, conv_indices: Tensor, + last_token_ids: Tensor): hidden_states = self.input_layernorm(hidden_states) @@ -57,7 +59,9 @@ def forward(self, hidden_states: Tensor, residual: Tensor, hidden_states, conv_state=conv_state, ssm_state=ssm_state, - host_request_types=host_request_types) + host_request_types=host_request_types, + conv_indices=conv_indices, + last_token_ids=last_token_ids) if self.residual_in_fp32: residual = residual + cast(ssm_out, 'float32') hidden_states = cast(residual, self.dtype) @@ -75,6 +79,8 @@ class MambaModel(Module): def __init__(self, config: PretrainedConfig): super().__init__() + self.d_conv = config.ssm_cfg['d_conv'] + self.d_inner = int(config.ssm_cfg['expand'] * config.hidden_size) n_layer = config.num_hidden_layers self.residual_in_fp32 = config.residual_in_fp32 if config.vocab_size % config.pad_vocab_size_multiple != 0: @@ -96,8 +102,21 @@ def __init__(self, config: PretrainedConfig): eps=config.norm_epsilon, dtype=config.dtype) - def forward(self, input_ids, conv_states, ssm_states, host_request_types): + def forward(self, input_ids, conv_states, ssm_states, host_request_types, + last_token_ids): hidden_states = self.vocab_embedding(input_ids) + + # Get conv state indices + batch_size = shape(input_ids, 0) + indices = expand( + unsqueeze(arange(0, self.d_conv - 1, dtype='int32'), 0), + concat([batch_size, self.d_conv - 1])) + offsets = expand(unsqueeze(last_token_ids, 1), + concat([batch_size, self.d_conv - 1])) + indices = unsqueeze(indices + offsets, 1) + indices = expand(indices, + concat([batch_size, self.d_inner, self.d_conv - 1])) + residual = cast(hidden_states, 'float32') if self.residual_in_fp32 else hidden_states hidden_values = [hidden_states, residual] @@ -105,7 +124,8 @@ def forward(self, input_ids, conv_states, ssm_states, host_request_types): for layer, past_conv, past_ssm in zip(self.layers, conv_states, ssm_states): hidden_values = layer(hidden_values[0], hidden_values[1], past_conv, - past_ssm, host_request_types) + past_ssm, host_request_types, indices, + last_token_ids) present_convs.append(hidden_values[2]) present_ssms.append(hidden_values[3]) hidden_states = hidden_values[0] @@ -130,6 +150,7 @@ def __init__(self, config: PretrainedConfig): self.d_conv = self.ssm_cfg.d_conv self.d_state = self.ssm_cfg.d_state self.config = config + self.gather_context_logits = False if isinstance(logits_dtype, str): self._logits_dtype = str_dtype_to_trt(logits_dtype) @@ -150,10 +171,12 @@ def __post_init__(self): def forward(self, input_ids, conv_states, ssm_states, host_request_types, last_token_ids): hidden_states, present_convs, present_ssms = self.backbone( - input_ids, conv_states, ssm_states, host_request_types) + input_ids, conv_states, ssm_states, host_request_types, + last_token_ids) - hidden_states = gather_last_token_logits(hidden_states, last_token_ids, - False) + if not self.gather_context_logits: + hidden_states = gather_last_token_logits(hidden_states, + last_token_ids, False) lm_logits = self.lm_head(hidden_states) lm_logits.mark_output('logits', self._logits_dtype) @@ -183,9 +206,7 @@ def prepare_inputs(self, @return: a list contains values which can be fed into the self.forward() ''' batch_range = [GenerationMixin.default_range(max_batch_size)] - conv_state_range = [ - GenerationMixin.default_range(self.d_conv - 1 + max_input_len) - ] + self.gather_context_logits = gather_context_logits input_ids = Tensor(name='input_ids', dtype=trt.int32, shape=[-1, -1], @@ -198,7 +219,7 @@ def prepare_inputs(self, conv_state_dim_range = OrderedDict([ ('batch_size', batch_range), ('dim_size', [self.d_inner]), - ('kernel_size', conv_state_range), + ('kernel_size', [self.d_conv - 1]), ]) ssm_state_dim_range = OrderedDict([ @@ -210,7 +231,7 @@ def prepare_inputs(self, for i in range(self.config.num_hidden_layers): conv_state = Tensor(name=f'past_conv_state_{i}', dtype=self.dtype, - shape=[-1, self.d_inner, -1], + shape=[-1, self.d_inner, self.d_conv - 1], dim_range=conv_state_dim_range) ssm_state = Tensor(name=f'past_ssm_state_{i}', @@ -228,16 +249,14 @@ def prepare_inputs(self, dim_range=OrderedDict([('batch_size', batch_range)]), ) - last_token_ids = None - if not gather_context_logits: - last_token_ids = Tensor( - name='last_token_ids', - dtype=trt.int32, - shape=[-1], - dim_range=OrderedDict([ - ('batch_size_last_token_ids', batch_range), - ]), - ) + last_token_ids = Tensor( + name='last_token_ids', + dtype=trt.int32, + shape=[-1], + dim_range=OrderedDict([ + ('batch_size', batch_range), + ]), + ) return { 'input_ids': input_ids, diff --git a/tensorrt_llm/models/modeling_utils.py b/tensorrt_llm/models/modeling_utils.py index 416045b98..012c1174c 100644 --- a/tensorrt_llm/models/modeling_utils.py +++ b/tensorrt_llm/models/modeling_utils.py @@ -2,7 +2,7 @@ import dataclasses import json import os -from typing import List, Optional, Union +from typing import Dict, List, Optional, Union import numpy as np import safetensors @@ -12,16 +12,18 @@ from .._utils import (numpy_to_torch, str_dtype_to_torch, str_dtype_to_trt, trt_dtype_to_torch) from ..functional import PositionEmbeddingType, Tensor, gather_last_token_logits -from ..layers import (AttentionParams, FusedGatedMLP, GatedMLP, - KeyValueCacheParams, LoraParams) -from ..layers.attention import Attention -from ..layers.linear import ColumnLinear +from ..layers import (AttentionParams, Embedding, FusedGatedMLP, GatedMLP, + KeyValueCacheParams, LoraParams, PromptTuningEmbedding) +from ..layers.attention import Attention, BertAttention +from ..layers.linear import ColumnLinear, Linear, RowLinear +from ..layers.lora import Lora from ..logger import logger from ..mapping import Mapping from ..module import Module, ModuleList from ..quantization import QuantMode from ..quantization.layers import FP8Linear -from ..quantization.mode import W8A8_SQ_PLUGIN_LIST +from ..quantization.mode import (FP8, W4A8_AWQ, W4A16, W4A16_AWQ, + W8A8_SQ_PLUGIN_LIST, W8A16) from ..quantization.quantize import quantize from .generation_mixin import GenerationMixin @@ -44,6 +46,32 @@ class QuantizationConfig: def use_plugin_sq(self): return self.quant_algo in W8A8_SQ_PLUGIN_LIST + def quant_algo_to_ammo_qformat(self): + from ..quantization import mode as quant_algo + + #"fp8", "int8_sq", "int4_awq", "w4a8_awq", "int8_wo", "int4_wo", + algo_to_ammo_map = { + quant_algo.W8A16: "int8_wo", + quant_algo.W4A16: "int4_wo", + quant_algo.W4A16_AWQ: "int4_awq", + quant_algo.W4A8_AWQ: 'w4a8_awq', + quant_algo.FP8: 'fp8', + quant_algo.W4A16_GPTQ: None, + quant_algo.W8A8_SQ_PER_CHANNEL: 'int8_sq', + quant_algo.W8A8_SQ_PER_TENSOR_PLUGIN: None, + quant_algo.W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN: None, + quant_algo.W8A8_SQ_PER_CHANNEL_PER_TENSOR_PLUGIN: None, + quant_algo.W8A8_SQ_PER_TENSOR_PER_TOKEN_PLUGIN: None, + None: 'full_prec' + } + assert self.quant_algo in algo_to_ammo_map + qformat = algo_to_ammo_map[self.quant_algo] + assert qformat is not None, "None means we don't use AMMO for this kind of quantization algorithm, you probably shall not call this" + return qformat + + def asdict(self): + return dataclasses.asdict(self) + def default_weight_loader(mapping: Mapping, param: torch.Tensor, loaded_weight: torch.Tensor) -> None: @@ -71,11 +99,9 @@ def __init__(self, tp_size: int, pp_size: int, quantization: Union[QuantizationConfig, dict], - use_prompt_tuning: bool = False, use_parallel_embedding: bool = False, embedding_sharding_dim: int = 0, share_embedding_table: bool = False, - max_lora_rank: int = 64, head_size: int = None, **kwargs): self.architecture = architecture @@ -94,7 +120,6 @@ def __init__(self, self.norm_epsilon = norm_epsilon self.position_embedding_type = PositionEmbeddingType.from_string( position_embedding_type) - self.use_prompt_tuning = use_prompt_tuning self.use_parallel_embedding = use_parallel_embedding self.embedding_sharding_dim = embedding_sharding_dim self.share_embedding_table = share_embedding_table @@ -114,7 +139,6 @@ def __init__(self, ), f"Expecting type of QuantizationConfig, found {type(quantization)}" self.quantization = quantization self.kv_dtype = self.dtype - self.max_lora_rank = max_lora_rank if self.quant_mode.has_int8_kv_cache(): self.kv_dtype = 'int8' elif self.quant_mode.has_fp8_kv_cache(): @@ -150,7 +174,6 @@ def from_dict(cls, config): num_attention_heads) intermediate_size = config.pop('intermediate_size', None) max_position_embeddings = config.pop('max_position_embeddings', None) - use_prompt_tuning = config.pop('use_prompt_tuning', False) use_parallel_embedding = config.pop('use_parallel_embedding', False) embedding_sharding_dim = config.pop('embedding_sharding_dim', 0) share_embedding_table = config.pop('share_embedding_table', False) @@ -185,16 +208,13 @@ def from_dict(cls, config): assert isinstance(quant_config_from_user, QuantizationConfig) quant_config = quant_config_from_user - max_lora_rank = config.pop('max_lora_rank', 64) - return cls(architecture, dtype, logits_dtype, vocab_size, max_position_embeddings, hidden_size, num_hidden_layers, num_attention_heads, num_key_value_heads, hidden_act, intermediate_size, norm_epsilon, position_embedding_type, world_size, tp_size, pp_size, quant_config, - use_prompt_tuning, use_parallel_embedding, - embedding_sharding_dim, share_embedding_table, max_lora_rank, - **config) + use_parallel_embedding, embedding_sharding_dim, + share_embedding_table, **config) @classmethod def from_json_file(cls, config_file: str): @@ -332,7 +352,7 @@ def from_checkpoint(cls, device='cpu') as f: for key in f.keys(): weights[key] = f.get_tensor(key) - + preprocess_weights(weights, config) model.load(weights) return model @@ -413,6 +433,7 @@ def prepare_inputs(self, use_custom_all_reduce = default_net( ).plugin_config.use_custom_all_reduce use_lora_plugin = default_net().plugin_config.lora_plugin + multiple_profiles = default_net().plugin_config.multiple_profiles model_inputs = self.prepare_basic_inputs( max_batch_size=max_batch_size, @@ -439,7 +460,8 @@ def prepare_inputs(self, use_custom_all_reduce=use_custom_all_reduce, use_lora_plugin=use_lora_plugin, max_draft_len=max_draft_len, - lora_target_modules=lora_target_modules) + lora_target_modules=lora_target_modules, + multiple_profiles=multiple_profiles) result = { 'input_ids': @@ -491,6 +513,44 @@ def prepare_inputs(self, return result + @classmethod + def quantize( + cls, + hf_model_dir, + output_dir, + quant_config: QuantizationConfig, + *, + dtype='float16', + mapping: Optional[Mapping] = None, + calib_batches=512, + calib_batch_size=1, + random_seed=1234, + tokenizer_max_seq_length=2048, + ): + if mapping is None: # single gpu + mapping = Mapping() + ammo_qformat = quant_config.quant_algo_to_ammo_qformat() + kv_cache_dtype = quant_config.kv_cache_quant_algo + assert ammo_qformat is not None + from ..quantization import quantize_and_export + hf_model_dir = str( + hf_model_dir) # quantize_and_export has some code can not take Path + quantize_and_export( + model_dir=hf_model_dir, + dtype=dtype, + device='cuda', + qformat=ammo_qformat, + kv_cache_dtype=kv_cache_dtype, + calib_size=calib_batches, + batch_size=calib_batch_size, + output_dir=output_dir, + tp_size=mapping.tp_size, + pp_size=mapping.pp_size, + seed=random_seed, + max_seq_length=tokenizer_max_seq_length, + awq_block_size=quant_config.group_size, + ) + class DecoderModelForCausalLM(PretrainedModel): @@ -584,8 +644,7 @@ def fuse_gate_mlp(model): dtype=layer.mlp.dtype, tp_group=layer.mlp.tp_group, tp_size=layer.mlp.tp_size, - quant_mode=layer.mlp.quant_mode, - max_lora_rank=layer.mlp.max_lora_rank) + quant_mode=layer.mlp.quant_mode) if quant_algo == 'FP8': if isinstance(layer.mlp.dtype, str): @@ -694,9 +753,196 @@ def unfuse_qkv_gemm(model): return model -def optimize_model(model, use_fused_mlp=False, use_unfused_qkv_gemm=False): +def set_prompt_tuning(model): + if isinstance(model.transformer.vocab_embedding, Embedding): + embedding = model.transformer.vocab_embedding + model.transformer.vocab_embedding = PromptTuningEmbedding( + num_embeddings=embedding.num_embeddings, + embedding_dim=embedding.embedding_dim, + dtype=embedding.dtype, + tp_size=embedding.tp_size, + tp_group=embedding.tp_group, + sharding_dim=embedding.sharding_dim, + tp_rank=embedding.tp_rank) + + model.transformer.vocab_embedding.weight.value = embedding.weight.raw_value + return model + + +def add_lora(model, max_lora_rank: Optional[int]): + for name, layer in model.named_modules(remove_duplicate=True): + max_rank = max_lora_rank + if isinstance(layer, (Attention, BertAttention)): + if max_rank is None: + max_rank = min( + layer.hidden_size, + layer.num_attention_heads * layer.attention_head_size, + layer.num_attention_kv_heads * layer.attention_head_size) + layer.qkv_lora = Lora( + in_hidden_size=layer.hidden_size, + out_hidden_sizes=[ + layer.num_attention_heads * layer.attention_head_size, + layer.num_attention_kv_heads * layer.attention_head_size, + layer.num_attention_kv_heads * layer.attention_head_size + ], + max_low_rank=max_rank, + ) + if isinstance(layer, (Linear, RowLinear)): + if max_rank is None: + max_rank = min(layer.in_features, layer.out_features) + layer.lora = Lora( + in_hidden_size=layer.in_features, + out_hidden_sizes=[layer.out_features], + max_low_rank=max_rank, + ) + if isinstance(layer, FusedGatedMLP): + if max_rank is None: + max_rank = min(layer.hidden_size, + layer.ffn_hidden_size // layer.tp_size) + layer.mlp_in_lora = Lora( + in_hidden_size=layer.hidden_size, + out_hidden_sizes=[ + layer.ffn_hidden_size // layer.tp_size, + layer.ffn_hidden_size // layer.tp_size + ], + max_low_rank=max_rank, + ) + return model + + +def optimize_model( + model, + use_fused_mlp=False, + use_unfused_qkv_gemm=False, + use_prompt_tuning=False, + use_lora=False, + max_lora_rank=None, +): if use_fused_mlp: model = fuse_gate_mlp(model) if use_unfused_qkv_gemm: model = unfuse_qkv_gemm(model) + if use_prompt_tuning: + model = set_prompt_tuning(model) + if use_lora: + model = add_lora(model, max_lora_rank) return model + + +def preprocess_weights( + weights: Dict[str, torch.Tensor], + model_config: PretrainedConfig) -> Dict[str, torch.Tensor]: + quant_algo = model_config.quantization.quant_algo + kv_cache_quant_algo = model_config.quantization.kv_cache_quant_algo + + # INT4_AWQ + if quant_algo == W4A8_AWQ or quant_algo == W4A16_AWQ: + preprocessor = torch.ops.trtllm.preprocess_weights_for_mixed_gemm + for name, param in weights.items(): + if name.endswith('weight') and param.dtype == torch.int8: + dtype = torch.float16 + if model_config.dtype == "bfloat16": + dtype = torch.bfloat16 + weights[name] = preprocessor(param.T.contiguous(), + torch.quint4x2).view(dtype) + if name.endswith('weights_scaling_factor'): + weights[name] = param.T.contiguous().to( + str_dtype_to_torch(model_config.dtype)) + if name.endswith('prequant_scaling_factor'): + weights[name] = param.reshape(1, -1) + if model_config.mapping.tp_rank > 0: + if name.endswith('attention.dense.bias') or name.endswith( + 'mlp.proj.bias'): + weights[name] = torch.zeros_like(param) + + if quant_algo == W4A8_AWQ: + for name in list(weights): + if name.endswith('weights_scaling_factor'): + activation_scaling_factor = weights.pop( + name.replace('weights_scaling_factor', + 'activation_scaling_factor')) + weights_scaling_factor_2 = weights.pop( + name.replace('weights_scaling_factor', + 'weights_scaling_factor_2')) + weights[name] /= weights_scaling_factor_2 + weights[name.replace( + 'weights_scaling_factor', + 'prequant_scaling_factor')] /= activation_scaling_factor + weights[name.replace( + 'weights_scaling_factor', 'alpha' + )] = activation_scaling_factor * weights_scaling_factor_2 + + # FP8 + elif quant_algo == FP8: + for name, param in weights.items(): + if name.endswith('weight') and param.dtype == torch.int8: + weights[name] = param.view(torch.float8_e4m3fn) + # lm_head is not quantized to FP8 + if "lm_head.weight" in weights: + assert weights['lm_head.weight'].dtype == str_dtype_to_torch( + model_config.dtype) + weights.pop('lm_head.weights_scaling_factor', None) + weights.pop('lm_head.activation_scaling_factor', None) + + # Weight only 4bit + elif quant_algo == W4A16: + for name in list(weights): + if any([ + _name in name for _name in [ + 'qkv.weight', 'dense.weight', 'fc.weight', + 'proj.weight', 'gate.weight' + ] + ]) and weights[name].dtype != torch.int8: + processed_torch_weights, torch_weight_scales = \ + torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix( + weights[name].t().contiguous(), torch.quint4x2) + weights[name] = processed_torch_weights + weights[name.replace( + '.weight', '.per_channel_scale')] = torch_weight_scales + + # Weight only 8bit + elif quant_algo == W8A16: + for name in list(weights): + if any([ + _name in name for _name in [ + 'qkv.weight', 'dense.weight', 'fc.weight', + 'proj.weight', 'gate.weight' + ] + ]) and weights[name].dtype != torch.int8: + processed_torch_weights, torch_weight_scales = \ + torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix( + weights[name].t().contiguous(), torch.int8) + weights[name] = processed_torch_weights + weights[name.replace( + '.weight', '.per_channel_scale')] = torch_weight_scales + + # FP8 kv_cache_scaling_factor is always 1.0 + if kv_cache_quant_algo == FP8: + for name, param in weights.items(): + if name.endswith('kv_cache_scaling_factor'): + weights[name] = torch.tensor([1.0], dtype=torch.float32) + + # If layer_norm bias is None. (For MPT) + if model_config.architecture == 'MPTForCausalLM': + update_dict = {} + for name, param in weights.items(): + if 'input_layernorm.weight' in name and name.replace( + 'weight', 'bias') not in weights: + update_dict[name.replace('weight', + 'bias')] = torch.zeros_like(param) + if 'post_layernorm.weight' in name and name.replace( + 'weight', 'bias') not in weights: + update_dict[name.replace('weight', + 'bias')] = torch.zeros_like(param) + if 'ln_f.weight' in name and name.replace('weight', + 'bias') not in weights: + update_dict[name.replace('weight', + 'bias')] = torch.zeros_like(param) + weights.update(update_dict) + + # Parallel block rowlinear should not have duplicate bias. + elif model_config.architecture == 'GPTJForCausalLM': + if model_config.mapping.tp_rank > 0: + for name, param in weights.items(): + if 'attention.dense.bias' in name or 'mlp.proj.bias' in name: + weights[name] = torch.zeros_like(param) diff --git a/tensorrt_llm/models/opt/model.py b/tensorrt_llm/models/opt/model.py index 95e315446..7d2199818 100644 --- a/tensorrt_llm/models/opt/model.py +++ b/tensorrt_llm/models/opt/model.py @@ -16,7 +16,7 @@ from ..._utils import pad_vocab_size from ...functional import Tensor from ...layers import (MLP, Attention, AttentionMaskType, ColumnLinear, - Embedding, LayerNorm, PromptTuningEmbedding) + Embedding, LayerNorm) from ...module import Module from ..modeling_utils import (DecoderLayerList, DecoderModelForCausalLM, PretrainedConfig) @@ -109,11 +109,9 @@ class OPTModel(Module): def __init__(self, config: PretrainedConfig): super().__init__() self.do_layer_norm_before = config.do_layer_norm_before - self.use_prompt_tuning = config.use_prompt_tuning mapping = config.mapping - EmbeddingCls = PromptTuningEmbedding if config.use_prompt_tuning else Embedding - self.vocab_embedding = EmbeddingCls( + self.vocab_embedding = Embedding( config.vocab_size, config.hidden_size, dtype=config.dtype, @@ -144,7 +142,7 @@ def forward(self, prompt_vocab_size=None): args = [prompt_embedding_table, prompt_tasks, prompt_vocab_size - ] if self.use_prompt_tuning else [] + ] if prompt_embedding_table is not None else [] hidden_states = self.vocab_embedding(input_ids, *args) hidden_states = hidden_states + self.position_embedding(position_ids) diff --git a/tensorrt_llm/models/phi/model.py b/tensorrt_llm/models/phi/model.py index 32901bedf..1212de593 100644 --- a/tensorrt_llm/models/phi/model.py +++ b/tensorrt_llm/models/phi/model.py @@ -96,7 +96,6 @@ def __init__(self, config: PretrainedConfig): mapping = config.mapping use_parallel_embedding = False embedding_sharding_dim = 0 - self.use_prompt_tuning = config.use_prompt_tuning self.vocab_embedding = Embedding( num_embeddings=config.vocab_size, @@ -124,7 +123,7 @@ def forward( prompt_vocab_size=None, ): args = [prompt_embedding_table, prompt_tasks, prompt_vocab_size - ] if self.use_prompt_tuning else [] + ] if prompt_embedding_table is not None else [] hidden_states = self.vocab_embedding(input_ids, *args) hidden_states = self.layers( diff --git a/tensorrt_llm/models/quantized/quant.py b/tensorrt_llm/models/quantized/quant.py index 028341384..85a3d6e0d 100644 --- a/tensorrt_llm/models/quantized/quant.py +++ b/tensorrt_llm/models/quantized/quant.py @@ -357,13 +357,13 @@ def _smooth_quantize_chatglm(model, quant_mode): def _smooth_quantize(model, quant_mode): from ...models import (BaichuanForCausalLM, BloomForCausalLM, - ChatGLMForCausalLM, GPTLMHeadModel, LLaMAForCausalLM, + ChatGLMForCausalLM, GPTForCausalLM, LLaMAForCausalLM, QWenForCausalLM) - assert isinstance(model, GPTLMHeadModel) or isinstance(model, LLaMAForCausalLM) \ + assert isinstance(model, GPTForCausalLM) or isinstance(model, LLaMAForCausalLM) \ or isinstance(model, BloomForCausalLM) or isinstance(model, BaichuanForCausalLM) \ or isinstance(model, QWenForCausalLM) or isinstance(model, ChatGLMForCausalLM), \ - "Only GPTLMHeadModel, LLaMAForCausalLM BloomForCausalLM and BaichuanForCausalLM are well tested now" - if isinstance(model, GPTLMHeadModel): + "Only GPTForCausalLM, LLaMAForCausalLM BloomForCausalLM and BaichuanForCausalLM are well tested now" + if isinstance(model, GPTForCausalLM): return _smooth_quantize_gpt(model, quant_mode) elif isinstance(model, LLaMAForCausalLM): return _smooth_quantize_llama(model, quant_mode) @@ -536,8 +536,8 @@ def _default_fp8_quantize(model, def _fp8_quantize(model, quant_mode: QuantMode, quant_scales: dict = None): from ...models import (BaichuanForCausalLM, FalconForCausalLM, - GPTJForCausalLM, GPTLMHeadModel, LLaMAForCausalLM) - if isinstance(model, (FalconForCausalLM, GPTJForCausalLM, GPTLMHeadModel, + GPTForCausalLM, GPTJForCausalLM, LLaMAForCausalLM) + if isinstance(model, (FalconForCausalLM, GPTJForCausalLM, GPTForCausalLM, LLaMAForCausalLM, BaichuanForCausalLM)): return _default_fp8_quantize(model, quant_mode, quant_scales) raise NotImplementedError( diff --git a/tensorrt_llm/models/qwen/model.py b/tensorrt_llm/models/qwen/model.py index a1ef24790..4178381a2 100644 --- a/tensorrt_llm/models/qwen/model.py +++ b/tensorrt_llm/models/qwen/model.py @@ -15,161 +15,73 @@ from typing import Optional -import tensorrt as trt - -from ..._common import default_net -from ..._utils import pad_vocab_size, str_dtype_to_trt -from ...functional import (Tensor, gather_last_token_logits, partial, recv, - send, unary) -from ...layers import (Attention, AttentionMaskType, AttentionParams, - ColumnLinear, Embedding, FusedGatedMLP, GatedMLP, - KeyValueCacheParams, PositionEmbeddingType, - PromptTuningEmbedding, RmsNorm) -from ...mapping import Mapping -from ...module import Module, ModuleList -from ...quantization import QuantMode -from ..generation_mixin import GenerationMixin - -log = partial(unary, op=trt.UnaryOperation.LOG) -ceil = partial(unary, op=trt.UnaryOperation.CEIL) - - -class GPTEmbedding(Module): - - def __init__(self, - vocab_size, - hidden_size, - max_position_embeddings, - position_embedding_type=PositionEmbeddingType.learned_absolute, - dtype=None, - use_prompt_tuning=False, - tensor_parallel=1, - tensor_parallel_group=None, - sharding_dim=0, - tp_rank=None): - super().__init__() - self.max_position_embeddings = max_position_embeddings - self.position_embedding_type = position_embedding_type - self.use_prompt_tuning = use_prompt_tuning - - EmbeddingCls = PromptTuningEmbedding if use_prompt_tuning else Embedding - self.vocab_embedding = EmbeddingCls(vocab_size, - hidden_size, - dtype=dtype, - tp_size=tensor_parallel, - tp_group=tensor_parallel_group, - sharding_dim=sharding_dim, - tp_rank=tp_rank) - - if self.position_embedding_type == PositionEmbeddingType.learned_absolute: - self.position_embedding = Embedding(max_position_embeddings, - hidden_size, - dtype=dtype) +from ..._utils import pad_vocab_size +from ...functional import Tensor, recv, send +from ...layers import (Attention, AttentionMaskType, ColumnLinear, Embedding, + GatedMLP, RmsNorm) +from ...module import Module +from ..modeling_utils import (DecoderLayerList, DecoderModelForCausalLM, + PretrainedConfig) - def forward(self, - input_ids, - position_ids, - prompt_embedding_table=None, - prompt_tasks=None, - prompt_vocab_size=None): - args = [] - if self.use_prompt_tuning: - args = [prompt_embedding_table, prompt_tasks, prompt_vocab_size] - x = self.vocab_embedding(input_ids, *args) - if self.position_embedding_type == PositionEmbeddingType.learned_absolute: - x = x + self.position_embedding(position_ids) - - return x - - -class QWenBlock(Module): - - def __init__(self, - local_layer_idx, - hidden_size, - seq_length, - num_attention_heads, - max_position_embeddings, - num_layers, - dtype=None, - attention_mask_type=AttentionMaskType.causal, - apply_query_key_layer_scaling=False, - hidden_act='silu', - position_embedding_type=PositionEmbeddingType.rope_gpt_neox, - rotary_base=10000.0, - rotary_scaling=None, - quant_mode=QuantMode(0), - mlp_hidden_size=None, - bias=False, - tp_group=None, - tp_size=1, - tp_rank=0, - rms_norm_eps=1e-06, - use_fused_mlp=False): + +class QWenDecoderLayer(Module): + + def __init__(self, config: PretrainedConfig, layer_idx: int): super().__init__() - self.layer_idx = local_layer_idx - self.hidden_size = hidden_size - self.seq_length = seq_length - self.mlp_hidden_size = mlp_hidden_size - self.bias = bias - self.hidden_act = hidden_act - self.dtype = dtype - self.attention_mask_type = attention_mask_type - self.apply_query_key_layer_scaling = apply_query_key_layer_scaling - self.tp_group = tp_group - self.tp_size = tp_size - self.num_attention_heads = num_attention_heads - self.max_position_embeddings = max_position_embeddings - self.num_layers = num_layers - self.position_embedding_type = position_embedding_type - - self.ln_1 = RmsNorm(normalized_shape=hidden_size, - eps=rms_norm_eps, - dtype=dtype) + self.layer_idx = layer_idx + self.config = config + + dtype = config.dtype + tp_group = config.mapping.tp_group + tp_size = config.mapping.tp_size + self.input_layernorm = RmsNorm(normalized_shape=config.hidden_size, + eps=config.norm_epsilon, + dtype=dtype) + + layers_range = config.mapping.pp_layers(config.num_hidden_layers) + local_layer_idx = layer_idx - layers_range[0] self.attention = Attention( local_layer_idx=local_layer_idx, - hidden_size=self.hidden_size, - num_attention_heads=self.num_attention_heads, - max_position_embeddings=self.max_position_embeddings, - num_layers=self.num_layers, - dtype=self.dtype, - attention_mask_type=self.attention_mask_type, - position_embedding_type=self.position_embedding_type, - rotary_embedding_base=rotary_base, - rotary_embedding_scaling=rotary_scaling, - tp_group=self.tp_group, - tp_size=self.tp_size, - quant_mode=quant_mode, - dense_bias=bias) - if not mlp_hidden_size: - mlp_hidden_size = hidden_size * 4 - - ClsMLP = FusedGatedMLP if use_fused_mlp else GatedMLP - - self.mlp = ClsMLP(hidden_size=hidden_size, - ffn_hidden_size=mlp_hidden_size // 2, - hidden_act=hidden_act, - dtype=dtype, - bias=False, - tp_group=tp_group, - tp_size=tp_size, - quant_mode=quant_mode) - self.ln_2 = RmsNorm(normalized_shape=hidden_size, - eps=rms_norm_eps, - dtype=dtype) + hidden_size=config.hidden_size, + num_attention_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + max_position_embeddings=config.max_position_embeddings, + dtype=dtype, + attention_mask_type=AttentionMaskType.causal, + position_embedding_type=config.position_embedding_type, + rotary_embedding_base=config.rotary_base, + rotary_embedding_scaling=config.rotary_scaling, + tp_group=tp_group, + tp_size=tp_size, + quant_mode=config.quant_mode, + dense_bias=False) + + self.mlp = GatedMLP(hidden_size=config.hidden_size, + ffn_hidden_size=config.intermediate_size // 2, + hidden_act=config.hidden_act, + dtype=dtype, + bias=False, + tp_group=tp_group, + tp_size=tp_size, + quant_mode=config.quant_mode) + self.post_layernorm = RmsNorm(normalized_shape=config.hidden_size, + eps=config.norm_epsilon, + dtype=dtype) def forward( self, hidden_states: Tensor, + attention_mask=None, use_cache=False, kv_cache_params=None, attention_params=None, ): residual = hidden_states - hidden_states = self.ln_1(hidden_states) + hidden_states = self.input_layernorm(hidden_states) attention_output = self.attention( hidden_states, + attention_mask=attention_mask, use_cache=use_cache, kv_cache_params=kv_cache_params, attention_params=attention_params, @@ -181,7 +93,7 @@ def forward( residual = hidden_states - hidden_states = self.ln_2(hidden_states) + hidden_states = self.post_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) @@ -193,121 +105,62 @@ def forward( class QWenModel(Module): - def __init__(self, - num_layers, - num_heads, - hidden_size, - seq_length, - vocab_size, - hidden_act, - max_position_embeddings, - dtype, - mlp_hidden_size=None, - position_embedding_type=PositionEmbeddingType.rope_gpt_neox, - bias=False, - rotary_base=10000.0, - rotary_scaling=None, - mapping=Mapping(), - quant_mode=QuantMode(0), - use_parallel_embedding=False, - embedding_sharding_dim=0, - rms_norm_eps=1e-06, - use_prompt_tuning=False, - use_fused_mlp=False): + def __init__(self, config: PretrainedConfig): super().__init__() - self.mapping = mapping + self.mapping = config.mapping if self.mapping.is_first_pp_rank(): - self.embedding = GPTEmbedding( - vocab_size, - hidden_size, - max_position_embeddings, - position_embedding_type=PositionEmbeddingType.relative, - dtype=dtype, - use_prompt_tuning=use_prompt_tuning, - tensor_parallel=mapping.tp_size - if use_parallel_embedding else 1, - tensor_parallel_group=mapping.tp_group - if use_parallel_embedding else None, - sharding_dim=embedding_sharding_dim, - tp_rank=mapping.tp_rank) - - layers_range = mapping.pp_layers(self.num_layers) - self.layers = ModuleList([ - QWenBlock(local_layer_idx=layer_idx - layers_range[0], - hidden_size=hidden_size, - seq_length=seq_length, - num_attention_heads=num_heads, - num_layers=num_layers, - max_position_embeddings=max_position_embeddings, - dtype=dtype, - hidden_act=hidden_act, - quant_mode=quant_mode, - mlp_hidden_size=mlp_hidden_size, - position_embedding_type=position_embedding_type, - rotary_base=rotary_base, - rotary_scaling=rotary_scaling, - bias=bias, - tp_group=mapping.tp_group, - tp_size=mapping.tp_size, - tp_rank=mapping.tp_rank, - rms_norm_eps=rms_norm_eps, - use_fused_mlp=use_fused_mlp) for layer_idx in layers_range - ]) + self.vocab_embedding = Embedding( + config.vocab_size, + config.hidden_size, + dtype=config.dtype, + tp_size=self.mapping.tp_size + if config.use_parallel_embedding else 1, + tp_group=self.mapping.tp_group + if config.use_parallel_embedding else None, + sharding_dim=config.embedding_sharding_dim, + tp_rank=self.mapping.tp_rank) + + self.layers = DecoderLayerList(QWenDecoderLayer, config) if self.mapping.is_last_pp_rank(): - self.ln_f = RmsNorm(normalized_shape=hidden_size, - eps=rms_norm_eps, - dtype=dtype) + self.ln_f = RmsNorm(normalized_shape=config.hidden_size, + eps=config.norm_epsilon, + dtype=config.dtype) def forward(self, - input_ids, + input_ids: Tensor, position_ids=None, use_cache=False, + attention_mask=None, kv_cache_params=None, attention_params=None, hidden_states=None, - prompt_embedding_table=None, - prompt_tasks=None, - prompt_vocab_size=None): - - if kv_cache_params.past_key_value is None: - tuple([None] * len(self.layers)) + prompt_embedding_table: Optional[Tensor] = None, + prompt_tasks: Optional[Tensor] = None, + prompt_vocab_size: Optional[Tensor] = None): kv_cache_params.fill_none_tensor_list(len(self.layers)) if use_cache: presents = [] + ptuning_args = [ + prompt_embedding_table, prompt_tasks, prompt_vocab_size + ] if prompt_embedding_table is not None else [] + if self.mapping.is_first_pp_rank(): - hidden_states = self.embedding(input_ids, position_ids, - prompt_embedding_table, prompt_tasks, - prompt_vocab_size) + hidden_states = self.vocab_embedding(input_ids, *ptuning_args) else: hidden_states = recv(hidden_states, self.mapping.prev_pp_rank()) - self.register_network_output(f"embd", hidden_states) - - for layer, past in zip(self.layers, kv_cache_params.past_key_value): - hidden_states = layer( - hidden_states, - use_cache=use_cache, - kv_cache_params=KeyValueCacheParams( - past_key_value=[past], - host_past_key_value_lengths=kv_cache_params. - host_past_key_value_lengths, - host_max_attention_window_sizes=kv_cache_params. - host_max_attention_window_sizes, - host_sink_token_length=kv_cache_params. - host_sink_token_length, - kv_cache_block_pointers=kv_cache_params. - kv_cache_block_pointers, - host_kv_cache_block_pointers=kv_cache_params. - host_kv_cache_block_pointers, - cache_indirection=kv_cache_params.cache_indirection), - attention_params=attention_params) - - if use_cache: - presents.append(hidden_states[1]) - hidden_states = hidden_states[0] + + hidden_states = self.layers.forward(hidden_states, + use_cache=use_cache, + attention_mask=attention_mask, + kv_cache_params=kv_cache_params, + attention_params=attention_params) + + if use_cache: + hidden_states, presents = hidden_states if self.mapping.is_last_pp_rank(): hidden_states = self.ln_f(hidden_states) @@ -319,203 +172,28 @@ def forward(self, return hidden_states -class QWenForCausalLM(QWenModel, GenerationMixin): - - def __init__(self, - num_layers, - num_heads, - num_kv_heads, - hidden_size, - seq_length, - vocab_size, - hidden_act, - max_position_embeddings, - dtype, - logits_dtype="float32", - mlp_hidden_size=None, - position_embedding_type=PositionEmbeddingType.rope_gpt_neox, - rotary_base=10000.0, - rotary_scaling=None, - mapping=Mapping(), - quant_mode=QuantMode(0), - use_parallel_embedding=False, - embedding_sharding_dim=0, - rms_norm_eps=1e-06, - use_prompt_tuning=False, - use_fused_mlp=False): - self.mapping = mapping - if isinstance(dtype, str): - self.dtype = str_dtype_to_trt(dtype) - else: - assert isinstance(dtype, trt.DataType) - self.dtype = dtype - if isinstance(logits_dtype, str): - self.logits_dtype = str_dtype_to_trt(logits_dtype) - else: - assert isinstance(logits_dtype, trt.DataType) - self.logits_dtype = logits_dtype - self.num_layers = num_layers - self.num_heads = num_heads - if num_kv_heads is None or num_kv_heads <= 0: - num_kv_heads = num_heads - self.num_kv_heads = num_kv_heads - self.hidden_size = hidden_size - self.vocab_size = vocab_size - self.tp_size = mapping.tp_size - - self.kv_dtype = self.dtype - if quant_mode.has_int8_kv_cache(): - self.kv_dtype = str_dtype_to_trt('int8') - elif quant_mode.has_fp8_kv_cache(): - self.kv_dtype = str_dtype_to_trt('fp8') - self.quant_mode = quant_mode - self.use_parallel_embedding = use_parallel_embedding - self.embedding_sharding_dim = embedding_sharding_dim - self.use_fused_mlp = use_fused_mlp - - super().__init__(num_layers=num_layers, - num_heads=num_heads, - hidden_size=hidden_size, - seq_length=seq_length, - vocab_size=vocab_size, - hidden_act=hidden_act, - max_position_embeddings=max_position_embeddings, - dtype=dtype, - mlp_hidden_size=mlp_hidden_size, - position_embedding_type=position_embedding_type, - rotary_base=rotary_base, - rotary_scaling=rotary_scaling, - mapping=mapping, - quant_mode=quant_mode, - use_parallel_embedding=use_parallel_embedding, - embedding_sharding_dim=embedding_sharding_dim, - rms_norm_eps=rms_norm_eps, - use_prompt_tuning=use_prompt_tuning, - use_fused_mlp=use_fused_mlp) - vocab_size_padded = pad_vocab_size(vocab_size, mapping.tp_size) - if self.mapping.is_last_pp_rank(): - self.lm_head = ColumnLinear(hidden_size, - vocab_size_padded, - bias=False, - dtype=dtype, - tp_group=mapping.tp_group, - tp_size=mapping.tp_size, - gather_output=True) - - def forward(self, - input_ids, - position_ids=None, - use_cache=False, - last_token_ids=None, - kv_cache_params=None, - attention_params=None, - hidden_states=None, - prompt_embedding_table: Optional[Tensor] = None, - prompt_tasks: Optional[Tensor] = None, - prompt_vocab_size: Optional[Tensor] = None): - hidden_states = super().forward(input_ids, position_ids, use_cache, - kv_cache_params, attention_params, - hidden_states, prompt_embedding_table, - prompt_tasks, prompt_vocab_size) - if use_cache: - hidden_states, presents = hidden_states +class QWenForCausalLM(DecoderModelForCausalLM): - if self.mapping.is_last_pp_rank(): - hidden_states = gather_last_token_logits( - hidden_states, last_token_ids, - default_net().plugin_config.remove_input_padding) + def __init__(self, config: PretrainedConfig): + self.check_config(config) + transformer = QWenModel(config) + vocab_size_padded = pad_vocab_size(config.vocab_size, + config.mapping.tp_size) - # [batch_size, hidden_size] -> [batch_size, vocab_size] - lm_logits = self.lm_head(hidden_states) - lm_logits.mark_output('logits', self.logits_dtype) + if config.mapping.is_last_pp_rank(): + lm_head = ColumnLinear(config.hidden_size, + vocab_size_padded, + bias=False, + dtype=config.dtype, + tp_group=config.mapping.tp_group, + tp_size=config.mapping.tp_size, + gather_output=True) else: - hidden_states.mark_output('hidden_states_output', self.dtype) - - if use_cache and default_net().plugin_config.paged_kv_cache == False: - for i, present in zip(self.mapping.pp_layers(self.num_layers), - presents): - present.mark_output(f'present_key_value_{i}', self.kv_dtype) - if self.mapping.is_last_pp_rank(): - return (lm_logits, presents) - return (hidden_states, presents) - else: - if self.mapping.is_last_pp_rank(): - return lm_logits - return hidden_states - - def prepare_inputs( - self, - max_batch_size, - max_input_len, - max_seq_len, - use_cache, - max_beam_width: int = 1, - max_num_tokens: int = None, - prompt_embedding_table_size=256, - gather_context_logits: bool = False, - gather_generation_logits: bool = False, - ): - '''@brief: Prepare inputs Tensors for the model, the given sizes are used to determine the - ranges of the dimensions of when using TRT dynamic shapes. - - @return: a list contains values which can be fed into the self.forward() - ''' - - # Prepare inputs - head_size = self.hidden_size // self.num_heads - remove_input_padding = default_net().plugin_config.remove_input_padding - use_gpt_attention_plugin = default_net( - ).plugin_config.gpt_attention_plugin - use_gemm_plugin = default_net().plugin_config.gemm_plugin - paged_kv_cache = default_net().plugin_config.paged_kv_cache - tokens_per_block = default_net().plugin_config.tokens_per_block - use_custom_all_reduce = default_net( - ).plugin_config.use_custom_all_reduce - - model_inputs = self.prepare_basic_inputs( - max_batch_size=max_batch_size, - max_beam_width=max_beam_width, - max_input_len=max_input_len, - max_seq_len=max_seq_len, - num_kv_heads=self.num_kv_heads, - head_size=head_size, - num_layers=self.num_layers, - kv_dtype=self.kv_dtype, - remove_input_padding=remove_input_padding, - use_gpt_attention_plugin=use_gpt_attention_plugin, - use_gemm_plugin=use_gemm_plugin, - use_custom_all_reduce=use_custom_all_reduce, - paged_kv_cache=paged_kv_cache, - tokens_per_block=tokens_per_block, - dtype=self.dtype, - num_heads=self.num_heads, - mapping=self.mapping, - max_num_tokens=max_num_tokens, - prompt_embedding_table_size=prompt_embedding_table_size, - gather_context_logits=gather_context_logits, - gather_generation_logits=gather_generation_logits) - - return ( - model_inputs['input_ids'], model_inputs['position_ids'], True, - model_inputs['last_token_ids'], - KeyValueCacheParams( - past_key_value=model_inputs['past_key_value'], - host_past_key_value_lengths=model_inputs[ - 'host_past_key_value_lengths'], - host_max_attention_window_sizes=model_inputs[ - 'host_max_attention_window_sizes'], - host_sink_token_length=model_inputs['host_sink_token_length'], - kv_cache_block_pointers=model_inputs['kv_cache_block_pointers'], - host_kv_cache_block_pointers=model_inputs[ - 'host_kv_cache_block_pointers'], - cache_indirection=model_inputs['cache_indirection'], - ), - AttentionParams( - sequence_length=model_inputs['sequence_length'], - context_lengths=model_inputs['context_lengths'], - host_context_lengths=model_inputs['host_context_lengths'], - max_context_length=max_input_len, - host_request_types=model_inputs['host_request_types']), - model_inputs['hidden_states_input'], - model_inputs['prompt_embedding_table'], model_inputs['tasks'], - model_inputs['prompt_vocab_size']) + lm_head = None + self.quant_mode = config.quant_mode + self.mapping = config.mapping + super().__init__(config, transformer, lm_head) + + def check_config(self, config): + config.set_if_not_exist('rotary_base', 10000.0) + config.set_if_not_exist('rotary_scaling', None) diff --git a/tensorrt_llm/plugin/plugin.py b/tensorrt_llm/plugin/plugin.py index ba6c6ae6f..f856a4223 100644 --- a/tensorrt_llm/plugin/plugin.py +++ b/tensorrt_llm/plugin/plugin.py @@ -95,6 +95,7 @@ class PluginConfig: use_context_fmha_for_generation: bool = False dense_context_fmha: bool = False pos_shift: bool = False + multiple_profiles: bool = False def set_plugin(self, name: str, value: Union[str, bool, int]): assert hasattr(self, name), f"Plugin name doesn't exist: {name}" @@ -305,7 +306,8 @@ def enable_pos_shift(self): "use_paged_context_fmha", "use_context_fmha_for_generation", "dense_context_fmha", - "pos_shift" + "pos_shift", + "multiple_profiles", ] plugin_options = ["float16", "float32", "bfloat16", "disable"] @@ -370,16 +372,14 @@ def gen_id(self) -> int: def set_workspace_tensor(self, mapping: Mapping, - two_opt_profiles: Optional[bool] = None): + num_profiles: Optional[int] = None): from ..functional import Tensor workspace_size = self.POINTERS_PER_RANK * mapping.tp_size dim_range = None - if two_opt_profiles is not None: - dim_range = OrderedDict([ - ('all_reduce_size', [workspace_size, workspace_size] - if two_opt_profiles else [workspace_size]) - ]) + if num_profiles is not None: + dim_range = OrderedDict([('all_reduce_size', + [workspace_size] * num_profiles)]) self.workspace = Tensor( name='all_reduce_workspace', diff --git a/tensorrt_llm/quantization/layers.py b/tensorrt_llm/quantization/layers.py index f7514bc35..a4bb31262 100644 --- a/tensorrt_llm/quantization/layers.py +++ b/tensorrt_llm/quantization/layers.py @@ -93,8 +93,7 @@ def __init__(self, tp_group=None, tp_size=1, gather_output=True, - quant_mode=QuantMode(0), - max_lora_rank=None): + quant_mode=QuantMode(0)): super().__init__() self.in_features = in_features self.out_features = out_features // tp_size @@ -167,7 +166,6 @@ def __init__( tp_group=None, tp_size=1, quant_mode=QuantMode(0), - max_lora_rank=None, ): super().__init__() self.in_features = in_features // tp_size @@ -230,7 +228,6 @@ def __init__( elementwise_affine=True, dtype=None, quant_mode=QuantMode(0), - max_lora_rank=None, ): super().__init__() if isinstance(normalized_shape, int): @@ -279,7 +276,6 @@ def __init__( dtype=None, quant_mode=QuantMode(0), bias=False, - max_lora_rank=None, ): super().__init__() if isinstance(normalized_shape, int): @@ -333,7 +329,6 @@ def __init__( tp_size=1, gather_output=True, quant_mode=QuantMode.use_weight_only(), - max_lora_rank=None, ): super().__init__() if quant_mode.is_int8_weight_only(): @@ -399,7 +394,6 @@ def __init__( tp_group=None, tp_size=1, quant_mode=QuantMode.use_weight_only(), - max_lora_rank=None, ): super().__init__() if quant_mode.is_int8_weight_only(): @@ -454,7 +448,6 @@ def __init__( tp_group=None, tp_size=1, gather_output=True, - max_lora_rank=None, use_w4a8_awq=False, ): @@ -539,7 +532,6 @@ def __init__( dtype=None, tp_group=None, tp_size=1, - max_lora_rank=None, use_w4a8_awq=False, ): super().__init__() @@ -616,7 +608,6 @@ def __init__( tp_group=None, tp_size=1, quant_mode=QuantMode(0), - max_lora_rank=None, ): super().__init__() if hidden_act not in ACT2FN: @@ -677,8 +668,7 @@ def __init__(self, bias=True, dtype=None, tp_group=None, - tp_size=1, - max_lora_rank=None): + tp_size=1): super().__init__(in_features, out_features, bias=bias, @@ -733,7 +723,6 @@ def __init__( tp_group=None, tp_size=1, gather_output=True, - max_lora_rank=None, ): super().__init__(in_features, out_features, @@ -789,7 +778,6 @@ def __init__( tp_group=None, tp_size=1, gather_output=True, - max_lora_rank=None, ): super().__init__(in_features, out_features, @@ -840,7 +828,6 @@ def __init__( dtype=None, tp_group=None, tp_size=1, - max_lora_rank=None, ): super().__init__(in_features, out_features, @@ -895,7 +882,6 @@ def __init__( tp_group=None, tp_size=1, quant_mode=QuantMode(0), - max_lora_rank=None, ): super().__init__(hidden_size, ffn_hidden_size, @@ -970,7 +956,6 @@ def __init__( scale_alibi_bias=False, paged_kv_cache=False, quant_mode=QuantMode(0), - max_lora_rank=None, ): super().__init__() self.layer_idx = layer_idx diff --git a/tensorrt_llm/quantization/quantize.py b/tensorrt_llm/quantization/quantize.py index 6a30ab559..c4210c982 100644 --- a/tensorrt_llm/quantization/quantize.py +++ b/tensorrt_llm/quantization/quantize.py @@ -53,7 +53,6 @@ def weight_only_quantize(model, current_key_name.pop(-1) setattr(model, 'quant_mode', quant_mode) - return model @@ -62,7 +61,7 @@ def weight_only_groupwise_quantize(model, quant_algo=W4A16_AWQ, group_size=128, pre_quant_scale=False, - zero=False, + has_zero_point=False, exclude_modules=None, current_key_name=None): assert quant_mode.is_weight_only() @@ -77,8 +76,9 @@ def weight_only_groupwise_quantize(model, if len(list(module.children())) > 0: weight_only_groupwise_quantize(module, quant_mode, quant_algo, - group_size, pre_quant_scale, zero, - exclude_modules, current_key_name) + group_size, pre_quant_scale, + has_zero_point, exclude_modules, + current_key_name) if isinstance(module, ColumnLinear) and name not in exclude_modules: if not any(key in '.'.join(current_key_name) @@ -88,7 +88,7 @@ def weight_only_groupwise_quantize(model, out_features=module.out_features * module.tp_size, group_size=group_size, pre_quant_scale=pre_quant_scale, - zero=zero, + zero=has_zero_point, bias=module.bias is not None, use_w4a8_awq=quant_algo == W4A8_AWQ, dtype=module.dtype, @@ -103,7 +103,7 @@ def weight_only_groupwise_quantize(model, out_features=module.out_features, group_size=group_size, pre_quant_scale=pre_quant_scale, - zero=zero, + zero=has_zero_point, bias=module.bias is not None, use_w4a8_awq=quant_algo == W4A8_AWQ, dtype=module.dtype, @@ -112,6 +112,7 @@ def weight_only_groupwise_quantize(model, current_key_name.pop(-1) + setattr(model, 'quant_mode', quant_mode) return model @@ -147,6 +148,7 @@ def smooth_quantize_ootb(model, current_key_name.pop(-1) + setattr(model, 'quant_mode', quant_mode) return model @@ -198,8 +200,9 @@ def smooth_quantize_plugin(model, quant_mode): elif isinstance(layer.mlp, MLP): mlp_norm_cls = SmoothQuantMLP + mlp_hidden_size = config.hidden_size * 4 if config.intermediate_size is None else config.intermediate_size layer.mlp = mlp_norm_cls(hidden_size=config.hidden_size, - ffn_hidden_size=config.intermediate_size, + ffn_hidden_size=mlp_hidden_size, hidden_act=config.hidden_act, dtype=config.dtype, tp_group=config.mapping.tp_group, @@ -222,6 +225,7 @@ def smooth_quantize_plugin(model, quant_mode): dtype=config.dtype, quant_mode=quant_mode) + setattr(model, 'quant_mode', quant_mode) return model @@ -236,20 +240,18 @@ def smooth_quantize(model, quant_mode, use_plugin=False): def quantize(model, quant_mode, **kwargs): if quant_mode.has_act_and_weight_quant(): use_plugin = kwargs.get('quant_algo', None) in W8A8_SQ_PLUGIN_LIST - smooth_quantize(model, quant_mode, use_plugin=use_plugin) + return smooth_quantize(model, quant_mode, use_plugin=use_plugin) elif quant_mode.is_weight_only(): if quant_mode.has_per_group_scaling(): quant_kwargs = { k: kwargs[k] for k in [ 'quant_algo', 'group_size', 'pre_quant_scale', - 'exclude_modules' - ] + 'has_zero_point', 'exclude_modules' + ] if k in kwargs } - # due to legacy reason, the weight_only_groupwise_quantize function take 'zero' as arg - # while the checkpoint uses 'has_zero_point' - quant_kwargs['zero'] = kwargs['has_zero_point'] - weight_only_groupwise_quantize(model, quant_mode, **quant_kwargs) + return weight_only_groupwise_quantize(model, quant_mode, + **quant_kwargs) else: - kwargs = {k: kwargs[k] for k in ['exclude_modules']} - weight_only_quantize(model, quant_mode, **kwargs) + kwargs = {k: kwargs[k] for k in ['exclude_modules'] if k in kwargs} + return weight_only_quantize(model, quant_mode, **kwargs) diff --git a/tensorrt_llm/quantization/quantize_by_ammo.py b/tensorrt_llm/quantization/quantize_by_ammo.py index a9d05102f..ff0b57409 100644 --- a/tensorrt_llm/quantization/quantize_by_ammo.py +++ b/tensorrt_llm/quantization/quantize_by_ammo.py @@ -25,10 +25,13 @@ import numpy as np import safetensors import torch +from ammo.torch.export.tensorrt_llm_utils import MODEL_NAME_TO_HF_ARCH_MAP from datasets import load_dataset from torch.utils.data import DataLoader from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer +MODEL_NAME_TO_HF_ARCH_MAP.update({"gpt2": "GPTForCausalLM"}) + EMPTY_CFG = { "quant_cfg": { "*weight_quantizer": { @@ -129,7 +132,7 @@ def get_tokenizer(ckpt_path, max_seq_length, model_type=None): tokenizer.eos_token = tokenizer.convert_ids_to_tokens(151643) # can't set attribute 'pad_token' for "" - if tokenizer.pad_token != "": + if tokenizer.pad_token != "": # nosec B105 tokenizer.pad_token = tokenizer.eos_token if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token @@ -223,6 +226,8 @@ def calibrate_loop(): """Adjusts weights and scaling factors based on selected algorithms.""" for idx, data in enumerate(calib_dataloader): print(f"Calibrating batch {idx}") + # model might be mapped to different device because the device_map is auto + data = data.to(model.device) model(data) print("Starting quantization...") @@ -311,47 +316,64 @@ def quantize_and_export(*, model_dir, dtype, device, qformat, kv_cache_dtype, export_path = output_dir start_time = time.time() - if qformat == "int4_awq" and model_type == "qwen": - torch.save(model.state_dict(), export_path) - else: - export_npz = (model_type not in [ - 'gptj', 'falcon', 'chatglm', 'mpt', 'llama', 'baichuan', 'gemma' - ]) - export_model_config(model, - model_type, - getattr(torch, dtype), - export_dir=export_path, - inference_tensor_parallel=tp_size, - inference_pipeline_parallel=pp_size, - export_tensorrt_llm_config=(not export_npz), - export_npz=export_npz) - - # Workaround for wo quantization - if qformat in ["int8_wo", "int4_wo", "full_prec"]: + export_npz = (model_type not in [ + 'gpt2', 'gptj', 'falcon', 'chatglm', 'mpt', 'llama', 'baichuan', + 'gemma', 'qwen' + ]) + export_model_config(model, + model_type, + getattr(torch, dtype), + export_dir=export_path, + inference_tensor_parallel=tp_size, + inference_pipeline_parallel=pp_size, + export_tensorrt_llm_config=(not export_npz), + export_npz=export_npz) + + # Workaround for wo quantization + if qformat in ["int8_wo", "int4_wo", "full_prec"]: + with open(f"{export_path}/config.json", "r") as f: + tensorrt_llm_config = json.load(f) + if qformat == "int8_wo": + tensorrt_llm_config["quantization"]["quant_algo"] = 'W8A16' + elif qformat == "int4_wo": + tensorrt_llm_config["quantization"]["quant_algo"] = 'W4A16' + else: + tensorrt_llm_config["quantization"]["quant_algo"] = None + with open(f"{export_path}/config.json", "w") as f: + json.dump(tensorrt_llm_config, f, indent=4) + + # Workaround for share_embedding_table + if pp_size == 1: + with safetensors.safe_open(f"{export_path}/rank0.safetensors", + framework='pt', + device='cpu') as f: + share_embedding_table = 'lm_head.weight' not in f.keys() + if share_embedding_table: with open(f"{export_path}/config.json", "r") as f: tensorrt_llm_config = json.load(f) - if qformat == "int8_wo": - tensorrt_llm_config["quantization"]["quant_algo"] = 'W8A16' - elif qformat == "int4_wo": - tensorrt_llm_config["quantization"]["quant_algo"] = 'W4A16' - else: - tensorrt_llm_config["quantization"]["quant_algo"] = None + tensorrt_llm_config["share_embedding_table"] = True with open(f"{export_path}/config.json", "w") as f: json.dump(tensorrt_llm_config, f, indent=4) - # Workaround for share_embedding_table - if pp_size == 1: - with safetensors.safe_open(f"{export_path}/rank0.safetensors", - framework='pt', - device='cpu') as f: - share_embedding_table = 'lm_head.weight' not in f.keys() - if share_embedding_table: - with open(f"{export_path}/config.json", "r") as f: - tensorrt_llm_config = json.load(f) - tensorrt_llm_config["share_embedding_table"] = True - with open(f"{export_path}/config.json", "w") as f: - json.dump(tensorrt_llm_config, f, indent=4) - + # Workaround for gpt2 position embedding + if model_type == 'gpt2': + for rank in range(tp_size): + weights = {} + with safetensors.safe_open( + f"{export_path}/rank{rank}.safetensors", + framework='pt', + device='cpu') as f: + for key in f.keys(): + weights[key] = f.get_tensor(key) + if 'transformer.positional_embedding.weight' in weights: + weights[ + 'transformer.position_embedding.weight'] = weights.pop( + 'transformer.positional_embedding.weight') + safetensors.torch.save_file( + weights, f"{export_path}/rank{rank}.safetensors") + + torch.cuda.empty_cache( + ) # otherwise torch is keeping using GPU, other routine like build engine has less free GPU to use end_time = time.time() print( "Quantized model exported to {} \nTotal time used {:.2f} s.".format( diff --git a/tensorrt_llm/runtime/generation.py b/tensorrt_llm/runtime/generation.py index 3f44aca40..97996c902 100755 --- a/tensorrt_llm/runtime/generation.py +++ b/tensorrt_llm/runtime/generation.py @@ -318,7 +318,6 @@ class ModelConfig: lora_plugin: bool = False lora_target_modules: List[str] = field(default_factory=list) use_context_fmha_for_generation: bool = False - hf_modules_to_trtllm_modules: dict = None trtllm_modules_to_hf_modules: dict = None skip_cross_qkv: bool = False num_medusa_heads: int = 0 @@ -345,9 +344,9 @@ class SamplingConfig: temperature: Union[float, torch.Tensor] = field(default=1.0) top_k: Union[int, torch.Tensor] = field(default=1) top_p: Union[float, torch.Tensor] = field(default=0.0) - top_p_decay: Optional[float] = field(default=None) - top_p_min: Optional[float] = field(default=None) - top_p_reset_ids: Optional[int] = field(default=None) + top_p_decay: Optional[torch.Tensor] = field(default=None) # float + top_p_min: Optional[torch.Tensor] = field(default=None) # float + top_p_reset_ids: Optional[torch.Tensor] = field(default=None) # int length_penalty: Union[float, torch.Tensor] = field(default=1.0) early_stopping: Union[int, torch.Tensor] = field(default=1) @@ -691,7 +690,7 @@ def __init__(self, logger.error(f"Found tensor names: {found_tensor_names}") raise RuntimeError( "Tensor names in engine are not the same as expected, to use this GenerationSession, " - "you need to use GPTLMHeadModel.prepare_inputs to create TRT Network inputs." + "you need to use PretrainedModel.prepare_inputs to create TRT Network inputs." ) if self.debug_mode: self.debug_tensors = list( @@ -891,14 +890,28 @@ def __setup_decoder(self, input_ids: torch.Tensor, scfg.repetition_penalty, dtype=torch.float32) - self.host_length_penalty = torch.full([batch_size], - scfg.length_penalty, - dtype=torch.float32) + if isinstance(scfg.length_penalty, torch.Tensor): + assert scfg.length_penalty.dtype == torch.float32, f"scfg.length_penalty.dtype ({scfg.length_penalty.dtype}) must be torch.float32" + assert scfg.length_penalty.shape[ + 0] == batch_size, f"scfg.length_penalty.shape[0] ({scfg.length_penalty.shape[0]}) must equal to batch_size ({batch_size})" + self.host_length_penalty = scfg.length_penalty + else: + self.host_length_penalty = torch.full([batch_size], + scfg.length_penalty, + dtype=torch.float32) + self.length_penalty = self.host_length_penalty.to(self.device) - self.host_early_stopping = torch.full([batch_size], - scfg.early_stopping, - dtype=torch.int32) + if isinstance(scfg.early_stopping, torch.Tensor): + assert scfg.early_stopping.dtype == torch.int32, f"scfg.early_stopping.dtype ({scfg.early_stopping.dtype}) must be torch.int32" + assert scfg.early_stopping.shape[ + 0] == batch_size, f"scfg.early_stopping.shape[0] ({scfg.early_stopping.shape[0]}) must equal to batch_size ({batch_size})" + self.host_early_stopping = scfg.early_stopping + else: + self.host_early_stopping = torch.full([batch_size], + scfg.early_stopping, + dtype=torch.int32) + self.early_stopping = self.host_early_stopping.to(self.device) if isinstance(scfg.presence_penalty, torch.Tensor): @@ -1444,9 +1457,8 @@ def add_tensor_with_shape(x, name, shape): if self.skip_cross_qkv: if self.cross_qkv_reuse is None: # see Attention's self.qkv output dim - cross_qkv_out_dim = self.mapping.tp_size * self.num_heads * self.head_size + ( - 2 * self.mapping.tp_size * self.num_heads_kv * - self.head_size) + cross_qkv_out_dim = self.num_heads * self.head_size + ( + 2 * self.num_heads_kv * self.head_size) cross_qkv_shape = encoder_output.shape[:-1] + ( cross_qkv_out_dim, ) cross_qkv_reuse = torch.empty(cross_qkv_shape, @@ -3226,8 +3238,7 @@ def __init__( expected_tensor_names += ['input_ids'] expected_tensor_names += ['logits'] expected_tensor_names += ['host_request_types'] - if not model_config.gather_context_logits: - expected_tensor_names += ['last_token_ids'] + expected_tensor_names += ['last_token_ids'] expected_tensor_names += [ f'past_conv_state_{i}' @@ -3313,16 +3324,10 @@ def setup(self, dtype=self._tensor_dtype('logits'), device=self.device) - ctx_conv_state_shape = ( - batch_size, - self.mamba_d_inner, - self.mamba_d_conv - 1 + self.max_context_length, - ) - - gen_conv_state_shape = ( + conv_state_shape = ( batch_size, self.mamba_d_inner, - self.mamba_d_conv, + self.mamba_d_conv - 1, ) ssm_state_shape = ( @@ -3336,9 +3341,9 @@ def setup(self, # They will take turns to act as input and output buffers. dtype = self._tensor_dtype(f'present_conv_state_{i}') self.buffer[f'present_conv_state_{i}'] = torch.empty( - ctx_conv_state_shape, dtype=dtype, device=self.device) + conv_state_shape, dtype=dtype, device=self.device) self.buffer[f'1_present_conv_state_{i}'] = torch.empty( - gen_conv_state_shape, dtype=dtype, device=self.device) + conv_state_shape, dtype=dtype, device=self.device) self.buffer[f'present_ssm_state_{i}'] = torch.empty( ssm_state_shape, dtype=dtype, device=self.device) @@ -3372,15 +3377,14 @@ def add_tensor(x, name): add_tensor(input_ids, 'input_ids') add_tensor(self.buffer['logits'], 'logits') - if not self.gather_context_logits: - add_tensor(last_token_ids, 'last_token_ids') + add_tensor(last_token_ids, 'last_token_ids') batch_size = context_lengths.shape[0] + conv_state_shape = (batch_size, self.mamba_d_inner, + self.mamba_d_conv - 1) for idx in range(self.first_layer, self.last_layer): # conv state dtype = self._tensor_dtype(f'present_conv_state_{idx}') - conv_state_shape = (batch_size, self.mamba_d_inner, - self.mamba_d_conv - 1) conv_state = torch.zeros(conv_state_shape, dtype=dtype, device=self.device) @@ -3439,35 +3443,22 @@ def add_tensor_with_shape(x, name, shape): input_ids_shape = (batch_size * beam_width, 1) add_tensor_with_shape(self.new_tokens, 'input_ids', input_ids_shape) add_tensor(self.buffer['logits'], 'logits') - if not self.gather_context_logits: - add_tensor(last_token_ids, 'last_token_ids') + add_tensor(last_token_ids, 'last_token_ids') for idx in range(self.first_layer, self.last_layer): # conv state - if step == 0: - next_shape_in = (batch_size, self.mamba_d_inner, - self.mamba_d_conv - 1 + - self.max_context_length) - next_shape_out = (batch_size, self.mamba_d_inner, - self.mamba_d_conv) - else: - next_shape_in = (batch_size, self.mamba_d_inner, - self.mamba_d_conv) - next_shape_out = (batch_size, self.mamba_d_inner, - self.mamba_d_conv) if step % 2: - add_tensor_with_shape( - self.buffer[f'1_present_conv_state_{idx}'], - f'past_conv_state_{idx}', next_shape_in) - add_tensor_with_shape(self.buffer[f'present_conv_state_{idx}'], - f'present_conv_state_{idx}', - next_shape_out) + add_tensor(self.buffer[f'1_present_conv_state_{idx}'], + f'past_conv_state_{idx}') + add_tensor( + self.buffer[f'present_conv_state_{idx}'], + f'present_conv_state_{idx}', + ) else: - add_tensor_with_shape(self.buffer[f'present_conv_state_{idx}'], - f'past_conv_state_{idx}', next_shape_in) - add_tensor_with_shape( - self.buffer[f'1_present_conv_state_{idx}'], - f'present_conv_state_{idx}', next_shape_out) + add_tensor(self.buffer[f'present_conv_state_{idx}'], + f'past_conv_state_{idx}') + add_tensor(self.buffer[f'1_present_conv_state_{idx}'], + f'present_conv_state_{idx}') # ssm state ssm_state = self.buffer[f'present_ssm_state_{idx}'] add_tensor(ssm_state, f'past_ssm_state_{idx}') diff --git a/tensorrt_llm/runtime/model_runner.py b/tensorrt_llm/runtime/model_runner.py index c9aaab796..6a258855d 100644 --- a/tensorrt_llm/runtime/model_runner.py +++ b/tensorrt_llm/runtime/model_runner.py @@ -122,8 +122,6 @@ def _builder_to_model_config(config: dict) -> Tuple[ModelConfig, dict]: 'max_prompt_embedding_table_size', 0) quant_mode = QuantMode(builder_config.get('quant_mode', 0)) lora_target_modules = builder_config.get('lora_target_modules') - lora_hf_modules_to_trtllm_modules = builder_config.get( - 'hf_modules_to_trtllm_modules') lora_trtllm_modules_to_hf_modules = builder_config.get( 'trtllm_modules_to_hf_modules') max_medusa_token_len = builder_config.get('max_draft_len', 0) @@ -165,7 +163,6 @@ def _builder_to_model_config(config: dict) -> Tuple[ModelConfig, dict]: lora_plugin=lora_plugin, lora_target_modules=lora_target_modules, use_context_fmha_for_generation=use_context_fmha_for_generation, - hf_modules_to_trtllm_modules=lora_hf_modules_to_trtllm_modules, trtllm_modules_to_hf_modules=lora_trtllm_modules_to_hf_modules, num_medusa_heads=num_medusa_heads, max_medusa_tokens=max_medusa_token_len, @@ -339,7 +336,7 @@ def __init__(self, @classmethod def from_engine(cls, engine: Engine, - lora_dir: Optional[str] = None, + lora_dir: Optional[List[str]] = None, rank: int = 0, debug_mode: bool = False, lora_ckpt_source: str = "hf", @@ -387,14 +384,9 @@ def from_engine(cls, mamba_expand=mamba_expand, mamba_d_conv=mamba_d_conv, lora_plugin=build_config.plugin_config.lora_plugin, - lora_target_modules=pretrained_config.lora_target_modules - if hasattr(pretrained_config, 'lora_target_modules') else [], - hf_modules_to_trtllm_modules=pretrained_config. - hf_modules_to_trtllm_modules if hasattr( - pretrained_config, 'hf_modules_to_trtllm_modules') else [], - trtllm_modules_to_hf_modules=pretrained_config. - trtllm_modules_to_hf_modules if hasattr( - pretrained_config, 'trtllm_modules_to_hf_modules') else [], + lora_target_modules=build_config.lora_config.lora_target_modules, + trtllm_modules_to_hf_modules=build_config.lora_config. + trtllm_modules_to_hf_modules, max_medusa_tokens=pretrained_config.max_draft_len if hasattr( pretrained_config, 'max_draft_len') else 0, num_medusa_heads=pretrained_config.num_medusa_heads if hasattr( @@ -450,7 +442,7 @@ def from_engine(cls, @classmethod def from_dir(cls, engine_dir: str, - lora_dir: Optional[str] = None, + lora_dir: Optional[List[str]] = None, rank: int = 0, debug_mode: bool = False, lora_ckpt_source: str = "hf", @@ -462,8 +454,8 @@ def from_dir(cls, Args: engine_dir (str): The directory that contains the serialized engine files and config files. - lora_dir (str): - The directory that contains LoRA weights. + lora_dir (Optional[List[str]]): + The directories that contain LoRA weights. rank (int): The runtime rank id. debug_mode (bool): @@ -545,6 +537,13 @@ def from_dir(cls, else: # the new engine format engine = Engine.from_dir(engine_dir, rank) + if lora_dir is None: + config_lora_dir = engine.config.build_config.lora_config.lora_dir + if len(config_lora_dir) > 0: + lora_dir = [ + f"{engine_dir}/{dir}" for dir in config_lora_dir + ] + lora_ckpt_source = engine.config.build_config.lora_config.lora_ckpt_source runner = ModelRunner.from_engine(engine, lora_dir, rank, debug_mode, lora_ckpt_source, medusa_choices, stream) diff --git a/tensorrt_llm/runtime/model_runner_cpp.py b/tensorrt_llm/runtime/model_runner_cpp.py index be072e352..3866d888e 100644 --- a/tensorrt_llm/runtime/model_runner_cpp.py +++ b/tensorrt_llm/runtime/model_runner_cpp.py @@ -286,7 +286,9 @@ def generate(self, sampling_config = copy.deepcopy(sampling_config) sampling_config.update(**kwargs) self._check_inputs(batch_input_ids, sampling_config) - gpt_sampling_config = _populate_sampling_config(sampling_config) + batch_size = len(batch_input_ids) + gpt_sampling_config = _populate_sampling_config(sampling_config, + batch_size) if lora_uids is not None: raise RuntimeError("LoRA is not supported in C++ session.") if streaming: @@ -298,7 +300,6 @@ def generate(self, raise RuntimeError( "Logits processor is not supported in C++ session.") - batch_size = len(batch_input_ids) batch_input_ids, input_lengths = self._prepare_inputs( batch_input_ids, sampling_config.pad_id) @@ -356,30 +357,129 @@ def generate(self, return outputs -def _populate_sampling_config( - sampling_config: SamplingConfig) -> GptSamplingConfig: +def _populate_sampling_config(sampling_config: SamplingConfig, + batch_size: int) -> GptSamplingConfig: gpt_sampling_config = GptSamplingConfig(sampling_config.num_beams) - gpt_sampling_config.beam_search_diversity_rate = [ - sampling_config.beam_search_diversity_rate - ] - gpt_sampling_config.length_penalty = [sampling_config.length_penalty] - gpt_sampling_config.early_stopping = [sampling_config.early_stopping] - gpt_sampling_config.min_length = [sampling_config.min_length] - # TODO: cannot set presence_penalty and frequency_penalty? - # gpt_sampling_config.presence_penalty = [sampling_config.presence_penalty] - # gpt_sampling_config.frequency_penalty = [sampling_config.frequency_penalty] - if sampling_config.random_seed is not None: + + if isinstance(sampling_config.beam_search_diversity_rate, torch.Tensor): + assert sampling_config.beam_search_diversity_rate.dtype == torch.float32, f"sampling_config.beam_search_diversity_rate.dtype ({sampling_config.beam_search_diversity_rate.dtype}) must be torch.float32" + assert sampling_config.beam_search_diversity_rate.shape[ + 0] == batch_size, f"sampling_config.beam_search_diversity_rate.shape[0] ({sampling_config.beam_search_diversity_rate.shape[0]}) must equal to batch_size ({batch_size})" + gpt_sampling_config.beam_search_diversity_rate = sampling_config.beam_search_diversity_rate.tolist( + ) + elif sampling_config.beam_search_diversity_rate is not None: + gpt_sampling_config.beam_search_diversity_rate = [ + sampling_config.beam_search_diversity_rate + ] + else: + gpt_sampling_config.beam_search_diversity_rate = None + + if isinstance(sampling_config.length_penalty, torch.Tensor): + assert sampling_config.length_penalty.dtype == torch.float32, f"sampling_config.length_penalty.dtype ({sampling_config.length_penalty.dtype}) must be torch.float32" + assert sampling_config.length_penalty.shape[ + 0] == batch_size, f"sampling_config.length_penalty.shape[0] ({sampling_config.length_penalty.shape[0]}) must equal to batch_size ({batch_size})" + gpt_sampling_config.length_penalty = sampling_config.length_penalty.tolist( + ) + elif sampling_config.length_penalty == 1.0: + gpt_sampling_config.length_penalty = None + else: + gpt_sampling_config.length_penalty = [sampling_config.length_penalty] + + if isinstance(sampling_config.early_stopping, torch.Tensor): + assert sampling_config.early_stopping.dtype == torch.int32, f"sampling_config.early_stopping.dtype ({sampling_config.early_stopping.dtype}) must be torch.int32" + assert sampling_config.early_stopping.shape[ + 0] == batch_size, f"sampling_config.early_stopping.shape[0] ({sampling_config.early_stopping.shape[0]}) must equal to batch_size ({batch_size})" + gpt_sampling_config.early_stopping = sampling_config.early_stopping.tolist( + ) + else: + gpt_sampling_config.early_stopping = [sampling_config.early_stopping] + + if isinstance(sampling_config.min_length, torch.Tensor): + assert sampling_config.min_length.dtype == torch.int32, f"sampling_config.min_length.dtype ({sampling_config.min_length.dtype}) must be torch.int32" + assert sampling_config.min_length.shape[ + 0] == batch_size, f"sampling_config.min_length.shape[0] ({sampling_config.min_length.shape[0]}) must equal to batch_size ({batch_size})" + gpt_sampling_config.min_length = sampling_config.min_length.tolist() + else: + gpt_sampling_config.min_length = [sampling_config.min_length] + + if isinstance(sampling_config.presence_penalty, torch.Tensor): + assert sampling_config.presence_penalty.dtype == torch.float32, f"sampling_config.presence_penalty.dtype ({sampling_config.presence_penalty.dtype}) must be torch.float32" + assert sampling_config.presence_penalty.shape[ + 0] == batch_size, f"sampling_config.presence_penalty.shape[0] ({sampling_config.presence_penalty.shape[0]}) must equal to batch_size ({batch_size})" + gpt_sampling_config.presence_penalty = sampling_config.presence_penalty.tolist( + ) + elif sampling_config.presence_penalty == 0.0: + gpt_sampling_config.presence_penalty = None + else: + gpt_sampling_config.presence_penalty = [ + sampling_config.presence_penalty + ] + + if isinstance(sampling_config.frequency_penalty, torch.Tensor): + assert sampling_config.frequency_penalty.dtype == torch.float32, f"sampling_config.frequency_penalty.dtype ({sampling_config.frequency_penalty.dtype}) must be torch.float32" + assert sampling_config.frequency_penalty.shape[ + 0] == batch_size, f"sampling_config.frequency_penalty.shape[0] ({sampling_config.frequency_penalty.shape[0]}) must equal to batch_size ({batch_size})" + gpt_sampling_config.frequency_penalty = sampling_config.frequency_penalty.tolist( + ) + elif sampling_config.frequency_penalty == 0.0: + gpt_sampling_config.frequency_penalty = None + else: + gpt_sampling_config.frequency_penalty = [ + sampling_config.frequency_penalty + ] + + if isinstance(sampling_config.random_seed, torch.Tensor): + assert sampling_config.random_seed.dtype == torch.int64, f"sampling_config.random_seed.dtype ({sampling_config.random_seed.dtype}) must be torch.int64" + assert sampling_config.random_seed.shape[ + 0] == batch_size, f"sampling_config.random_seed.shape[0] ({sampling_config.random_seed.shape[0]}) must equal to batch_size ({batch_size})" + gpt_sampling_config.random_seed = sampling_config.random_seed + elif sampling_config.random_seed is not None: gpt_sampling_config.random_seed = [sampling_config.random_seed] - gpt_sampling_config.repetition_penalty = [ - sampling_config.repetition_penalty - ] - gpt_sampling_config.temperature = [sampling_config.temperature] - gpt_sampling_config.top_k = [sampling_config.top_k] - gpt_sampling_config.top_p = [sampling_config.top_p] + else: + gpt_sampling_config.random_seed = None + + if isinstance(sampling_config.repetition_penalty, torch.Tensor): + assert sampling_config.repetition_penalty.dtype == torch.float32, f"sampling_config.repetition_penalty.dtype ({sampling_config.repetition_penalty.dtype}) must be torch.float32" + assert sampling_config.repetition_penalty.shape[ + 0] == batch_size, f"sampling_config.repetition_penalty.shape[0] ({sampling_config.repetition_penalty.shape[0]}) must equal to batch_size ({batch_size})" + gpt_sampling_config.repetition_penalty = sampling_config.repetition_penalty.tolist( + ) + elif sampling_config.repetition_penalty == 1.0: + gpt_sampling_config.repetition_penalty = None + else: + gpt_sampling_config.repetition_penalty = [ + sampling_config.repetition_penalty + ] + + if isinstance(sampling_config.temperature, torch.Tensor): + assert sampling_config.temperature.dtype == torch.float32, f"sampling_config.temperature.dtype ({sampling_config.temperature.dtype}) must be torch.float32" + assert sampling_config.temperature.shape[ + 0] == batch_size, f"sampling_config.temperature.shape[0] ({sampling_config.temperature.shape[0]}) must equal to batch_size ({batch_size})" + gpt_sampling_config.temperature = sampling_config.temperature.tolist() + else: + gpt_sampling_config.temperature = [sampling_config.temperature] + + if isinstance(sampling_config.top_k, torch.Tensor): + assert sampling_config.top_k.dtype == torch.int32, f"sampling_config.top_k.dtype ({sampling_config.top_k.dtype}) must be torch.int32" + assert sampling_config.top_k.shape[ + 0] == batch_size, f"sampling_config.top_k.shape[0] ({sampling_config.top_k.shape[0]}) must equal to batch_size ({batch_size})" + gpt_sampling_config.top_k = sampling_config.top_k.tolist() + else: + gpt_sampling_config.top_k = [sampling_config.top_k] + + if isinstance(sampling_config.top_p, torch.Tensor): + assert sampling_config.top_p.dtype == torch.float32, f"sampling_config.top_p.dtype ({sampling_config.top_p.dtype}) must be torch.float32" + assert sampling_config.top_p.shape[ + 0] == batch_size, f"sampling_config.top_p.shape[0] ({sampling_config.top_p.shape[0]}) must equal to batch_size ({batch_size})" + gpt_sampling_config.top_p = sampling_config.top_p.tolist() + else: + gpt_sampling_config.top_p = [sampling_config.top_p] + if sampling_config.top_p_decay is not None: - gpt_sampling_config.top_p_decay = [sampling_config.top_p_decay] + gpt_sampling_config.top_p_decay = sampling_config.top_p_decay.tolist() if sampling_config.top_p_min is not None: - gpt_sampling_config.top_p_min = [sampling_config.top_p_min] + gpt_sampling_config.top_p_min = sampling_config.top_p_min.tolist() if sampling_config.top_p_reset_ids is not None: - gpt_sampling_config.top_p_reset_ids = [sampling_config.top_p_reset_ids] + gpt_sampling_config.top_p_reset_ids = sampling_config.top_p_reset_ids.tolist( + ) return gpt_sampling_config diff --git a/tensorrt_llm/top_model_mixin.py b/tensorrt_llm/top_model_mixin.py index 594c9ede9..935a06201 100644 --- a/tensorrt_llm/top_model_mixin.py +++ b/tensorrt_llm/top_model_mixin.py @@ -15,6 +15,7 @@ from typing import Optional +from .lora_manager import LoraBuildConfig from .mapping import Mapping from .plugin.plugin import PluginConfig from .quantization.mode import QuantMode @@ -49,24 +50,13 @@ def from_hugging_face(cls, ''' raise NotImplementedError("Subclass shall override this") - @classmethod - def from_faster_transformer(cls, ft_model_dir: str): + def use_lora(self, lora_config: LoraBuildConfig): ''' - create and object and load weights from FasterTransformer''' - raise NotImplementedError("Subclass shall override this") - - @classmethod - def from_checkpoint(cls, checkpoint_dir: str): - raise NotImplementedError("Will implement in the future release") - - def use_lora(self, lora_dir: str, lora_ckpt_source: str): - '''Load lora weights and config from the give dir to the module. lora_format should be one of 'hf' or 'nemo'. - lora_dir: the directory contains the lora weights + Load lora weights from the give config to the module + Parameters: + lora_config: the lora config ''' - # TODO: this is build time API, so pack the lora data together as engine - self.lora_dir = lora_dir - self.lora_ckpt_source = lora_ckpt_source - raise NotImplementedError # Fill more details later + raise NotImplementedError("Subclass shall override this") def use_prompt_tuning(self, max_prompt_embedding_table_size: str, prompt_table_path: str): @@ -79,16 +69,6 @@ def use_prompt_tuning(self, max_prompt_embedding_table_size: str, self.max_prompt_embedding_table_size = max_prompt_embedding_table_size raise NotImplementedError # Fill more details later - def use_streaming_llm(self, sink_token_length: int): - '''Enable Streaming-LLM feature - ''' - raise NotImplementedError - - def config_moe(self, moe_top_k: int, moe_tp_mode, moe_renorm_mode): - '''Configure the moe tuning parameters, the model must a MoE model, otherwise, this fails. - ''' - raise NotImplementedError - def default_plugin_config(self, **kwargs) -> 'PluginConfig': '''Return the default plugin config for this model, when the plugin_config value is not given in to_trt() call. If users need to set different plugin configs, they can start from the return object and change it. diff --git a/tensorrt_llm/version.py b/tensorrt_llm/version.py index a9de4b3e0..d9f4c43c4 100644 --- a/tensorrt_llm/version.py +++ b/tensorrt_llm/version.py @@ -12,4 +12,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.9.0.dev2024031200" +__version__ = "0.9.0.dev2024031900" diff --git a/tests/attention/test_gpt_attention.py b/tests/attention/test_gpt_attention.py index 01e77b470..6d155fee6 100644 --- a/tests/attention/test_gpt_attention.py +++ b/tests/attention/test_gpt_attention.py @@ -620,6 +620,7 @@ def _construct_execution( builder_config = builder.create_builder_config( name=attention_type, precision=dtype, + opt_level=0, int8=int8_trt_flag, quant_mode=quant_mode) diff --git a/tests/attention/test_gpt_attention_IFB.py b/tests/attention/test_gpt_attention_IFB.py index 14da91688..3edc24379 100644 --- a/tests/attention/test_gpt_attention_IFB.py +++ b/tests/attention/test_gpt_attention_IFB.py @@ -57,7 +57,7 @@ def setUp(self): def _build_trt_engine(self, trt_network, trt_builder, dtype, shape_dict, use_int8): - config = trt_builder.create_builder_config() + config = trt_builder.create_builder_config(opt_level=0) if dtype == 'float16': config.flags = 1 << (int)(trt.BuilderFlag.FP16) @@ -380,6 +380,7 @@ def _construct_execution( int8_trt_flag = False builder_config = builder.create_builder_config(name=attention_type, precision=dtype, + opt_level=0, int8=int8_trt_flag) if session is None: engine = builder.build_engine(net, builder_config) diff --git a/tests/attention/test_gpt_attention_no_cache.py b/tests/attention/test_gpt_attention_no_cache.py index 16e14c50a..99104d451 100644 --- a/tests/attention/test_gpt_attention_no_cache.py +++ b/tests/attention/test_gpt_attention_no_cache.py @@ -79,8 +79,7 @@ def build_engine(qkv_shape, kv_dtype=kv_dtype, remove_input_padding=remove_input_padding, use_gpt_attention_plugin=True, - use_gemm_plugin= - True, # because we don't want two optimization profiles + enable_ctx_gen_opt_profiles=False, use_cache=use_cache, ) diff --git a/tests/bindings/test_bindings.py b/tests/bindings/test_bindings.py index 5ce9e5a1d..3b9a3a590 100644 --- a/tests/bindings/test_bindings.py +++ b/tests/bindings/test_bindings.py @@ -438,7 +438,7 @@ def test_llm_request(): assert llm_request.is_streaming assert llm_request.pad_id == 99 assert llm_request.end_id == 100 - assert llm_request.seq_slot == -1 # seq_slot is still uninitialized + assert llm_request.seq_slot == None assert torch.equal(llm_request.prompt_embedding_table, kwargs["prompt_embedding_table"]) assert llm_request.prompt_vocab_size == 2 @@ -642,13 +642,16 @@ def test_KvCacheConfig_pickle(): def test_TrtGptModelOptionalParams_pickle(): cache = _tb.KvCacheConfig(free_gpu_memory_fraction=0.4) - params = _tb.TrtGptModelOptionalParams( + params1 = _tb.TrtGptModelOptionalParams( kv_cache_config=cache, enable_trt_overlap=True, ) - params.enable_chunked_context = True + params1.enable_chunked_context = True + params2 = pickle.loads(pickle.dumps(params1)) - params1 = pickle.dumps(params) - params2: _tb.TrtGptModelOptionalParams = pickle.loads(params1) + assert params2 == params1 - assert params2 == params + params1 = _tb.TrtGptModelOptionalParams() + params2 = pickle.loads(pickle.dumps(params1)) + + assert params2 == params1 diff --git a/tests/bindings/test_executor_bindings.py b/tests/bindings/test_executor_bindings.py new file mode 100644 index 000000000..03e37144b --- /dev/null +++ b/tests/bindings/test_executor_bindings.py @@ -0,0 +1,783 @@ +import datetime +import os +import random +import sys +import time +from pathlib import Path + +import numpy as np +import pytest +import torch +from binding_test_utils import * + +import tensorrt_llm.bindings.executor as trtllm + +sys.path.append(os.path.join(os.path.dirname(__file__), '..')) +from utils.util import skip_pre_ampere + + +@pytest.fixture +def model_path(engine_path): + return engine_path / "gpt2/fp16-plugin-packed-paged/tp1-pp1-gpu" + + +@pytest.fixture +def model_path_return_logits(engine_path): + return engine_path / "gpt2/fp16-plugin-packed-paged-gather/tp1-pp1-gpu" + + +@pytest.fixture +def input_data_path(data_path): + return data_path / "input_tokens.npy" + + +@pytest.fixture(scope="module") +def results_data_path(data_path: Path) -> Path: + return data_path / "gpt2/sampling/output_tokens_fp16_plugin_packed_paged_tp1_pp1.npy" + + +@pytest.fixture(scope="module") +def results_data_path_beam_width_2(data_path: Path) -> Path: + return data_path / "gpt2/beam_search_2/output_tokens_fp16_plugin_packed_paged_tp1_pp1.npy" + + +@pytest.fixture +def model_files(llm_root: Path, resource_path: Path, llm_model_root, + results_data_path): + # Model engines and expected outputs need to be generated. + if not results_data_path.exists(): + model_cache_arg = ["--model_cache", + str(llm_model_root) + ] if llm_model_root is not None else [] + prepare_model_tests(llm_root, resource_path, "gpt", model_cache_arg) + + +def get_expected_num_tokens(prompt_len, max_new_tokens, streaming, + exclude_input_from_output): + if not streaming and not exclude_input_from_output: + return prompt_len + max_new_tokens + return max_new_tokens + + +@skip_pre_ampere # ContextFMHAType with fp32 acc is not supported in pre-ampere architecture +def test_executor_valid_ctor(model_files, model_path): + executor_config = trtllm.ExecutorConfig(1) + executor = trtllm.Executor(model_path, trtllm.ModelType.DECODER_ONLY, + executor_config) + + +@skip_pre_ampere # ContextFMHAType with fp32 acc is not supported in pre-ampere architecture +def test_executor_from_memory(model_files, model_path): + executor_config = trtllm.ExecutorConfig(1) + engine_buffer = open(model_path / "rank0.engine", mode="rb").read() + json_config_str = open(model_path / "config.json", 'r').read() + executor = trtllm.Executor(engine_buffer, json_config_str, + trtllm.ModelType.DECODER_ONLY, executor_config) + + +def test_executor_invalid_ctor(): + executor_config = trtllm.ExecutorConfig(1) + invalid_path = "Bla" + try: + executor = trtllm.Executor(invalid_path, trtllm.ModelType.DECODER_ONLY, + executor_config) + assert False, "Expected an error" + except Exception as e: + assert "File does not exist" in str(e) + + +@skip_pre_ampere # ContextFMHAType with fp32 acc is not supported in pre-ampere architecture +def test_embedding_bias(model_files, model_path): + streaming = False + exclude_input_from_output = False + output_config = trtllm.OutputConfig() + output_config.exclude_input_from_output = exclude_input_from_output + + # Create executor + beam_width = 1 + executor_config = trtllm.ExecutorConfig(beam_width) + executor = trtllm.Executor(model_path, trtllm.ModelType.DECODER_ONLY, + executor_config) + + # Create the request + max_new_tokens = 5 + input_tokens = [1, 2, 3, 4] + # Set embedding bias so "biased_output" is always picked + biased_output = 10 + vocab_size_padded = 50257 + embedding_bias = torch.zeros(vocab_size_padded) + embedding_bias[biased_output] = torch.finfo(torch.float32).max + request = trtllm.Request(input_tokens, + max_new_tokens, + streaming, + trtllm.SamplingConfig(), + output_config, + embedding_bias=embedding_bias) + + # Enqueue the request + request_id = executor.enqueue_request(request) + + # Get the new tokens + tokens = [] + done = False + i = 0 + max_wait_ms = 10000 + while not done and i < max_wait_ms: + wait_time = datetime.timedelta(milliseconds=1) + responses = executor.await_responses(request_id, wait_time) + for response in responses: + assert not response.has_error( + ), f"Request id {request_id} failed with err {response.error_msg}" + result = response.result + done = result.is_final + new_tokens = result.output_token_ids[beam_width - 1] + tokens.extend(new_tokens) + i += 1 + assert i < max_wait_ms + assert len(tokens) == get_expected_num_tokens( + len(input_tokens), max_new_tokens, streaming, + exclude_input_from_output), f"{request_id}" + # All generated tokens should equal biased_output + assert tokens[-max_new_tokens:] == [biased_output] * max_new_tokens + + +@pytest.mark.parametrize("streaming", [False, True]) +@pytest.mark.parametrize("exclude_input_from_output", [False]) +@skip_pre_ampere # ContextFMHAType with fp32 acc is not supported in pre-ampere architecture +def test_single_request(streaming: bool, exclude_input_from_output: bool, + model_files, model_path): + output_config = trtllm.OutputConfig() + output_config.exclude_input_from_output = exclude_input_from_output + + # Create executor + beam_width = 1 + executor_config = trtllm.ExecutorConfig(beam_width) + executor = trtllm.Executor(model_path, trtllm.ModelType.DECODER_ONLY, + executor_config) + + # Create the request + max_new_tokens = 5 + input_tokens = [1, 2, 3, 4] + request = trtllm.Request(input_tokens, max_new_tokens, streaming, + trtllm.SamplingConfig(), output_config) + + # Enqueue the request + request_id = executor.enqueue_request(request) + + # Get the new tokens + tokens = [] + done = False + i = 0 + max_wait_ms = 10000 + while not done and i < max_wait_ms: + wait_time = datetime.timedelta(milliseconds=1) + responses = executor.await_responses(request_id, wait_time) + for response in responses: + assert not response.has_error( + ), f"Request id {request_id} failed with err {response.error_msg}" + result = response.result + done = result.is_final + new_tokens = result.output_token_ids[beam_width - 1] + tokens.extend(new_tokens) + i += 1 + assert i < max_wait_ms + assert len(tokens) == get_expected_num_tokens( + len(input_tokens), max_new_tokens, streaming, + exclude_input_from_output), f"{request_id}" + + executor.get_latest_iteration_stats() + executor.get_latest_request_stats() + + +@pytest.mark.parametrize("streaming", [False, True]) +@pytest.mark.parametrize("exclude_input_from_output", [False]) +@skip_pre_ampere # ContextFMHAType with fp32 acc is not supported in pre-ampere architecture +def test_multi_request(streaming: bool, exclude_input_from_output: bool, + model_files, model_path): + output_config = trtllm.OutputConfig() + output_config.exclude_input_from_output = exclude_input_from_output + + # Create executor + beam_width = 1 + executor_config = trtllm.ExecutorConfig(beam_width) + executor = trtllm.Executor(model_path, trtllm.ModelType.DECODER_ONLY, + executor_config) + + num_requests = 20 + max_prompt_len = 20 + max_max_new_tokens = 20 + end_id = -1 + + # Enqueue the requests + tokens = {} + expected_num_tokens = {} + for i in range(num_requests): + prompt_len = random.randint(1, max_prompt_len) + max_new_tokens = random.randint(1, max_max_new_tokens) + input_tokens = [1] * prompt_len + request = trtllm.Request(input_tokens, max_new_tokens, streaming, + trtllm.SamplingConfig(), output_config, end_id) + request_id = executor.enqueue_request(request) + tokens[request_id] = [] + expected_num_tokens[request_id] = get_expected_num_tokens( + prompt_len, max_new_tokens, streaming, exclude_input_from_output) + + # Get the new tokens for each request + num_finished = 0 + i = 0 + num_responses = 0 + max_wait_ms = 10000 + while num_finished < num_requests and i < max_wait_ms: + wait_time = datetime.timedelta(milliseconds=1) + responses = executor.await_responses(None, wait_time) + for response in responses: + num_responses += 1 + assert not response.has_error( + ), f"Request id {response.request_id} failed with err {response.error_msg}" + result = response.result + num_finished += result.is_final + new_tokens = result.output_token_ids[beam_width - 1] + tokens[response.request_id].extend(new_tokens) + i += 1 + assert i < max_wait_ms + + for request_id in expected_num_tokens: + assert len(tokens[request_id]) == expected_num_tokens[request_id] + + +@pytest.mark.parametrize("streaming", [False, True]) +@pytest.mark.parametrize("exclude_input_from_output", [False]) +@skip_pre_ampere # ContextFMHAType with fp32 acc is not supported in pre-ampere architecture +def test_get_num_responses_ready(streaming: bool, + exclude_input_from_output: bool, model_files, + model_path): + output_config = trtllm.OutputConfig() + output_config.exclude_input_from_output = exclude_input_from_output + + # Create executor + executor_config = trtllm.ExecutorConfig(1) + executor = trtllm.Executor(model_path, trtllm.ModelType.DECODER_ONLY, + executor_config) + + max_prompt_len = 20 + max_max_new_tokens = 20 + + # Enqueue the requests + num_requests = random.randint(1, 50) + num_expected_responses = 0 + req_num_expected_responses = {} + for i in range(num_requests): + prompt_len = random.randint(1, max_prompt_len) + max_new_tokens = random.randint(1, max_max_new_tokens) + + request = trtllm.Request([1] * prompt_len, max_new_tokens, streaming, + trtllm.SamplingConfig(), output_config) + request_id = executor.enqueue_request(request) + req_num_expected_responses[ + request_id] = max_new_tokens if streaming else 1 + num_expected_responses += req_num_expected_responses[request_id] + + i = 0 + num_ready = 0 + max_wait_ms = 10000 + while num_ready < num_expected_responses and i < max_wait_ms: + num_ready = 0 + for request_id in req_num_expected_responses: + num_ready += executor.get_num_responses_ready(request_id) + time.sleep(0.001) + i += 1 + assert i < max_wait_ms + + for request_id in req_num_expected_responses: + num_ready = executor.get_num_responses_ready(request_id) + assert num_ready == req_num_expected_responses[request_id] + assert executor.get_num_responses_ready() == num_expected_responses + + +@pytest.mark.parametrize("batching_type", [trtllm.BatchingType.INFLIGHT]) +@pytest.mark.parametrize("streaming", [False, True]) +@pytest.mark.parametrize("beam_width", [1]) +@pytest.mark.parametrize("compute_log_probs", [False, True]) +@pytest.mark.parametrize("exclude_input_from_output", [False]) +@pytest.mark.parametrize("return_context_logits", [False, True]) +@pytest.mark.parametrize("return_generation_logits", [False, True]) +@skip_pre_ampere # ContextFMHAType with fp32 acc is not supported in pre-ampere architecture +def test_token_comparison(batching_type: trtllm.BatchingType, streaming: bool, + beam_width: int, compute_log_probs: bool, + exclude_input_from_output: bool, + return_context_logits: bool, + return_generation_logits: bool, model_files, + model_path, model_path_return_logits, input_data_path, + results_data_path, results_data_path_beam_width_2): + if streaming and beam_width > 1: + pytest.skip("Test does not support streaming with beam search") + + vocab_size_padded = 50257 + pad_id = 50256 + remove_input = not exclude_input_from_output and not streaming + + def load_test_data(input_path, results_path): + # Inputs + assert input_path.is_file() + given_input = np.load(input_path).astype("int32") + input_shape = given_input.shape + assert len(input_shape) == 2 + max_input_length = input_shape[1] + given_input_lengths = sequence_lengths(given_input, pad_id) + assert np.all(given_input_lengths <= max_input_length) + # Expected results + assert results_path.is_file() + expected_outputs = np.load(results_path).astype("int32") + output_shape = expected_outputs.shape + assert len(output_shape) == 2 + assert input_shape[0] * beam_width == output_shape[0] + max_seq_length = output_shape[1] + max_new_tokens = max_seq_length - max_input_length + + end_ids = [pad_id for _ in range(len(given_input_lengths))] + expected_lengths = [] + for i in range(len(given_input_lengths)): + expected_lengths.append([ + given_input_lengths[i] + max_new_tokens + for _ in range(beam_width) + ]) + + test_data = { + "expected_output_ids": expected_outputs, + "expected_output_lengths": expected_lengths, + "max_seq_length": max_seq_length, + "end_ids": end_ids + } + return given_input, given_input_lengths, max_input_length, test_data + + def validate_results_shapes(result, input_length, max_output_len, + beam_tokens): + if compute_log_probs: + assert result.cum_log_probs is not None + assert result.log_probs is not None + assert len(result.cum_log_probs) == beam_width + assert len(result.log_probs) == beam_width + for beam in range(beam_width): + expected_len = len( + beam_tokens[beam]) - (input_length if remove_input else 0) + assert len(result.log_probs[beam]) == expected_len + else: + assert result.cum_log_probs is None + assert result.log_probs is None + if return_context_logits: + assert result.context_logits is not None + assert len(result.context_logits.shape) == 2 + assert list(result.context_logits.shape) == [ + input_length, vocab_size_padded + ] + else: + assert result.context_logits is None + if return_generation_logits: + assert len(result.generation_logits.shape) == 3 + assert list(result.generation_logits.shape) == [ + beam_width, max_output_len, vocab_size_padded + ] + + def verify_output(beam_tokens, test_data, given_input_lengths): + for batch_id, tokens in beam_tokens.items(): + input_length = given_input_lengths[batch_id] + end_id = test_data["end_ids"][batch_id] + for beam in range(beam_width): + predicted_tokens = tokens[beam] + if remove_input: + predicted_tokens = predicted_tokens[input_length:] + expected_length = test_data["expected_output_lengths"][ + batch_id][beam] - input_length + assert len(predicted_tokens) == expected_length + expected_tokens = test_data["expected_output_ids"][ + batch_id * beam_width + beam][input_length:] + for i in range(len(predicted_tokens)): + if expected_tokens[i] == end_id: + break + assert predicted_tokens[i] == expected_tokens[i], \ + f"Predicted: {predicted_tokens} vs Expected: {expected_tokens}" + + output_config = trtllm.OutputConfig() + output_config.exclude_input_from_output = exclude_input_from_output + output_config.return_log_probs = compute_log_probs + output_config.return_generation_logits = return_generation_logits + output_config.return_context_logits = return_context_logits + + kv_cache_config = trtllm.KvCacheConfig(False, free_gpu_memory_fraction=0.5) + executor_config = trtllm.ExecutorConfig(beam_width) + executor_config.batching_type = batching_type + executor_config.kv_cache_config = kv_cache_config + + if return_context_logits or return_generation_logits: + model_path = model_path_return_logits + executor = trtllm.Executor(model_path, trtllm.ModelType.DECODER_ONLY, + executor_config) + + # Load test data + results_path = results_data_path if beam_width == 1 else results_data_path_beam_width_2 + given_input, given_input_lengths, max_input_length, test_data = load_test_data( + input_data_path, results_path) + + # Create requests from input data + num_requests = len(given_input_lengths) + requests = [] + req_max_new_tokens = [] + + for i in range(num_requests): + input_len = given_input_lengths[i] + max_new_tokens = test_data["max_seq_length"] - max_input_length + req_max_new_tokens.append(max_new_tokens) + req_tokens = given_input[i][:input_len] + requests.append( + trtllm.Request(req_tokens, + max_new_tokens, + streaming, + trtllm.SamplingConfig(beam_width), + output_config, + end_id=-1)) + + req_ids = executor.enqueue_requests(requests) + + req_to_batch_id = {req_ids[i]: i for i in range(len(requests))} + tokens = {i: [[] for _ in range(beam_width)] for i in range(len(requests))} + + num_finished = 0 + i = 0 + num_responses = 0 + max_wait_ms = 10000 + while num_finished < num_requests and i < max_wait_ms: + wait_time = datetime.timedelta(milliseconds=1) + responses = executor.await_responses(None, wait_time) + for response in responses: + num_responses += 1 + assert not response.has_error( + ), f"Request id {response.request_id} failed with err {response.error_msg}" + result = response.result + num_finished += result.is_final + + batch_id = req_to_batch_id[response.request_id] + for beam in range(beam_width): + new_tokens = result.output_token_ids[beam] + tokens[batch_id][beam] += new_tokens + + validate_results_shapes(result, given_input_lengths[batch_id], + req_max_new_tokens[batch_id], + tokens[batch_id]) + i += 1 + assert i < max_wait_ms + verify_output(tokens, test_data, given_input_lengths) + + +@skip_pre_ampere # ContextFMHAType with fp32 acc is not supported in pre-ampere architecture +def test_gpt_executor_timed_out(model_files, model_path): + beam_width = 1 + executor_config = trtllm.ExecutorConfig(beam_width) + executor = trtllm.Executor(model_path, trtllm.ModelType.DECODER_ONLY, + executor_config) + + # No requests enqueued, expect no responses + num_responses_ready = executor.get_num_responses_ready() + assert num_responses_ready == 0 + + wait_time = datetime.timedelta(milliseconds=10) + responses = executor.await_responses(None, wait_time) + assert len(responses) == 0 + + +@skip_pre_ampere # ContextFMHAType with fp32 acc is not supported in pre-ampere architecture +def test_single_request_invalid_inputs(model_files, model_path): + streaming = True + beam_width = 1 + executor_config = trtllm.ExecutorConfig(beam_width) + executor = trtllm.Executor(model_path, trtllm.ModelType.DECODER_ONLY, + executor_config) + + max_new_tokens = 5 + input_tokens = [1, 2, 3, 4] + request = trtllm.Request(input_tokens, max_new_tokens, streaming) + # Invalid embedding bias shape + embedding_bias = torch.ones(1) + request.embedding_bias = embedding_bias + expected_error_msg = "embedding bias shape is not as expected" + + request_id = executor.enqueue_request(request) + + done = False + i = 0 + max_wait_ms = 10000 + while not done and i < max_wait_ms: + wait_time = datetime.timedelta(milliseconds=1) + responses = executor.await_responses(request_id, wait_time) + for response in responses: + assert response.has_error(), "Expected an error" + assert expected_error_msg in response.error_msg + done = True + i += 1 + assert done + + +def test_sampling_config(): + beam_width = 1 + kwargs = { + "top_k": 2, + "top_p": 1.0, + "top_p_min": 1.0, + "top_p_reset_ids": 3, + "top_p_decay": 1.0, + "random_seed": 7, + "temperature": 1.0, + "min_length": 4, + "beam_search_diversity_rate": 1.0, + "repetition_penalty": 1.0, + "presence_penalty": 1.0, + "frequency_penalty": 1.0, + "length_penalty": 1.0, + "early_stopping": 5 + } + config = trtllm.SamplingConfig(beam_width, **kwargs) + for k, v in kwargs.items(): + assert getattr(config, k) == v + del config + + config = trtllm.SamplingConfig(beam_width) + assert config.beam_width == beam_width + for k in kwargs: + assert getattr(config, k) is None + + +def test_output_config(): + config = trtllm.OutputConfig() + assert config.return_log_probs == False + assert config.return_context_logits == False + assert config.return_generation_logits == False + assert config.exclude_input_from_output == False + + config = trtllm.OutputConfig(True, False, True, False) + assert config.return_log_probs == True + assert config.return_context_logits == False + assert config.return_generation_logits == True + assert config.exclude_input_from_output == False + + +def test_speculative_decoding_config(): + tokens = [1, 2, 3] + config = trtllm.SpeculativeDecodingConfig(tokens) + assert config.tokens == tokens + assert config.logits is None + assert config.acceptance_threshold is None + del config + + logits = torch.ones(3, 1) + acceptance_threshold = 1.0 + config = trtllm.SpeculativeDecodingConfig(tokens, logits, + acceptance_threshold) + assert config.tokens == tokens + assert (config.logits == logits).all() + assert config.acceptance_threshold == acceptance_threshold + + +def test_prompt_tuning_config(): + embedding_table = torch.ones(100, 64) + config = trtllm.PromptTuningConfig(embedding_table) + assert (config.embedding_table == embedding_table).all() + + +def test_lora_config(): + task_id = 1 + lora_config = trtllm.LoraConfig(task_id) + assert lora_config.task_id == task_id + assert lora_config.weights is None + assert lora_config.config is None + + task_id = 2 + weights = torch.ones(1, 2) + config = torch.ones(1, 2, dtype=torch.int32) + lora_config = trtllm.LoraConfig(task_id, weights, config) + assert lora_config.task_id == task_id + assert (lora_config.weights == weights).all() + assert (lora_config.config == config).all() + + +def test_request(): + kwargs = { + "input_token_ids": [1, 2, 3], + "max_new_tokens": 1, + "streaming": False, + "sampling_config": trtllm.SamplingConfig(), + "output_config": trtllm.OutputConfig(), + "end_id": -1, + "pad_id": -2, + "bad_words": [[4, 5, 6]], + "stop_words": [[7, 8, 9]], + "embedding_bias": torch.ones(1), + "speculative_decoding_config": + trtllm.SpeculativeDecodingConfig([1, 2, 3]), + "prompt_tuning_config": trtllm.PromptTuningConfig(torch.ones(100, 64)), + "lora_config": trtllm.LoraConfig(1) + } + request = trtllm.Request(**kwargs) + for k, v in kwargs.items(): + if "config" not in k: + assert getattr(request, k) == v + assert isinstance(request.sampling_config, trtllm.SamplingConfig) + assert isinstance(request.output_config, trtllm.OutputConfig) + assert isinstance(request.speculative_decoding_config, + trtllm.SpeculativeDecodingConfig) + assert request.speculative_decoding_config.tokens == [1, 2, 3] + assert isinstance(request.prompt_tuning_config, trtllm.PromptTuningConfig) + assert (request.prompt_tuning_config.embedding_table == torch.ones( + 100, 64)).all() + assert isinstance(request.lora_config, trtllm.LoraConfig) + + +def test_result(): + result = trtllm.Result() + result.is_final = True + result.output_token_ids = [[1, 2, 3]] + result.cum_log_probs = [1.0, 2.0, 3.0] + result.log_probs = [[1.0, 2.0, 3.0]] + result.context_logits = torch.ones(3, 100) + result.generation_logits = torch.ones(1, 3, 100) + assert result.is_final == True + assert result.output_token_ids == [[1, 2, 3]] + assert result.cum_log_probs == [1.0, 2.0, 3.0] + assert result.log_probs == [[1.0, 2.0, 3.0]] + assert (result.context_logits == torch.ones(3, 100)).all() + assert (result.generation_logits == torch.ones(1, 3, 100)).all() + + +def test_response(): + request_id = 0 + error_msg = "error" + response = trtllm.Response(request_id, error_msg) + assert response.request_id == request_id + assert response.has_error() + assert response.error_msg == error_msg + + result = trtllm.Result() + result.is_final = True + result.output_token_ids = [[1, 2, 3]] + request_id = 1 + response = trtllm.Response(request_id, result) + assert response.request_id == request_id + assert not response.has_error() + assert response.result.is_final + assert response.result.output_token_ids == [[1, 2, 3]] + + +def test_scheduler_config(): + policy = trtllm.SchedulerPolicy.MAX_UTILIZATION + config = trtllm.SchedulerConfig(policy) + assert config.policy == policy + + policy = trtllm.SchedulerPolicy.GUARANTEED_NO_EVICT + config = trtllm.SchedulerConfig(policy) + assert config.policy == policy + + +def test_kv_cache_config(): + config = trtllm.KvCacheConfig() + assert config.enable_block_reuse == False + assert config.max_tokens is None + assert config.max_attention_window is None + assert config.sink_token_length is None + assert config.free_gpu_memory_fraction is None + + kwargs = { + "enable_block_reuse": True, + "max_tokens": 3, + "max_attention_window": 10, + "sink_token_length": 2, + "free_gpu_memory_fraction": 0.5, + } + config = trtllm.KvCacheConfig(**kwargs) + for k, v in kwargs.items(): + assert getattr(config, k) == v + + +def test_executor_config(): + config = trtllm.ExecutorConfig() + assert config.max_beam_width == 1 + assert isinstance(config.scheduler_config, trtllm.SchedulerConfig) + assert isinstance(config.kv_cache_config, trtllm.KvCacheConfig) + assert config.enable_chunked_context == False + assert config.normalize_log_probs == True + assert config.iter_stats_max_iterations == 1000 + assert config.batching_type == trtllm.BatchingType.INFLIGHT + assert config.parallel_config is None + assert isinstance(config.peft_cache_config, trtllm.PeftCacheConfig) + + kwargs = { + "max_beam_width": + 2, + "scheduler_config": + trtllm.SchedulerConfig(trtllm.SchedulerPolicy.MAX_UTILIZATION), + "kv_cache_config": + trtllm.KvCacheConfig(), + "enable_chunked_context": + True, + "normalize_log_probs": + False, + "iter_stats_max_iterations": + 100, + "batching_type": + trtllm.BatchingType.STATIC, + "parallel_config": + trtllm.ParallelConfig(), + "peft_cache_config": + trtllm.PeftCacheConfig(10) + } + config = trtllm.ExecutorConfig(**kwargs) + for k, v in kwargs.items(): + if "config" not in k: + assert getattr(config, k) == v + assert isinstance(config.scheduler_config, trtllm.SchedulerConfig) + assert config.scheduler_config.policy == trtllm.SchedulerPolicy.MAX_UTILIZATION + assert isinstance(config.kv_cache_config, trtllm.KvCacheConfig) + assert isinstance(config.parallel_config, trtllm.ParallelConfig) + assert isinstance(config.peft_cache_config, trtllm.PeftCacheConfig) + + +def test_parallel_config(): + comm_type = trtllm.CommunicationType.MPI + comm_mode = trtllm.CommunicationMode.LEADER + device_ids = [0, 1, 2, 3] + participant_ids = [4, 5, 6, 7] + parallel_config = trtllm.ParallelConfig(comm_type, comm_mode, device_ids, + participant_ids) + assert parallel_config.communication_type == comm_type + assert parallel_config.communication_mode == comm_mode + assert parallel_config.device_ids == device_ids + assert parallel_config.participant_ids == participant_ids + + +def test_peft_cache_config(): + num_host_module_layer = 1 + num_device_module_layer = 2 + optimal_adapter_size = 3 + max_adapter_size = 4 + num_put_workers = 5 + num_ensure_workers = 6 + num_copy_streams = 7 + max_pages_per_block_host = 8 + max_pages_per_block_device = 9 + device_cache_percent = 0.9 + host_cache_size = 1024 + peft_cache_config = trtllm.PeftCacheConfig( + num_host_module_layer, num_device_module_layer, optimal_adapter_size, + max_adapter_size, num_put_workers, num_ensure_workers, num_copy_streams, + max_pages_per_block_host, max_pages_per_block_device, + device_cache_percent, host_cache_size) + + assert peft_cache_config.num_host_module_layer == num_host_module_layer + assert peft_cache_config.num_device_module_layer == num_device_module_layer + assert peft_cache_config.optimal_adapter_size == optimal_adapter_size + assert peft_cache_config.max_adapter_size == max_adapter_size + assert peft_cache_config.num_put_workers == num_put_workers + assert peft_cache_config.num_ensure_workers == num_ensure_workers + assert peft_cache_config.num_copy_streams == num_copy_streams + assert peft_cache_config.max_pages_per_block_host == max_pages_per_block_host + assert peft_cache_config.max_pages_per_block_device == max_pages_per_block_device + assert np.isclose(peft_cache_config.device_cache_percent, + device_cache_percent) + assert peft_cache_config.host_cache_size == host_cache_size diff --git a/tests/examples_unit/gpt/test_build.py b/tests/examples_unit/gpt/test_build.py deleted file mode 100644 index fc432bfbd..000000000 --- a/tests/examples_unit/gpt/test_build.py +++ /dev/null @@ -1,123 +0,0 @@ -import argparse -import os -import sys -import tempfile -from contextlib import contextmanager -from pathlib import Path -from typing import Union - - -@contextmanager -def prepend_to_sys_path(path: Union[str, os.PathLike]) -> None: - sys.path = [str(path)] + sys.path - try: - yield - finally: - sys.path = sys.path[1:] - - -# Using 2 separate context managers instead of 1 to avoid pre-commit isort problem. -# In "Build TRT-LLM" job pre-commit is run twice for 2 consecutive builds. -# If 1 context manager is used the first pre-commit check is passed for the first -# build but the second pre-commit check is failed with isort. -with prepend_to_sys_path(Path(__file__).parent / '../../../examples/gpt'): - from build import override_args_from_model_dir - -with prepend_to_sys_path(Path(__file__).parent / '../../../examples/gpt'): - from utils.nemo import nemo_config_to_ini_config - - -class TestOverridingOfRotaryParameters: - nemo_configs = { - "rotary_base_overriding": { - "position_embedding_type": "rope", - "rotary_percentage": 1.0, - "seq_len_interpolation_factor": 4.0, - "rotary_base": 8888, - "max_position_embedding": 1024, - "num_attention_heads": 48, - }, - "rotary_scaling_overriding": { - "position_embedding_type": "rope", - "rotary_percentage": 1.0, - "seq_len_interpolation_factor": 3.33333, - "rotary_base": 10000, - "max_position_embedding": 1024, - "num_attention_heads": 48, - }, - "rotary_pct_overriding": { - "position_embedding_type": "rope", - "rotary_percentage": 0.3, - "seq_len_interpolation_factor": 3.33333, - "rotary_base": 10000, - "max_position_embedding": 1024, - "num_attention_heads": 48, - }, - "no_overriding": { - "position_embedding_type": "learned_absolute", - "max_position_embedding": 1024, - "num_attention_heads": 48, - }, - } - - @staticmethod - def create_args_with_model_dir(model_dir) -> argparse.Namespace: - args = argparse.Namespace() - args.model_dir = model_dir - return args - - def test_rotary_base_overriding(self): - nemo_config = self.nemo_configs["rotary_base_overriding"] - ini_config = nemo_config_to_ini_config(nemo_config, 100, 101, 103, - "float32") - with tempfile.TemporaryDirectory() as model_dir: - with open(Path(model_dir) / "config.ini", "w") as f: - ini_config.write(f) - args = self.create_args_with_model_dir(model_dir) - args.rotary_base = 1111 - override_args_from_model_dir(args) - assert args.rotary_base == nemo_config["rotary_base"] - - def test_rotary_scaling_overriding(self): - nemo_config = self.nemo_configs["rotary_scaling_overriding"] - ini_config = nemo_config_to_ini_config(nemo_config, 100, 101, 103, - "float32") - with tempfile.TemporaryDirectory() as model_dir: - model_dir = Path(model_dir) - with open(model_dir / "config.ini", "w") as f: - ini_config.write(f) - args = self.create_args_with_model_dir(model_dir) - args.scaling = "Scaling?" - override_args_from_model_dir(args) - assert args.rotary_scaling == [ - "linear", - str(nemo_config["seq_len_interpolation_factor"]) - ] - - def test_rotary_pct_overriding(self): - nemo_config = self.nemo_configs["rotary_pct_overriding"] - ini_config = nemo_config_to_ini_config(nemo_config, 100, 101, 103, - "float32") - with tempfile.TemporaryDirectory() as model_dir: - with open(Path(model_dir) / "config.ini", "w") as f: - ini_config.write(f) - args = self.create_args_with_model_dir(model_dir) - args.rotary_pct = 'foo' - override_args_from_model_dir(args) - assert args.rotary_pct == nemo_config["rotary_percentage"] - - def test_no_overriding(self): - nemo_config = self.nemo_configs["no_overriding"] - ini_config = nemo_config_to_ini_config(nemo_config, 100, 101, 103, - "float32") - with tempfile.TemporaryDirectory() as model_dir: - with open(Path(model_dir) / "config.ini", "w") as f: - ini_config.write(f) - args = self.create_args_with_model_dir(model_dir) - args.rotary_scaling = 'foo' - args.rotary_pct = "bar" - args.rotary_base = "baz" - override_args_from_model_dir(args) - assert args.rotary_scaling == "foo" - assert args.rotary_pct == 0.0 - assert args.rotary_base == "baz" diff --git a/tests/examples_unit/gpt/utils/test_nemo.py b/tests/examples_unit/gpt/utils/test_nemo.py deleted file mode 100644 index c99473ab2..000000000 --- a/tests/examples_unit/gpt/utils/test_nemo.py +++ /dev/null @@ -1,168 +0,0 @@ -import os -import sys -from contextlib import contextmanager -from pathlib import Path -from typing import Union - -import pytest - - -@contextmanager -def prepend_to_sys_path(path: Union[str, os.PathLike]) -> None: - sys.path = [str(path)] + sys.path - try: - yield - finally: - sys.path = sys.path[1:] - - -with prepend_to_sys_path(Path(__file__).parent / '../../../../../examples/gpt'): - from utils.nemo import nemo_config_to_ini_config - - -class TestRotaryParametersSetting: - nemo_configs = { - "learned_absolute": { - "position_embedding_type": "learned_absolute", - "rotary_percentage": 0.0, - "seq_len_interpolation_factor": 4.0, - "rotary_base": 10000, - "max_position_embedding": 1024, - "num_attention_heads": 48, - }, - "relative": { - "position_embedding_type": "relative", - "rotary_percentage": 0.0, - "seq_len_interpolation_factor": None, - "rotary_base": 10000, - "max_position_embedding": 1024, - "num_attention_heads": 48, - }, - "nemo_rotary_pct_default_is_used": { - "position_embedding_type": "rope", - "seq_len_interpolation_factor": None, - "max_position_embedding": 1024, - "num_attention_heads": 48, - }, - "nemo_rotary_base_default_is_used": { - "position_embedding_type": "rope", - "seq_len_interpolation_factor": None, - "max_position_embedding": 1024, - "num_attention_heads": 48, - }, - "scaling_is_set": { - "position_embedding_type": "rope", - "seq_len_interpolation_factor": 3.5, - "max_position_embedding": 1024, - "num_attention_heads": 48, - }, - "wrong_seq_len_interpolation_factor_value": { - "position_embedding_type": "rope", - "seq_len_interpolation_factor": 1.0, - "max_position_embedding": 1024, - "num_attention_heads": 48, - }, - "no_scaling": { - "position_embedding_type": "rope", - "seq_len_interpolation_factor": None, - "max_position_embedding": 1024, - "num_attention_heads": 48, - }, - "rotary_base": { - "position_embedding_type": "rope", - "rotary_base": 9999, - "max_position_embedding": 1024, - "num_attention_heads": 48, - }, - "rotary_percentage_equals_0": { - "position_embedding_type": "rope", - "rotary_base": 9999, - "max_position_embedding": 1024, - "num_attention_heads": 48, - "rotary_percentage": 0.0, - }, - "rotary_percentage_gt_1": { - "position_embedding_type": "rope", - "rotary_base": 9999, - "max_position_embedding": 1024, - "num_attention_heads": 48, - "rotary_percentage": 1.1, - }, - "good_rotary_percentage": { - "position_embedding_type": "rope", - "rotary_base": 9999, - "max_position_embedding": 1024, - "num_attention_heads": 48, - "rotary_percentage": 0.4, - }, - } - - @pytest.mark.parametrize("nemo_config_name", - ["learned_absolute", "relative"]) - def test_no_rope(self, nemo_config_name): - nemo_config = self.nemo_configs[nemo_config_name] - vocab_size = 103 - ini_config = nemo_config_to_ini_config(nemo_config, 100, 101, - vocab_size, "float32") - assert float(ini_config["gpt"]["rotary_pct"]) == 0.0 - assert "rotary_scaling" not in ini_config["gpt"] - assert "rotary_base" not in ini_config["gpt"] - assert "n_head" in ini_config["gpt"] - assert "n_positions" in ini_config["gpt"] - assert int(ini_config["gpt"]["vocab_size"]) == vocab_size - - def test_nemo_rotary_pct_default_is_used(self): - nemo_config = self.nemo_configs["nemo_rotary_pct_default_is_used"] - ini_config = nemo_config_to_ini_config(nemo_config, 100, 101, 103, - "float32") - assert float(ini_config["gpt"]["rotary_pct"]) == 1.0 - - def test_rotary_base_default(self): - nemo_config = self.nemo_configs["nemo_rotary_base_default_is_used"] - ini_config = nemo_config_to_ini_config(nemo_config, 100, 101, 103, - "float32") - assert int(ini_config["gpt"]["rotary_base"]) == 10000 - - def test_scaling_is_set(self): - nemo_config = self.nemo_configs["scaling_is_set"] - ini_config = nemo_config_to_ini_config(nemo_config, 100, 101, 103, - "float32") - assert float(ini_config["gpt"]["rotary_scaling_factor"] - ) == nemo_config["seq_len_interpolation_factor"] - assert ini_config["gpt"]["rotary_scaling_type"] == "linear" - - def test_wrong_seq_len_interpolation_factor_value(self): - nemo_config = self.nemo_configs[ - "wrong_seq_len_interpolation_factor_value"] - with pytest.raises(ValueError): - nemo_config_to_ini_config(nemo_config, 100, 101, 103, "float32") - - def test_no_scaling(self): - nemo_config = self.nemo_configs["no_scaling"] - ini_config = nemo_config_to_ini_config(nemo_config, 100, 101, 103, - "float32") - assert "rotary_scaling" not in ini_config["gpt"] - - def test_rotary_base(self): - nemo_config = self.nemo_configs["rotary_base"] - ini_config = nemo_config_to_ini_config(nemo_config, 100, 101, 103, - "float32") - assert int( - ini_config["gpt"]["rotary_base"]) == nemo_config["rotary_base"] - - def test_rotary_percentage_equals_0(self): - nemo_config = self.nemo_configs["rotary_percentage_equals_0"] - with pytest.raises(ValueError): - nemo_config_to_ini_config(nemo_config, 100, 101, 103, "float32") - - def test_rotary_percentage_gt_1(self): - nemo_config = self.nemo_configs["rotary_percentage_gt_1"] - with pytest.raises(ValueError): - nemo_config_to_ini_config(nemo_config, 100, 101, 103, "float32") - - def test_good_rotary_percentage(self): - nemo_config = self.nemo_configs["good_rotary_percentage"] - ini_config = nemo_config_to_ini_config(nemo_config, 100, 101, 103, - "float32") - assert float( - ini_config["gpt"]["rotary_pct"]) == nemo_config["rotary_percentage"] diff --git a/tests/functional/test_selective_scan.py b/tests/functional/test_selective_scan.py index 52e29871f..41018074d 100644 --- a/tests/functional/test_selective_scan.py +++ b/tests/functional/test_selective_scan.py @@ -19,7 +19,7 @@ import numpy as np import torch -from einops import rearrange +import torch.nn.functional as F from parameterized import parameterized from torch_ref import selective_scan_ref, selective_state_update_ref @@ -46,7 +46,7 @@ def test_selective_scan(self, dim, dstate, req_type, dtype): skip_bf16_pre_ampere(dtype) # configs - batch_size = 1 + batch_size = 4 device = "cuda" seq_len = 16 if req_type == 'context' else 1 is_variable_B = True @@ -55,6 +55,16 @@ def test_selective_scan(self, dim, dstate, req_type, dtype): # test data torch.random.manual_seed(0) + if req_type == 'context': + last_token_ids = torch.randint(1, + seq_len + 1, + size=(batch_size, ), + dtype=torch.int32, + device=device) + last_token_ids[0] = seq_len + else: + last_token_ids = torch.ones( + [batch_size], dtype=torch.int32, device=device) * seq_len state = torch.randn(batch_size, dstate, dim, @@ -91,15 +101,15 @@ def test_selective_scan(self, dim, dstate, req_type, dtype): device=device, dtype=str_dtype_to_torch(dtype)) - state_ref = state.detach().clone().permute(0, 2, 1).contiguous() - x_ref = x.detach().clone().permute(0, 2, 1).contiguous() - dt_ref = dt.detach().clone().permute(0, 2, 1).contiguous() + state_ref = state.detach().clone() + x_ref = x.detach().clone() + dt_ref = dt.detach().clone() dt_bias_ref = dt_bias.detach().clone() - A_ref = A.detach().clone().permute(1, 0).contiguous() - B_ref = B.detach().clone().permute(0, 2, 1).contiguous() - C_ref = C.detach().clone().permute(0, 2, 1).contiguous() + A_ref = A.detach().clone() + B_ref = B.detach().clone() + C_ref = C.detach().clone() D_ref = D.detach().clone() - z_ref = z.detach().clone().permute(0, 2, 1).contiguous() + z_ref = z.detach().clone() # construct trt network builder = tensorrt_llm.Builder() @@ -137,11 +147,15 @@ def test_selective_scan(self, dim, dstate, req_type, dtype): name='host_request_types', shape=host_request_types.shape, dtype=tensorrt_llm.str_dtype_to_trt('int32')) + last_token_ids_tensor = Tensor( + name='last_token_ids', + shape=last_token_ids.shape, + dtype=tensorrt_llm.str_dtype_to_trt('int32')) outputs = tensorrt_llm.functional.selective_scan( x_tensor, state_tensor, dt_tensor, dt_bias_tensor, A_tensor, B_tensor, C_tensor, D_tensor, z_tensor, - host_request_types_tensor, dim, dstate, is_variable_B, - is_variable_C, delta_softplus, dtype) + host_request_types_tensor, last_token_ids_tensor, dim, dstate, + is_variable_B, is_variable_C, delta_softplus, dtype) net._mark_output(outputs[0], 'output', dtype=tensorrt_llm.str_dtype_to_trt(dtype)) @@ -160,7 +174,8 @@ def test_selective_scan(self, dim, dstate, req_type, dtype): 'C': C, 'D': D, 'z': z, - 'host_request_types': host_request_types + 'host_request_types': host_request_types, + 'last_token_ids': last_token_ids } outputs = {'output': output, 'present_state': state} stream = torch.cuda.current_stream() @@ -171,41 +186,61 @@ def test_selective_scan(self, dim, dstate, req_type, dtype): out_ref = None if req_type == 'context': # pytorch run - out_ref, state_ref = selective_scan_ref(x_ref, - dt_ref, - A_ref, - B_ref, - C_ref, - D=D_ref, - z=z_ref, - delta_bias=dt_bias_ref, - delta_softplus=True) - + out_ref, state_ref = [], [] + for i in range(batch_size): + seq_len_i = last_token_ids[i] + out_ref_i, state_ref_i = selective_scan_ref( + x_ref[i:i + 1, 0:seq_len_i, :], + dt_ref[i:i + 1, 0:seq_len_i, :], + A_ref, + B_ref[i:i + 1, 0:seq_len_i, :], + C_ref[i:i + 1, 0:seq_len_i, :], + D=D_ref, + z=z_ref[i:i + 1, 0:seq_len_i, :], + delta_bias=dt_bias_ref, + delta_softplus=True) + out_ref_i = F.pad(out_ref_i, + (0, 0, 0, seq_len - out_ref_i.shape[1], 0, 0), + value=0) + out_ref.append(out_ref_i) + state_ref.append(state_ref_i) + out_ref = torch.concat(out_ref, dim=0) + state_ref = torch.concat(state_ref, dim=0) elif req_type == 'generation': # pytorch run out_ref = selective_state_update_ref(state_ref, - x_ref.squeeze(2), - dt_ref.squeeze(2), + x_ref.squeeze(1), + dt_ref.squeeze(1), A_ref, - B_ref.squeeze(2), - C_ref.squeeze(2), + B_ref.squeeze(1), + C_ref.squeeze(1), D=D_ref, - z=z_ref.squeeze(2), + z=z_ref.squeeze(1), dt_bias=dt_bias_ref, dt_softplus=True) - out_ref = out_ref.unsqueeze(2) + out_ref = out_ref.unsqueeze(1) - dtype_atol = {"float16": 5e-3, "float32": 2e-3, "bfloat16": 5e-2} + # get output mask + if req_type == 'context': + out_mask = torch.zeros(batch_size, seq_len, device=device) + for i in range(batch_size): + for j in range(last_token_ids[i]): + out_mask[i, j] = 1 + out_mask = out_mask.unsqueeze(2).expand([batch_size, seq_len, dim]) + else: + out_mask = torch.ones(batch_size, seq_len, dim, device=device) - output_cpu = outputs['output'].to(torch.float32).cpu() - present_state_cpu = outputs['present_state'].to(torch.float32).cpu() - output_cpu = rearrange(output_cpu, 'b s d -> b d s').contiguous() - present_state_cpu = rearrange(present_state_cpu, - 'b d n -> b n d').contiguous() + dtype_atol = {"float16": 5e-3, "float32": 2e-3, "bfloat16": 5e-2} - np.testing.assert_allclose(out_ref.to(torch.float32).cpu().numpy(), - output_cpu.numpy(), - atol=dtype_atol[dtype]) - np.testing.assert_allclose(state_ref.to(torch.float32).cpu().numpy(), - present_state_cpu.numpy(), + # compare out diff + outputs['output'][out_mask == 0] = 0 + out_trt_llm = outputs['output'].detach().to(torch.float32).cpu().numpy() + out_ref = (out_ref * out_mask).detach().to(torch.float32).cpu().numpy() + np.testing.assert_allclose(out_ref, out_trt_llm, atol=dtype_atol[dtype]) + + # compare present state diff + state_trt_llm = outputs['present_state'].detach().to(torch.float32) + state_ref = state_ref.detach().to(torch.float32) + np.testing.assert_allclose(state_ref.cpu().numpy(), + state_trt_llm.cpu().numpy(), atol=dtype_atol[dtype]) diff --git a/tests/functional/torch_ref.py b/tests/functional/torch_ref.py index 2acef943c..40ee4cb98 100644 --- a/tests/functional/torch_ref.py +++ b/tests/functional/torch_ref.py @@ -165,41 +165,41 @@ def selective_scan_ref(u, delta_bias=None, delta_softplus=False): """ - u: (B D L) - delta: (B D L) - A: (D N) - B: (B N L) - C: (B N L) + u: (B L D) + delta: (B L D) + A: (N D) + B: (B L N) + C: (B L N) D: (D) - z: (B D L) + z: (B L D) delta_bias: (D), fp32 - out: (B D L) - last_state (optional): (B D dstate), fp32 + out: (B L D) + last_state (optional): (B dstate D), fp32 """ dtype_in = u.dtype u = u.float() delta = delta.float() if delta_bias is not None: - delta = delta + delta_bias[..., None].float() + delta = delta + delta_bias.unsqueeze(0).unsqueeze(1).float() if delta_softplus: delta = F.softplus(delta) - batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1] + batch, dstate, dim = u.shape[0], A.shape[0], A.shape[1] B = B.float() C = C.float() - x = A.new_zeros((batch, dim, dstate)) + x = A.new_zeros((batch, dstate, dim)) ys = [] - deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A)) - deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u) + deltaA = torch.exp(torch.einsum('bld,nd->blnd', delta, A)) + deltaB_u = torch.einsum('bld,bln,bld->blnd', delta, B, u) last_state = None - for i in range(u.shape[2]): - x = deltaA[:, :, i] * x + deltaB_u[:, :, i] - y = torch.einsum('bdn,bn->bd', x, C[:, :, i]) - if i == u.shape[2] - 1: + for i in range(u.shape[1]): + x = deltaA[:, i, :] * x + deltaB_u[:, i, :] + y = torch.einsum('bnd,bn->bd', x, C[:, i, :]) + if i == u.shape[1] - 1: last_state = x ys.append(y) - y = torch.stack(ys, dim=2) # (batch dim L) - out = y if D is None else y + u * rearrange(D, "d -> d 1") + y = torch.stack(ys, dim=1) # (batch L dim) + out = y if D is None else y + u * rearrange(D, "d -> 1 d") if z is not None: out = out * F.silu(z.float()) out = out.to(dtype=dtype_in) @@ -219,10 +219,10 @@ def selective_state_update_ref(state, dt_softplus=False): """ Argument: - state: (batch, dim, dstate) + state: (batch, dstate, dim) x: (batch, dim) dt: (batch, dim) - A: (dim, dstate) + A: (dstate, dim) B: (batch, dstate) C: (batch, dstate) D: (dim,) @@ -231,10 +231,10 @@ def selective_state_update_ref(state, Return: out: (batch, dim) """ - batch, dim, dstate = state.shape + batch, dstate, dim = state.shape assert x.shape == (batch, dim) assert dt.shape == x.shape - assert A.shape == (dim, dstate) + assert A.shape == (dstate, dim) assert B.shape == (batch, dstate) assert C.shape == B.shape if D is not None: @@ -245,13 +245,13 @@ def selective_state_update_ref(state, assert dt_bias.shape == (dim, ) dt = dt + dt_bias dt = F.softplus(dt) if dt_softplus else dt - dA = torch.exp(rearrange(dt, "b d -> b d 1") * A) # (batch, dim, dstate) - dB = rearrange(dt, "b d -> b d 1") * rearrange( - B.float(), "b n -> b 1 n") # (batch, dim, dstate) + dA = torch.exp(rearrange(dt, "b d -> b 1 d") * A) # (batch, dstate, dim) + dB = rearrange(dt, "b d -> b 1 d") * rearrange( + B.float(), "b n -> b n 1") # (batch, dstate, dim) state_new = state * dA + dB * rearrange( - x, "b d -> b d 1") # (batch, dim, dstate) + x, "b d -> b 1 d") # (batch, dstate, dim) state.copy_(state_new.to(state.dtype)) - out = torch.einsum("bdn,bn->bd", state_new, C.float()) + out = torch.einsum("bnd,bn->bd", state_new, C.float()) if D is not None: out += x * D return (out if z is None else out * F.silu(z.float())).to(x.dtype) @@ -322,14 +322,38 @@ def __init__(self, def forward(self, hidden_states, - conv_state=None, - ssm_state=None, + last_token_ids, + conv_state, + ssm_state, seqlen_offset=0): + batch, seqlen, _ = hidden_states.shape + out, present_conv_state, present_ssm_state = [], [], [] + for i in range(batch): + hidden_states_i = hidden_states[i:i + 1, 0:last_token_ids[i], :] + conv_state_i = conv_state[i:i + 1, :] + ssm_state_i = ssm_state[i:i + 1, :] + out_i, conv_state_i, ssm_state_i = self.forward_impl( + hidden_states_i, conv_state_i, ssm_state_i, seqlen_offset) + out_i = F.pad(out_i, (0, 0, 0, seqlen - out_i.shape[1], 0, 0), + value=0) + out.append(out_i) + present_conv_state.append(conv_state_i) + present_ssm_state.append(ssm_state_i) + out = torch.concat(out, dim=0) + present_conv_state = torch.concat(present_conv_state, dim=0) + present_ssm_state = torch.concat(present_ssm_state, dim=0) + return out, present_conv_state, present_ssm_state + + def forward_impl(self, + hidden_states, + conv_state, + ssm_state, + seqlen_offset=0): """ hidden_states: (B, L, D) Returns: same shape as hidden_states """ - batch, seqlen, dim = hidden_states.shape + _, seqlen, _ = hidden_states.shape if seqlen_offset > 0: # The states are updated inplace @@ -339,13 +363,13 @@ def forward(self, # in_proj xz = torch.nn.functional.linear(hidden_states, self.in_proj.weight) - xz = xz.permute(0, 2, 1) if self.in_proj.bias is not None: xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), - "d -> d 1") + "d -> 1 d") # Conv - x, z = xz.chunk(2, dim=1) + x, z = xz.chunk(2, dim=2) + x = x.permute(0, 2, 1) if conv_state is not None: conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0))) x_conv = self.conv1d(x)[..., :seqlen] @@ -357,11 +381,12 @@ def forward(self, [self.dt_rank, self.d_state, self.d_state], dim=-1) dt = self.dt_proj.weight @ dt.t() - dt = rearrange(dt, "d (b l) -> b d l", l=seqlen) - B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous() - C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous() + dt = rearrange(dt, "d (b l) -> b l d", l=seqlen).contiguous() + B = rearrange(B, "(b l) dstate -> b l dstate", l=seqlen).contiguous() + C = rearrange(C, "(b l) dstate -> b l dstate", l=seqlen).contiguous() # Selective scan + x = x.permute(0, 2, 1) y, last_state = selective_scan_ref(x, dt, self.A, @@ -374,7 +399,6 @@ def forward(self, ssm_state.copy_(last_state) # out_proj - y = rearrange(y, "b d l -> b l d") out = self.out_proj(y) return out, conv_state, ssm_state diff --git a/tests/hlapi/test_executor.py b/tests/hlapi/test_executor.py new file mode 100644 index 000000000..ea16b2f67 --- /dev/null +++ b/tests/hlapi/test_executor.py @@ -0,0 +1,203 @@ +import os +import sys +import unittest +from pathlib import Path +from typing import Optional + +import pytest +import torch +from mpi4py import MPI +from transformers import AutoTokenizer + +from tensorrt_llm.executor import (GenerationExecutor, GenerationExecutorWorker, + GenerationRequest) +from tensorrt_llm.hlapi.llm import LLM, ModelConfig + +WORLD_SIZE = MPI.COMM_WORLD.size + + +@pytest.fixture(scope="module") +def llm_root() -> Path: + environ_root = os.environ.get("LLM_ROOT", None) + return Path(environ_root) if environ_root is not None else Path( + __file__).parent.parent.parent + + +@pytest.fixture(scope="module") +def llm_model_root() -> Optional[Path]: + if "LLM_MODEL_ROOT" in os.environ: + return Path(os.environ["LLM_MODEL_ROOT"]) + + sys.path.append(str(Path(__file__).resolve().parent.parent)) + from utils.llm_data import llm_models_root + + return llm_models_root() + + +@pytest.fixture(scope="module") +def resource_path(llm_root: Path) -> Path: + return llm_root / "cpp" / "tests" / "resources" + + +@pytest.fixture(scope="module") +def engine_path(resource_path: Path) -> Path: + return resource_path / "models" / "rt_engine" + + +@pytest.fixture(scope="module") +def llama_7b_path(engine_path: Path, llm_model_root: Path) -> Path: + path = engine_path / "llama7b" + + if not path.exists(): + config = ModelConfig(str(llm_model_root / "llama-models/llama-7b-hf")) + llm = LLM(config) + llm.save(str(path)) + + return path + + +@pytest.fixture(scope="module") +def llama_7b_bs2_path(engine_path: Path, llm_model_root: Path) -> Path: + path = engine_path / "llama7b_bs2" + + if not path.exists(): + config = ModelConfig(str(llm_model_root / "llama-models/llama-7b-hf"), + max_beam_width=2) + llm = LLM(config) + llm.save(str(path)) + + return path + + +@pytest.fixture(scope="module") +def llama_7b_tp2_path(engine_path: Path, llm_model_root: Path) -> Path: + path = engine_path / "llama7b-tp2" + + if not path.exists(): + config = ModelConfig(str(llm_model_root / "llama-models/llama-7b-hf")) + config.parallel_config.tp_size = 2 + llm = LLM(config) + llm.save(str(path)) + + return path + + +@pytest.mark.skipif(WORLD_SIZE != 1, reason="Must run on single MPI rank") +def test_generation_bs2(llama_7b_bs2_path: Path): + tokenizer = llama_7b_bs2_path + prompt = "A B C D" + max_new_tokens = 4 + + with GenerationExecutorWorker(llama_7b_bs2_path, + tokenizer, + max_beam_width=2) as executor: + result = executor.generate(prompt, + max_new_tokens=max_new_tokens, + beam_width=2) + assert result.text[0] == " A B C D E F G H" + assert result.text[1] == " A B C D E F G I" + + +@pytest.mark.skipif(WORLD_SIZE != 1, reason="Must run on single MPI rank") +def test_sync_generation(llama_7b_path: Path): + tokenizer = llama_7b_path + prompt = "A B C D" + expected_output = " E F G H" + expected_long_output = " E F G H I J K L" + split_output = ["E", " F", " G", " H", " I", " J", " K", " L"] + max_new_tokens = 4 + with GenerationExecutorWorker(llama_7b_path, tokenizer) as executor: + # Simple generations (synchronous) + result = executor.generate(prompt, max_new_tokens=max_new_tokens) + assert result.text == " " + prompt + expected_output + + results = executor.generate( + [prompt, prompt], + max_new_tokens=[max_new_tokens, 2 * max_new_tokens]) + for result, expected in zip(results, + (expected_output, expected_long_output)): + assert result.text == " " + prompt + expected + + # Simple generations (asynchronous) + # + # Iterate the partial results when streaming + future = executor.generate_async(prompt, + streaming=True, + max_new_tokens=max_new_tokens) + for idx, partial_result in enumerate(future): + assert partial_result.text_diff == split_output[idx] + + # Iterate the partial results when streaming + # Streaming results in nested loop + futures = executor.generate_async( + [prompt, prompt], + streaming=True, + max_new_tokens=[max_new_tokens, 2 * max_new_tokens]) + for future in futures: + for idx, partial_result in enumerate(future): + assert partial_result.text_diff == split_output[idx] + + # Low-level api with .submit + # Submit a batch of requests + tokenizer = AutoTokenizer.from_pretrained("gpt2") + futures = [] + for _ in range(5): + futures.append( + executor.submit( + GenerationRequest( + prompt, + tokenizer=AutoTokenizer.from_pretrained(llama_7b_path), + max_new_tokens=max_new_tokens))) + + for future in executor.wait_first_completed(futures): + assert future.done + assert future.result().text == "".join(split_output[:4]) + + +@pytest.mark.skipif(torch.cuda.device_count() < 2 or WORLD_SIZE != 2, + reason="Must run on 2 MPI ranks with at least 2 GPUs") +def test_sync_generation_tp_all_nodes(llama_7b_tp2_path: Path): + prompt = "deep learning" + max_new_tokens = 4 + + # Normal execution, all nodes live + executor = GenerationExecutorWorker(llama_7b_tp2_path, llama_7b_tp2_path) + result = executor.generate(prompt, max_new_tokens=max_new_tokens) + assert result.text == " deep learning, neural network," + executor.shutdown() + + +@pytest.mark.skipif(torch.cuda.device_count() < 2 or WORLD_SIZE != 2, + reason="Must run on 2 MPI ranks with at least 2 GPUs") +def test_sync_generation_tp_main_node_only(llama_7b_tp2_path: Path): + prompt = "deep learning" + max_new_tokens = 4 + + with GenerationExecutorWorker(llama_7b_tp2_path, + llama_7b_tp2_path) as executor: + + executor.block_subordinates() + # from now on, only rank0 lives in the with statement + # other nodes wait at the "end" of the with statement + + result = executor.generate(prompt, max_new_tokens=max_new_tokens) + assert result.text == " deep learning, neural network," + + +@pytest.mark.skipif(torch.cuda.device_count() < 2 or WORLD_SIZE != 1, + reason="Must run on 1 MPI rank with at least 2 GPUs") +def test_sync_generation_tp_inner(llama_7b_tp2_path: Path): + prompt = "deep learning" + max_new_tokens = 4 + tp_size = 2 + + executor = GenerationExecutor.create(llama_7b_tp2_path, + llama_7b_tp2_path, + model_world_size=tp_size) + result = executor.generate(prompt, max_new_tokens=max_new_tokens) + assert result.text == " deep learning, neural network," + executor.shutdown() + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/hlapi/test_mpi_session.py b/tests/hlapi/test_mpi_session.py deleted file mode 100644 index 4b0d98c46..000000000 --- a/tests/hlapi/test_mpi_session.py +++ /dev/null @@ -1,41 +0,0 @@ -from dataclasses import dataclass -from typing import List - -from tensorrt_llm.hlapi.mpi_session import SocketListener - - -@dataclass -class ComplexData: - a: str - b: int - c: List[int] - - -def test_SocketServer(): - - messages = [ - "hello", # str - 123, # int - ComplexData("hello", 123, [1, 2, 3]) # complex - ] - - offset = 0 - - def callback(data): - nonlocal offset - print('get data', data) - assert data == messages[offset] - offset += 1 - - server = SocketListener(callback=callback) - - client = server.get_client() - - for data in messages: - client.send(data) - - server.shutdown() - - -if __name__ == '__main__': - test_SocketServer() diff --git a/tests/model/test_gpt.py b/tests/model/test_gpt.py index f00956c2f..46b7a3b85 100644 --- a/tests/model/test_gpt.py +++ b/tests/model/test_gpt.py @@ -24,7 +24,6 @@ # isort: off import torch -import tensorrt as trt # isort: on from parameterized import parameterized from transformers import GPT2Config, GPT2LMHeadModel @@ -42,8 +41,7 @@ KVCacheManager) sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) - -from examples.gpt.weight import load_from_hf_gpt +from examples.gpt.convert_checkpoint import convert_hf_gpt sys.path.append(os.path.join(os.path.dirname(__file__), '..')) from utils.util import skip_fp32_accum_pre_ampere, unittest_name_func @@ -58,6 +56,7 @@ def _gen_hf_gpt(self, hidden_act, n_layer, max_length, dtype): max_length=max_length, torch_dtype=dtype, ) + gpt_config.n_kv_head = gpt_config.n_head hf_gpt = GPT2LMHeadModel(gpt_config).cuda().eval() return gpt_config, hf_gpt @@ -67,28 +66,39 @@ def _gen_tensorrt_llm_network(self, network, builder, hf_gpt, gpt_config, apply_query_key_layer_scaling, gather_context_logits, gather_generation_logits): - num_layers = gpt_config.n_layer - num_heads = gpt_config.n_head - hidden_size = gpt_config.n_embd - vocab_size = gpt_config.vocab_size - hidden_act = gpt_config.activation_function - n_positions = gpt_config.n_positions - tensor_parallel_group = list(range(tensor_parallel)) + dtype = 'float16' if fp16 else 'float32' + config = { + 'architecture': 'GPTForCausalLM', + 'dtype': dtype, + 'num_hidden_layers': gpt_config.n_layer, + 'num_attention_heads': gpt_config.n_head, + 'num_key_value_heads': gpt_config.n_head, + 'hidden_size': gpt_config.n_embd, + 'intermediate_size': gpt_config.n_embd * 4, + 'norm_epsilon': 1e-5, + 'vocab_size': gpt_config.vocab_size, + 'position_embedding_type': 'learned_absolute', + 'max_position_embeddings': gpt_config.n_positions, + 'hidden_act': gpt_config.activation_function, + 'mapping': { + 'world_size': tensor_parallel, + 'tp_size': tensor_parallel, + }, + 'bias': getattr(gpt_config, 'bias', True), + 'apply_query_key_layer_scaling': apply_query_key_layer_scaling, + } + config = tensorrt_llm.models.PretrainedConfig.from_dict(config) + weights = convert_hf_gpt(hf_gpt, + gpt_config, + "gpt2", + config.mapping, + dtype=dtype) + tensorrt_llm_gpt = tensorrt_llm.models.GPTForCausalLM(config) + tensorrt_llm_gpt.load(weights) with net_guard(network): - kv_dtype = trt.float16 if fp16 else trt.float32 # Initialize model - tensorrt_llm_gpt = tensorrt_llm.models.GPTLMHeadModel( - num_layers=num_layers, - num_heads=num_heads, - hidden_size=hidden_size, - vocab_size=vocab_size, - hidden_act=hidden_act, - max_position_embeddings=n_positions, - dtype=kv_dtype, - mapping=tensorrt_llm.Mapping(world_size=tensor_parallel, - tp_size=tensor_parallel), - apply_query_key_layer_scaling=apply_query_key_layer_scaling) + network.set_named_parameters(tensorrt_llm_gpt.named_parameters()) inputs = tensorrt_llm_gpt.prepare_inputs( max_batch_size=batch_size, max_input_len=input_len, @@ -97,14 +107,9 @@ def _gen_tensorrt_llm_network(self, network, builder, hf_gpt, gpt_config, max_beam_width=1, gather_context_logits=gather_context_logits, gather_generation_logits=gather_generation_logits) - load_from_hf_gpt(tensorrt_llm_gpt, - hf_gpt, - dtype="float16" if fp16 else "float32") # Prepare - network.set_named_parameters(tensorrt_llm_gpt.named_parameters()) - - tensorrt_llm_gpt(*inputs) + tensorrt_llm_gpt(**inputs) return network @@ -932,33 +937,38 @@ def test_greedy_search_float32(self, test_partition, use_refit, streaming): @parameterized.expand(["other"], name_func=unittest_name_func) def test_rope_scaling_is_set_in_attention(self, test_partition): num_layers = 2 - position_embedding_type = PositionEmbeddingType.rope_gpt_neox + position_embedding_type = 'rope_gpt_neox' rotary_embedding_percentage = 0.3 rotary_base = 99999.1 rotary_scaling = {"type": "linear", "factor": 2.72} - tensorrt_llm_gpt = tensorrt_llm.models.GPTLMHeadModel( - num_layers=num_layers, - num_heads=4, - hidden_size=128, - vocab_size=256, - hidden_act='gelu', - max_position_embeddings=1024, - dtype=trt.float16, - position_embedding_type=position_embedding_type, - rotary_embedding_percentage=rotary_embedding_percentage, - rotary_base=rotary_base, - rotary_scaling=rotary_scaling, - ) + + config = { + 'architecture': 'GPTForCausalLM', + 'dtype': 'float16', + 'num_hidden_layers': num_layers, + 'num_attention_heads': 4, + 'hidden_size': 128, + 'vocab_size': 256, + 'max_position_embeddings': 1024, + 'hidden_act': 'gelu', + 'position_embedding_type': position_embedding_type, + 'rotary_pct': rotary_embedding_percentage, + 'rotary_base': rotary_base, + 'rotary_scaling': rotary_scaling, + } + config = tensorrt_llm.models.PretrainedConfig.from_dict(config) + tensorrt_llm_gpt = tensorrt_llm.models.GPTForCausalLM(config) + for layer_i in range(num_layers): - assert tensorrt_llm_gpt.layers[ + assert tensorrt_llm_gpt.transformer.layers[ layer_i].attention.rotary_embedding_base == rotary_base - assert tensorrt_llm_gpt.layers[ + assert tensorrt_llm_gpt.transformer.layers[ layer_i].attention.rotary_embedding_scale == rotary_scaling[ "factor"] - assert tensorrt_llm_gpt.layers[ + assert tensorrt_llm_gpt.transformer.layers[ layer_i].attention.rotary_embedding_scale_type == RotaryScalingType.linear - assert tensorrt_llm_gpt.layers[ - layer_i].attention.position_embedding_type == position_embedding_type + assert tensorrt_llm_gpt.transformer.layers[ + layer_i].attention.position_embedding_type == PositionEmbeddingType.rope_gpt_neox if __name__ == '__main__': diff --git a/tests/model/test_gpt_e2e.py b/tests/model/test_gpt_e2e.py index 7b93ae681..c7a72d194 100644 --- a/tests/model/test_gpt_e2e.py +++ b/tests/model/test_gpt_e2e.py @@ -12,24 +12,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import json import os import subprocess import sys import unittest from pathlib import Path +from typing import Sequence import numpy as np import torch import torch.multiprocessing as multiprocessing import tensorrt_llm -from tensorrt_llm.runtime import ModelConfig, SamplingConfig - -sys.path.insert( - 0, str(Path(__file__).resolve().parent.parent.parent / "examples/gpt")) -from build import get_engine_name, run_build # isort:skip -from hf_gpt_convert import ProgArgs, run_conversion +from tensorrt_llm.runtime import ModelRunner, SamplingConfig END_ID = 50256 PAD_ID = 50256 @@ -40,22 +35,44 @@ from llm_data import llm_models_root from util import getSMVersion +gpt_example_root = os.path.join(os.path.dirname(__file__), '../../examples/gpt') + + +def run_command(command: Sequence[str], *, cwd=None, **kwargs) -> None: + print(f"Running: cd %s && %s" % (str(cwd or Path.cwd()), " ".join(command)), + flush=True) + subprocess.check_call(command, cwd=cwd, **kwargs) + + +def convert_ckpt(model_dir: str, output_dir: str, *args): + convert_cmd = [ + sys.executable, f"{gpt_example_root}/convert_checkpoint.py", + f"--model_dir={model_dir}", f"--output_dir={output_dir}" + ] + list(args) + run_command(convert_cmd) + -def build_engine(weight_dir: Path, engine_dir: Path, *args): - print( - f"== Build engine from {weight_dir} to {engine_dir}, with args {args}") - run_build([ - '--model_dir', - str(weight_dir), - '--output_dir', - str(engine_dir), +def build_engine(checkpoint_dir: str, engine_dir: str, *args): + build_cmd = [ + "trtllm-build", + f"--checkpoint_dir={checkpoint_dir}", + f"--output_dir={engine_dir}", '--log_level=verbose', '--max_batch_size=256', '--max_input_len=40', '--max_output_len=20', '--max_beam_width=2', '--builder_opt=0', - ] + list(args)) + ] + legacy_args = [ + "--gpt_attention_plugin=disable", + "--context_fmha=disable", + "--paged_kv_cache=disable", + "--remove_input_padding=disable", + "--enable_xqa=disable", + ] + build_cmd = build_cmd + legacy_args + list(args) + run_command(build_cmd) def build_engines(): @@ -77,108 +94,59 @@ def build_engines(): pytorch_model ]) - weight_dir = work_dir / 'c-model/gpt2' + ckpt_dir = work_dir / 'c-model/gpt2' engine_dir = work_dir / 'rt_engine/gpt2' print("\nConverting to fp32") - fp32_weight_dir = weight_dir / 'fp32/1-gpu' - run_conversion( - ProgArgs(in_file=str(gpt2_dir), - out_dir=str(fp32_weight_dir), - storage_type='float32')) + fp32_ckpt_dir = ckpt_dir / 'fp32/1-gpu' + convert_ckpt(str(gpt2_dir), str(fp32_ckpt_dir), "--dtype=float32") print("\nBuilding fp32 engines") - fp32_weight_dir_1_gpu = fp32_weight_dir / '1-gpu' - build_engine(fp32_weight_dir_1_gpu, engine_dir / 'fp32-default/1-gpu', - '--dtype=float32') - build_engine(fp32_weight_dir_1_gpu, engine_dir / 'fp32-plugin/1-gpu', - '--dtype=float32', '--use_gpt_attention_plugin=float32') + + build_engine(str(fp32_ckpt_dir), str(engine_dir / 'fp32-default/1-gpu')) + build_engine(str(fp32_ckpt_dir), str(engine_dir / 'fp32-plugin/1-gpu'), + '--gpt_attention_plugin=float32') print("\nConverting to fp16") - fp16_weight_dir = weight_dir / 'fp16/1-gpu' - run_conversion( - ProgArgs(in_file=str(gpt2_dir), - out_dir=str(fp16_weight_dir), - storage_type='float16')) + fp16_ckpt_dir = ckpt_dir / 'fp16/1-gpu' + convert_ckpt(str(gpt2_dir), str(fp16_ckpt_dir), "--dtype=float16") print("\nBuilding fp16 engines") - fp16_weight_dir_1_gpu = fp16_weight_dir / '1-gpu' - build_engine(fp16_weight_dir_1_gpu, engine_dir / 'fp16-default/1-gpu', - '--dtype=float16', '--strongly_typed') - build_engine(fp16_weight_dir_1_gpu, engine_dir / 'fp16-plugin/1-gpu', - '--dtype=float16', '--use_gpt_attention_plugin=float16', + build_engine(str(fp16_ckpt_dir), str(engine_dir / 'fp16-default/1-gpu'), '--strongly_typed') + build_engine(str(fp16_ckpt_dir), str(engine_dir / 'fp16-plugin/1-gpu'), + '--gpt_attention_plugin=float16', '--strongly_typed') # Skip tests that are not supported in pre-ampere architecture if getSMVersion() >= 80: - build_engine(fp16_weight_dir_1_gpu, - engine_dir / 'fp16-plugin-fmha/1-gpu', '--dtype=float16', - '--use_gpt_attention_plugin=float16', - '--enable_context_fmha', '--strongly_typed') + build_engine(str(fp16_ckpt_dir), + str(engine_dir / 'fp16-plugin-fmha/1-gpu'), + '--gpt_attention_plugin=float16', '--context_fmha=enable', + '--strongly_typed') - build_engine(fp16_weight_dir_1_gpu, engine_dir / 'fp16-plugin-packed/1-gpu', - '--dtype=float16', '--use_gpt_attention_plugin=float16', - '--remove_input_padding', '--strongly_typed') + build_engine(str(fp16_ckpt_dir), + str(engine_dir / 'fp16-plugin-packed/1-gpu'), + '--gpt_attention_plugin=float16', + '--remove_input_padding=enable', '--strongly_typed') # Skip tests that are not supported in pre-ampere architecture if getSMVersion() >= 80: - build_engine(fp16_weight_dir_1_gpu, - engine_dir / 'fp16-plugin-packed-fmha/1-gpu', - '--dtype=float16', '--use_gpt_attention_plugin=float16', - '--remove_input_padding', '--enable_context_fmha', + build_engine(fp16_ckpt_dir, + str(engine_dir / 'fp16-plugin-packed-fmha/1-gpu'), + '--gpt_attention_plugin=float16', + '--remove_input_padding=enable', '--context_fmha=enable', '--strongly_typed') print("Done.") def check_accuracy(engine_dir, input_tokens, max_output_len): - config_path = os.path.join(engine_dir, 'config.json') - with open(config_path, 'r') as f: - config = json.load(f) - dtype = config['builder_config']['precision'] - use_gpt_attention_plugin = config['plugin_config']['gpt_attention_plugin'] - remove_input_padding = config['plugin_config']['remove_input_padding'] - dtype = config['builder_config']['precision'] - world_size = config['builder_config']['tensor_parallel'] - assert world_size == tensorrt_llm.mpi_world_size(), \ - f'Engine world size ({world_size}) != Runtime world size ({tensorrt_llm.mpi_world_size()})' - num_heads = config['builder_config']['num_heads'] // world_size - hidden_size = config['builder_config']['hidden_size'] // world_size - max_batch_size = config['builder_config']['max_batch_size'] - max_beam_width = config['builder_config']['max_beam_width'] - vocab_size = config['builder_config']['vocab_size'] - num_layers = config['builder_config']['num_layers'] - num_kv_heads = config['builder_config']['num_kv_heads'] - runtime_rank = tensorrt_llm.mpi_rank() - runtime_mapping = tensorrt_llm.Mapping(world_size, - runtime_rank, - tp_size=world_size) - torch.cuda.set_device(runtime_rank % runtime_mapping.gpus_per_node) - - model_config = ModelConfig(max_batch_size=max_batch_size, - max_beam_width=max_beam_width, - num_heads=num_heads, - num_kv_heads=num_kv_heads, - hidden_size=hidden_size, - vocab_size=vocab_size, - num_layers=num_layers, - gpt_attention_plugin=use_gpt_attention_plugin, - remove_input_padding=remove_input_padding, - dtype=dtype) + runner = ModelRunner.from_dir(engine_dir, rank=runtime_rank) sampling_config = SamplingConfig(end_id=END_ID, pad_id=END_ID) - engine_name = get_engine_name('gpt', dtype, world_size, runtime_rank) - serialize_path = os.path.join(engine_dir, engine_name) - with open(serialize_path, 'rb') as f: - engine_buffer = f.read() - decoder = tensorrt_llm.runtime.GenerationSession(model_config, - engine_buffer, - runtime_mapping) - - input_lengths = torch.tensor([len(x) for x in input_tokens], - dtype=torch.int, - device='cuda') + all_input_ids = [torch.tensor(x, dtype=torch.int32) for x in input_tokens] + all_input_lengths = [len(x) for x in input_tokens] num_samples = len(input_tokens) expect_output = None @@ -188,50 +156,45 @@ def check_accuracy(engine_dir, input_tokens, max_output_len): output_with_fake_dim = [] print(f"Running batch size: {batch_size}") for i in range(num_samples // batch_size): - samples = input_tokens[i * batch_size:(i + 1) * batch_size] - sample_lengths = input_lengths[i * batch_size:(i + 1) * batch_size] - if remove_input_padding: - input_ids = np.concatenate(samples) - input_ids = torch.tensor(input_ids, - dtype=torch.int, - device='cuda') - input_ids_with_fake_dim = input_ids.unsqueeze(0) - max_input_length = torch.max(sample_lengths).item() - else: - input_ids = torch.nested.to_padded_tensor( - torch.nested.nested_tensor(samples, dtype=torch.int32), - PAD_ID).cuda() - max_input_length = input_ids.size(1) - - decoder.setup(batch_size, max_input_length, max_output_len) - output_ids = decoder.decode(input_ids, sample_lengths, - sampling_config) + batch_input_ids = all_input_ids[i * batch_size:(i + 1) * batch_size] + batch_input_lengths = all_input_lengths[i * batch_size:(i + 1) * + batch_size] + max_input_length = max(batch_input_lengths) + output_ids = runner.generate(batch_input_ids, + sampling_config=sampling_config, + max_new_tokens=max_output_len) torch.cuda.synchronize() - if remove_input_padding: - decoder.setup(batch_size, max_input_length, max_output_len) - output_ids_with_fake_dim = decoder.decode( - input_ids_with_fake_dim, sample_lengths, sampling_config) + if runner.remove_input_padding: + runner.session.setup(batch_size, max_input_length, + max_output_len) + batch_input_ids_with_fake_dim = torch.concat( + batch_input_ids).unsqueeze(0) + + output_ids_with_fake_dim = runner.session.decode( + batch_input_ids_with_fake_dim.cuda(), + torch.tensor(batch_input_lengths, dtype=torch.int32).cuda(), + sampling_config) outputs_with_fake_dim_list = [ - output_ids_with_fake_dim[ - batch_idx, :, - sample_lengths[batch_idx]:sample_lengths[batch_idx] + - max_output_len].cpu() + output_ids_with_fake_dim[batch_idx, :, + batch_input_lengths[batch_idx]: + batch_input_lengths[batch_idx] + + max_output_len].cpu() for batch_idx in range(output_ids_with_fake_dim.shape[0]) ] outputs_with_fake_dim = torch.cat(outputs_with_fake_dim_list) output_with_fake_dim.append(outputs_with_fake_dim) outputs_list = [ - output_ids[batch_idx, :, - sample_lengths[batch_idx]:sample_lengths[batch_idx] + + output_ids[batch_idx, :, batch_input_lengths[batch_idx]: + batch_input_lengths[batch_idx] + max_output_len].cpu() for batch_idx in range(output_ids.shape[0]) ] outputs = torch.cat(outputs_list) output.append(outputs) output = torch.stack(output, dim=0) - if remove_input_padding: + if runner.remove_input_padding: output_with_fake_dim = torch.stack(output_with_fake_dim, dim=0) error = np.mean(output.cpu().numpy().flatten() != output_with_fake_dim.cpu().numpy().flatten()) diff --git a/tests/model/test_llama.py b/tests/model/test_llama.py index 75b4ef322..855d30af2 100644 --- a/tests/model/test_llama.py +++ b/tests/model/test_llama.py @@ -82,7 +82,6 @@ def _gen_tensorrt_llm_network(self, network, hf_llama, }, 'use_parallel_embedding': False, 'embedding_sharding_dim': 0, - 'use_prompt_tuning': False, 'moe_num_experts': 0, 'moe_top_k': 0, 'moe_tp_mode': 2, @@ -529,7 +528,6 @@ def print_layers(m: tensorrt_llm.models.LLaMAForCausalLM): }, 'use_parallel_embedding': use_parallel_embedding, 'embedding_sharding_dim': embedding_sharding_dim, - 'use_prompt_tuning': False, 'moe_num_experts': 0, 'moe_top_k': 0, 'moe_tp_mode': 1, diff --git a/tests/model/test_mamba.py b/tests/model/test_mamba.py index 27ff3a484..8ec22e1b8 100644 --- a/tests/model/test_mamba.py +++ b/tests/model/test_mamba.py @@ -146,6 +146,7 @@ def test_mamba(self, gemm_plugin, dtype): load_mode = 'from_model' hf_path = '' hf_config = MambaConfig(d_model=128, n_layer=2, vocab_size=128) + pad_id = 0 # get hf mamba hf_mamba = MambaLMHeadModel(hf_config, @@ -161,16 +162,10 @@ def test_mamba(self, gemm_plugin, dtype): mamba_d_inner = hf_mamba.backbone.layers[0].mixer.d_inner mamba_d_conv = hf_mamba.backbone.layers[0].mixer.d_conv mamba_d_state = hf_mamba.backbone.layers[0].mixer.d_state - ctx_conv_state_shape = ( + conv_state_shape = ( batch_size, mamba_d_inner, - mamba_d_conv - 1 + input_len, - ) - - gen_conv_state_shape = ( - batch_size, - mamba_d_inner, - mamba_d_conv, + mamba_d_conv - 1, ) ssm_state_shape = ( @@ -183,40 +178,66 @@ def test_mamba(self, gemm_plugin, dtype): present_ssm_states = [] for _ in range(hf_config.n_layer): present_conv_states.append( - torch.zeros(ctx_conv_state_shape, + torch.zeros(conv_state_shape, dtype=str_dtype_to_torch(dtype), device='cuda')) present_conv_states_1.append( - torch.empty(gen_conv_state_shape, + torch.empty(conv_state_shape, dtype=str_dtype_to_torch(dtype), device='cuda')) present_ssm_states.append( - torch.empty(ssm_state_shape, dtype=torch.float32, + torch.empty(ssm_state_shape, + dtype=str_dtype_to_torch(dtype), device='cuda')) # compare context - ctx_ids = torch.randint(100, (batch_size, input_len)).int().cuda() - ctx_last_token_ids = input_len * torch.ones( - (batch_size), dtype=torch.int32, device='cuda') + ctx_len = torch.randint(1, + input_len + 1, + size=(batch_size, ), + dtype=torch.int32, + device='cuda') + ctx_len[0] = input_len + ctx_ids = [ + torch.randint(100, + size=(ctx_len[i], ), + dtype=torch.int32, + device='cuda') for i in range(batch_size) + ] + paddings = [ + torch.ones(input_len - l, dtype=torch.int32, device='cuda') * pad_id + for l in ctx_len + ] + ctx_ids = [torch.cat([x, pad]) for x, pad in zip(ctx_ids, paddings)] + ctx_ids = torch.stack(ctx_ids) + ctx_last_token_ids = ctx_len.detach().clone() + ctx_conv_token_ids = torch.zeros((batch_size, ), + dtype=torch.int32, + device='cuda') + ctx_host_request_types = torch.tensor([0] * batch_size, dtype=torch.int32) - infer_params = InferenceParams(max_seqlen=input_len + output_len, - max_batch_size=batch_size) + infer_params = [ + InferenceParams(max_seqlen=input_len + output_len, max_batch_size=1) + for i in range(batch_size) + ] with torch.no_grad(): - hf_outputs = hf_mamba.forward(ctx_ids, - inference_params=infer_params) - infer_params.seqlen_offset += ctx_ids.shape[1] + hf_outputs = [] + for i in range(batch_size): + hf_output = hf_mamba.forward(ctx_ids[i:i + 1, 0:ctx_len[i]], + inference_params=infer_params[i]) + hf_outputs.append(hf_output.logits[:, -1, :]) + infer_params[i].seqlen_offset += ctx_len[i] torch.cuda.synchronize() - ref = hf_outputs.logits[:, -1, :] + ref = torch.concat(hf_outputs, dim=0) ctx_buffer = { 'input_ids': ctx_ids, 'last_token_ids': ctx_last_token_ids, 'host_request_types': ctx_host_request_types, + 'conv_token_ids': ctx_conv_token_ids, } for idx in range(hf_config.n_layer): - conv_state_shape = (batch_size, mamba_d_inner, mamba_d_conv - 1) conv_state = torch.zeros(conv_state_shape, dtype=str_dtype_to_torch(dtype), device='cuda') @@ -239,21 +260,27 @@ def test_mamba(self, gemm_plugin, dtype): # compare generation step1_id = torch.randint(100, (batch_size, 1)).int().cuda() - gen_last_token_ids = torch.zeros((batch_size), - dtype=torch.int32, - device='cuda') + gen_last_token_ids = torch.ones((batch_size), + dtype=torch.int32, + device='cuda') + gen_conv_token_ids = ctx_last_token_ids.detach().clone() gen_host_request_types = torch.tensor([1] * batch_size, dtype=torch.int32) with torch.no_grad(): - hf_outputs = hf_mamba.forward(step1_id, - inference_params=infer_params) - infer_params.seqlen_offset += step1_id.shape[1] + hf_outputs = [] + for i in range(batch_size): + hf_output = hf_mamba.forward(step1_id[i:i + 1, ], + inference_params=infer_params[i]) + hf_outputs.append(hf_output.logits[:, -1, :]) + infer_params[i].seqlen_offset += step1_id.shape[1] torch.cuda.synchronize() - ref = hf_outputs.logits[:, -1, :] + ref = torch.concat(hf_outputs, dim=0) + step1_buffer = { 'input_ids': step1_id, 'last_token_ids': gen_last_token_ids, 'host_request_types': gen_host_request_types, + 'conv_token_ids': gen_conv_token_ids, } for idx in range(hf_config.n_layer): step1_buffer[f'past_conv_state_{idx}'] = present_conv_states[idx] diff --git a/tests/model/test_mistral.py b/tests/model/test_mistral.py index 576c2374e..b6621d276 100644 --- a/tests/model/test_mistral.py +++ b/tests/model/test_mistral.py @@ -78,7 +78,6 @@ def _gen_tensorrt_llm_network(self, network, hf_mistral, }, 'use_parallel_embedding': False, 'embedding_sharding_dim': 0, - 'use_prompt_tuning': False, 'moe_num_experts': 0, 'moe_top_k': 0, 'moe_tp_mode': 1, @@ -501,7 +500,6 @@ def print_layers(m: tensorrt_llm.models.LLaMAForCausalLM): }, 'use_parallel_embedding': use_parallel_embedding, 'embedding_sharding_dim': embedding_sharding_dim, - 'use_prompt_tuning': False, 'moe_num_experts': 0, 'moe_top_k': 0, 'moe_tp_mode': 1, diff --git a/tests/model_api/test_model_api_multi_gpu.py b/tests/model_api/test_model_api_multi_gpu.py index d607a7612..98b08ef38 100644 --- a/tests/model_api/test_model_api_multi_gpu.py +++ b/tests/model_api/test_model_api_multi_gpu.py @@ -15,13 +15,14 @@ from tensorrt_llm.auto_parallel.config import (AutoParallelConfig, infer_cluster_key) from tensorrt_llm.builder import BuildConfig, build -from tensorrt_llm.executor import GenerationExecutor +from tensorrt_llm.executor import GenerationExecutorWorker from tensorrt_llm.hlapi.utils import print_traceback_on_error from tensorrt_llm.models import LLaMAForCausalLM sys.path.append(os.path.join(os.path.dirname(__file__), '..')) from utils import llm_data from utils.llm_data import llm_models_root +from utils.util import force_ampere # MPIPoolExecutor only serializes function name and let workers find it in Python path. # Since all tests are not installed in Python path, workers will fail. @@ -99,18 +100,20 @@ def build_and_run_tp2(rank, model_name, engine_dir, use_auto_parallel): engine.save(engine_dir) mpi_barrier() tensorrt_llm.logger.warning(f"Build finished for rank {rank}") - executor = GenerationExecutor(engine_dir, tokenizer_dir) - mpi_barrier() - for idx, output in enumerate(executor.generate(input_text, 10)): - tensorrt_llm.logger.info(f"{rank} input: {input_text[idx]}") - tensorrt_llm.logger.info(f"{rank} output: {output.text}") - assert output.text.endswith( - expected_output[idx] - ), f"Expecting {expected_output[idx]}, got {output.text}" + with GenerationExecutorWorker(engine_dir, tokenizer_dir) as executor: + executor.block_subordinates() + + for idx, output in enumerate(executor.generate(input_text, 10)): + tensorrt_llm.logger.info(f"{rank} input: {input_text[idx]}") + tensorrt_llm.logger.info(f"{rank} output: {output.text}") + assert output.text.endswith( + expected_output[idx] + ), f"Expecting {expected_output[idx]}, got {output.text}" mpi_barrier() return True +@force_ampere @pytest.mark.parametrize("use_auto_parallel", [True, False], ids=["enable_auto_parallel", "disable_auto_parallel"]) @pytest.mark.parametrize("model_name", diff --git a/tests/model_api/test_model_level_api.py b/tests/model_api/test_model_level_api.py index c7bf7f619..13daf0fb1 100644 --- a/tests/model_api/test_model_level_api.py +++ b/tests/model_api/test_model_level_api.py @@ -13,7 +13,7 @@ sys.path.append(os.path.join(os.path.dirname(__file__), '..')) from utils.llm_data import llm_models_root -from utils.util import skip_pre_ampere +from utils.util import force_ampere, skip_pre_ampere tensorrt_llm.logger.set_level('verbose') @@ -41,7 +41,7 @@ def workspace(suffix, prefix="./trtllm_workspace"): # 233s on ipp1-1197: loading weights 37s, network/engine 27s, save engine: 35s, load engine (14GB) about 100s @profile("save-and-load") -@skip_pre_ampere +@force_ampere def test_save_load(): '''When the engine_dir parameter of to_trt and generate is not None to_trt() saves the engine to disk. @@ -51,50 +51,33 @@ def test_save_load(): max_batch_size, max_isl, max_osl = 8, 256, 256 hf_model_dir = llm_models_root() / "llama-models/llama-7b-hf" tokenizer_dir = hf_model_dir + with workspace("llama-save-load") as engine_dir: # build and run by one llama object llama = LLaMAForCausalLM.from_hugging_face(hf_model_dir, 'float16') - engine = build( - llama, - BuildConfig(max_batch_size=max_batch_size, - max_input_len=max_isl, - max_output_len=max_osl)) + build_config = BuildConfig(max_batch_size=max_batch_size, + max_input_len=max_isl, + max_output_len=max_osl, + plugin_config=llama.default_plugin_config()) + build_config.plugin_config.gemm_plugin = 'float16' # faster build + engine = build(llama, build_config) engine.save(engine_dir) - executor = GenerationExecutor(engine_dir, tokenizer_dir) - for idx, output in enumerate( - executor.generate(input_text, [10] * len(input_text))): - tensorrt_llm.logger.info(f"Input: {input_text[idx]}") - tensorrt_llm.logger.info(f'Output: {output.text}') - # note the output.text contains everything from the input, so only compare the suffix here. - assert output.text.endswith( - expected_output[idx] - ), f"Expecting and got:'{expected_output[idx]}' Got: '{output.text}'" - - -# 76s on ipp1-1197, loading weights 18s (varies based on network speed), network/engine creation 27s -@profile("all-in-one-step") -@skip_pre_ampere -def test_all_in_one_step(): - '''Do not save the engine, all in one LLaMAForCausalLM object - ''' - max_batch_size, max_isl, max_osl = 8, 256, 256 - hf_model_dir = llm_models_root() / "llama-models/llama-7b-hf" - - # build and run by one llama object - llama = LLaMAForCausalLM.from_hugging_face(hf_model_dir, 'float16') - build( - llama, - BuildConfig(max_batch_size=max_batch_size, - max_input_len=max_isl, - max_output_len=max_osl)) - - # TODO (tali): init the generation executor from the in-memory engine - # This is depending on WIP MR https://gitlab-master.nvidia.com/ftp/tekit/-/merge_requests/2785 + # use context manager to make sure the __exit__ can release the resources immediately + with GenerationExecutor.create(engine_dir, tokenizer_dir) as executor: + for idx, output in enumerate( + executor.generate(input_text, + max_new_tokens=[10] * len(input_text))): + tensorrt_llm.logger.info(f"Input: {input_text[idx]}") + tensorrt_llm.logger.info(f'Output: {output.text}') + # note the output.text contains everything from the input, so only compare the suffix here. + assert output.text.endswith( + expected_output[idx] + ), f"Expecting and got:'{expected_output[idx]}' Got: '{output.text}'" @profile(tag="fake-weights") -@skip_pre_ampere +@force_ampere def test_high_level_fake_weights(): '''sanity to make sure the flow works. The key is "skip_loading_weights" param ''' @@ -109,65 +92,65 @@ def test_high_level_fake_weights(): llama = LLaMAForCausalLM.from_hugging_face(hf_model_dir, 'float16', skip_loading_weights=True) - build( - llama, - BuildConfig(max_batch_size=max_batch_size, - max_input_len=max_isl, - max_output_len=max_osl)) + build_config = BuildConfig(max_batch_size=max_batch_size, + max_input_len=max_isl, + max_output_len=max_osl, + plugin_config=llama.default_plugin_config()) + build_config.plugin_config.gemm_plugin = 'float16' # faster build + build(llama, build_config) @skip_pre_ampere -def _test_inflight_batching(): - # TODO[chunweiy]: Enable it later +def test_inflight_batching(): max_batch_size, max_isl, max_osl = 8, 256, 256 hf_model_dir = llm_models_root() / "llama-models/llama-7b-hf" tokenizer_dir = hf_model_dir llama = LLaMAForCausalLM.from_hugging_face(hf_model_dir, 'float16') - engine = build( - llama, - BuildConfig(max_batch_size=max_batch_size, - max_input_len=max_isl, - max_output_len=max_osl)) + build_config = BuildConfig(max_batch_size=max_batch_size, + max_input_len=max_isl, + max_output_len=max_osl) + build_config.plugin_config.gemm_plugin = 'float16' # faster build + engine = build(llama, build_config) + engine_dir = "llama-ifb" engine_temp = tempfile.TemporaryDirectory(engine_dir) engine_dir = engine_temp.name engine.save(engine_dir) async def main(): - async_engine = GenerationExecutor(engine_dir, tokenizer_dir) - - async def generate_and_print(idx, inp): - result = async_engine.generate_async(inp, - streaming=False, - max_new_tokens=10) - await result.aresult() - tensorrt_llm.logger.info(result.text) - assert result.text.endswith(expected_output[idx]) - - output = "" - async for stream in async_engine.generate_async(inp, - streaming=True, - max_new_tokens=10): - output += stream.text + ' ' - tensorrt_llm.logger.info( - f"prompt: '{inp}', generation: '{output}'") - - loop = asyncio.get_running_loop() - tasks = [] - # submit many request concurrently - for idx, inp in enumerate(input_text): - task = loop.create_task(generate_and_print(idx, inp)) - tasks.append(task) - - # wait all task done - await asyncio.gather(*tasks) + with GenerationExecutor.create(engine_dir, + tokenizer_dir) as async_engine: + + async def generate_and_print(idx, inp): + result = async_engine.generate_async(inp, + streaming=False, + max_new_tokens=10) + await result.aresult() + tensorrt_llm.logger.info(result.text) + assert result.text.endswith(expected_output[idx]) + + output = "" + async for stream in async_engine.generate_async( + inp, streaming=True, max_new_tokens=10): + output += stream.text + ' ' + tensorrt_llm.logger.info( + f"prompt: '{inp}', generation: '{output}'") + + loop = asyncio.get_running_loop() + tasks = [] + # submit many request concurrently + for idx, inp in enumerate(input_text): + task = loop.create_task(generate_and_print(idx, inp)) + tasks.append(task) + + # wait all task done + await asyncio.gather(*tasks) asyncio.run(main()) if __name__ == "__main__": - test_all_in_one_step() - test_high_level_fake_weights() test_save_load() test_inflight_batching() + test_high_level_fake_weights() diff --git a/tests/model_api/test_model_quantization.py b/tests/model_api/test_model_quantization.py index fd2e63d04..e0a7b84c5 100644 --- a/tests/model_api/test_model_quantization.py +++ b/tests/model_api/test_model_quantization.py @@ -1,21 +1,23 @@ import os import sys import tempfile +from pathlib import Path import tensorrt_llm +import tensorrt_llm.quantization.mode as quant_algo from tensorrt_llm.builder import BuildConfig, build from tensorrt_llm.executor import GenerationExecutor from tensorrt_llm.models import LLaMAForCausalLM -from tensorrt_llm.quantization.mode import QuantMode +from tensorrt_llm.models.modeling_utils import QuantizationConfig sys.path.append(os.path.join(os.path.dirname(__file__), '..')) from utils.llm_data import llm_models_root -from utils.util import skip_no_ammo, skip_pre_ada, skip_pre_ampere +from utils.util import force_ampere, skip_no_ammo, skip_pre_ada tensorrt_llm.logger.set_level('info') -@skip_pre_ampere +@force_ampere @skip_no_ammo def test_int4_awq_quantization(): input_text = [ @@ -25,19 +27,14 @@ def test_int4_awq_quantization(): max_batch_size, max_isl, max_osl = 8, 256, 256 hf_model_dir = llm_models_root() / "llama-models/llama-7b-hf" tokenizer_dir = hf_model_dir - - quant_mode_int4_awq = QuantMode.from_description(quantize_weights=True, - quantize_activations=False, - per_token=False, - per_channel=False, - per_group=True, - use_int4_weights=True) - - hf_model_dir = llm_models_root() / "llama-models/llama-7b-hf" - llama = LLaMAForCausalLM.from_hugging_face(hf_model_dir, - 'float16', - quant_mode=quant_mode_int4_awq, - quantize_lm_head=True) + checkpoint_dir = tempfile.TemporaryDirectory("llama-checkpoint").name + quant_config = QuantizationConfig(quant_algo.W4A16_AWQ) + LLaMAForCausalLM.quantize(hf_model_dir, + checkpoint_dir, + quant_config=quant_config, + calib_batches=32, + calib_batch_size=32) + llama = LLaMAForCausalLM.from_checkpoint(checkpoint_dir) engine = build( llama, BuildConfig(max_batch_size=max_batch_size, @@ -48,11 +45,12 @@ def test_int4_awq_quantization(): engine_temp = tempfile.TemporaryDirectory(engine_dir) engine_dir = engine_temp.name engine.save(engine_dir) - executor = GenerationExecutor(engine_dir, tokenizer_dir) - for idx, output in enumerate(executor.generate(input_text, 10)): - print(f"Input: {input_text[idx]}") - print(f'Output: {output.text}') - # TODO: TRTLLM-185, check the score when the test infra is ready, hard coded value is not stable, cause flaky tests in L0 + with GenerationExecutor.create(Path(engine_dir), tokenizer_dir) as executor: + for idx, output in enumerate( + executor.generate(input_text, max_new_tokens=10)): + print(f"Input: {input_text[idx]}") + print(f'Output: {output.text}') + # TODO: TRTLLM-185, check the score when the test infra is ready, hard coded value is not stable, cause flaky tests in L0 @skip_pre_ada @@ -66,28 +64,30 @@ def test_fp8_quantization(): hf_model_dir = llm_models_root() / "llama-models/llama-7b-hf" tokenizer_dir = hf_model_dir - quant_mode = QuantMode(0) - quant_mode = quant_mode.set_fp8_qdq() - quant_mode = quant_mode.set_fp8_kv_cache() + checkpoint_dir = tempfile.TemporaryDirectory("llama-checkpoint").name + quant_config = QuantizationConfig(quant_algo.FP8) + LLaMAForCausalLM.quantize(hf_model_dir, + checkpoint_dir, + quant_config=quant_config, + calib_batches=32) + llama = LLaMAForCausalLM.from_checkpoint(checkpoint_dir) - hf_model_dir = llm_models_root() / "llama-models/llama-7b-hf" - llama = LLaMAForCausalLM.from_hugging_face(hf_model_dir, - 'float16', - quant_mode=quant_mode) engine = build( llama, BuildConfig(max_batch_size=max_batch_size, max_input_len=max_isl, - max_output_len=max_osl)) + max_output_len=max_osl, + strongly_typed=True)) engine_dir = "llama-fp8-quantized" engine_temp = tempfile.TemporaryDirectory(engine_dir) engine_dir = engine_temp.name engine.save(engine_dir) - executor = GenerationExecutor(engine_dir, tokenizer_dir) - for idx, output in enumerate(executor.generate(input_text, 10)): - print(f"Input: {input_text[idx]}") - print(f'Output: {output.text}') - # TODO: TRTLLM-185, check the score when the test infra is ready, hard coded value is not stable, cause flaky tests in L0 + with GenerationExecutor.create(Path(engine_dir), tokenizer_dir) as executor: + for idx, output in enumerate( + executor.generate(input_text, max_new_tokens=10)): + print(f"Input: {input_text[idx]}") + print(f'Output: {output.text}') + # TODO: TRTLLM-185, check the score when the test infra is ready, hard coded value is not stable, cause flaky tests in L0 if __name__ == "__main__": diff --git a/tests/quantization/test_quant.py b/tests/quantization/test_quant.py index a6b88e342..0bf65bdef 100644 --- a/tests/quantization/test_quant.py +++ b/tests/quantization/test_quant.py @@ -15,109 +15,124 @@ import unittest from tensorrt_llm.layers import ColumnLinear, RowLinear -from tensorrt_llm.mapping import Mapping -from tensorrt_llm.models import GPTLMHeadModel, quantize_model +from tensorrt_llm.models import GPTForCausalLM, PretrainedConfig from tensorrt_llm.quantization import QuantMode from tensorrt_llm.quantization.layers import (SmoothQuantAttention, SmoothQuantLayerNorm, SmoothQuantMLP, WeightOnlyQuantColumnLinear, WeightOnlyQuantRowLinear) +from tensorrt_llm.quantization.mode import W8A8_SQ_PER_TENSOR_PLUGIN +from tensorrt_llm.quantization.quantize import quantize class TestQuant(unittest.TestCase): def test_weight_only_quant(self): mode = QuantMode.use_weight_only() - - model = GPTLMHeadModel(num_layers=2, - num_heads=12, - hidden_size=768, - vocab_size=51200, - hidden_act='relu', - max_position_embeddings=1024, - dtype='float16') - - quant_model = quantize_model(model, mode) + config = { + 'architecture': 'GPTForCausalLM', + 'dtype': 'float16', + 'num_hidden_layers': 2, + 'num_attention_heads': 12, + 'hidden_size': 768, + 'vocab_size': 51200, + 'max_position_embeddings': 1024, + 'hidden_act': 'relu', + } + config = PretrainedConfig.from_dict(config) + model = GPTForCausalLM(config) + + quant_model = quantize(model, mode) self.assertTrue(hasattr(quant_model, 'quant_mode')) self.assertTrue( - isinstance(quant_model.layers[0].attention.qkv, + isinstance(quant_model.transformer.layers[0].attention.qkv, WeightOnlyQuantColumnLinear)) self.assertTrue( - isinstance(quant_model.layers[0].attention.dense, + isinstance(quant_model.transformer.layers[0].attention.dense, WeightOnlyQuantRowLinear)) self.assertTrue( - isinstance(quant_model.layers[0].mlp.fc, + isinstance(quant_model.transformer.layers[0].mlp.fc, WeightOnlyQuantColumnLinear)) self.assertTrue( - isinstance(quant_model.layers[0].mlp.proj, + isinstance(quant_model.transformer.layers[0].mlp.proj, WeightOnlyQuantRowLinear)) self.assertTrue( - isinstance(quant_model.layers[1].attention.qkv, + isinstance(quant_model.transformer.layers[1].attention.qkv, WeightOnlyQuantColumnLinear)) self.assertTrue( - isinstance(quant_model.layers[1].attention.dense, + isinstance(quant_model.transformer.layers[1].attention.dense, WeightOnlyQuantRowLinear)) self.assertTrue( - isinstance(quant_model.layers[1].mlp.fc, + isinstance(quant_model.transformer.layers[1].mlp.fc, WeightOnlyQuantColumnLinear)) self.assertTrue( - isinstance(quant_model.layers[1].mlp.proj, + isinstance(quant_model.transformer.layers[1].mlp.proj, WeightOnlyQuantRowLinear)) self.assertTrue(isinstance(quant_model.lm_head, ColumnLinear)) def test_weight_only_quant_exclude_modules(self): mode = QuantMode.use_weight_only() - - model = GPTLMHeadModel(num_layers=1, - num_heads=12, - hidden_size=768, - vocab_size=51200, - hidden_act='relu', - max_position_embeddings=1024, - dtype='float16') - - quant_model = quantize_model(model, - mode, - exclude_modules=['fc', 'dense']) + config = { + 'architecture': 'GPTForCausalLM', + 'dtype': 'float16', + 'num_hidden_layers': 1, + 'num_attention_heads': 12, + 'hidden_size': 768, + 'vocab_size': 51200, + 'max_position_embeddings': 1024, + 'hidden_act': 'relu', + } + config = PretrainedConfig.from_dict(config) + model = GPTForCausalLM(config) + + quant_model = quantize(model, mode, exclude_modules=['fc', 'dense']) self.assertTrue(hasattr(quant_model, 'quant_mode')) self.assertTrue( - isinstance(quant_model.layers[0].attention.qkv, + isinstance(quant_model.transformer.layers[0].attention.qkv, WeightOnlyQuantColumnLinear)) self.assertTrue( - isinstance(quant_model.layers[0].attention.dense, RowLinear)) - self.assertTrue(isinstance(quant_model.layers[0].mlp.fc, ColumnLinear)) + isinstance(quant_model.transformer.layers[0].attention.dense, + RowLinear)) + self.assertTrue( + isinstance(quant_model.transformer.layers[0].mlp.fc, ColumnLinear)) self.assertTrue( - isinstance(quant_model.layers[0].mlp.proj, + isinstance(quant_model.transformer.layers[0].mlp.proj, WeightOnlyQuantRowLinear)) self.assertTrue( isinstance(quant_model.lm_head, WeightOnlyQuantColumnLinear)) def test_convert_GPT_to_smooth_quant(self): - gpt = GPTLMHeadModel(num_layers=1, - num_heads=1, - hidden_size=128, - vocab_size=1024, - hidden_act='gelu', - max_position_embeddings=256, - dtype='float16', - mapping=Mapping(world_size=1, rank=0, tp_size=1)) + config = { + 'architecture': 'GPTForCausalLM', + 'dtype': 'float16', + 'num_hidden_layers': 1, + 'num_attention_heads': 1, + 'hidden_size': 128, + 'vocab_size': 1024, + 'max_position_embeddings': 256, + 'hidden_act': 'gelu', + } + config = PretrainedConfig.from_dict(config) + model = GPTForCausalLM(config) quant_mode = QuantMode.use_smooth_quant() - sq_gpt = quantize_model(gpt, quant_mode) - for layer in sq_gpt.layers: + quant_model = quantize(model, + quant_mode, + quant_algo=W8A8_SQ_PER_TENSOR_PLUGIN) + for layer in quant_model.transformer.layers: assert isinstance(layer.input_layernorm, SmoothQuantLayerNorm) assert isinstance(layer.post_layernorm, SmoothQuantLayerNorm) assert isinstance(layer.mlp, SmoothQuantMLP) assert isinstance(layer.attention, SmoothQuantAttention) - assert sq_gpt.quant_mode == quant_mode + assert quant_model.quant_mode == quant_mode if __name__ == '__main__': diff --git a/tests/test_export.py b/tests/test_export.py index 3bfcd5eb5..83c3e67f4 100644 --- a/tests/test_export.py +++ b/tests/test_export.py @@ -20,8 +20,8 @@ import torch sys.path.append(str(Path(__file__).parent.resolve() / - "../examples/gpt/utils")) # more precise, avoid confusion -from convert import generate_int8 + "../examples/gpt")) # more precise, avoid confusion +from convert_checkpoint import generate_int8 def dist(x, y): diff --git a/tests/test_layer.py b/tests/test_layer.py index 7aad9ef84..14a3258c4 100644 --- a/tests/test_layer.py +++ b/tests/test_layer.py @@ -925,15 +925,15 @@ def test_attention(self, atol=a_tol, verbose=True) - @parameterized.expand([(1, 16, 1024, 16, 'context', 'float32'), - (1, 16, 1024, 16, 'context', 'float16'), - (1, 16, 1024, 16, 'context', 'bfloat16'), - (1, 1, 1024, 16, 'generation', 'float32'), - (1, 1, 1024, 16, 'generation', 'float16'), - (1, 1, 1024, 16, 'generation', 'bfloat16')], + @parameterized.expand([(3, 16, 1, 1024, 16, 'context', 'float32'), + (3, 16, 1, 1024, 16, 'context', 'float16'), + (3, 16, 1, 1024, 16, 'context', 'bfloat16'), + (3, 16, 1, 1024, 16, 'generation', 'float32'), + (3, 16, 1, 1024, 16, 'generation', 'float16'), + (3, 16, 1, 1024, 16, 'generation', 'bfloat16')], name_func=unittest_name_func) - def test_mamba(self, batch_size, seq_len, d_model, d_state, req_type, - dtype): + def test_mamba(self, batch_size, in_seq_len, out_seq_len, d_model, d_state, + req_type, dtype): # Skip tests that are not supported in pre-ampere architecture skip_bf16_pre_ampere(dtype) @@ -946,13 +946,33 @@ def test_mamba(self, batch_size, seq_len, d_model, d_state, req_type, conv_bias = True bias = False d_inner = int(expand * d_model) - seqlen_offset = 0 if req_type == 'context' else seq_len + seqlen_offset = 0 if req_type == 'context' else in_seq_len + seq_len = in_seq_len if req_type == 'context' else out_seq_len # test data torch_dtype = str_dtype_to_torch(dtype) mean = 0.0 std_dev = 0.1 if dtype == "float32" else 0.05 + if req_type == 'context': + last_token_ids = torch.randint(1, + in_seq_len + 1, + size=(batch_size, ), + dtype=torch.int32, + device=device) + last_token_ids[0] = in_seq_len + else: + last_token_ids = torch.ones(size=[batch_size], + dtype=torch.int32, + device=device) + offsets = last_token_ids.view([batch_size, 1, 1]) + conv_indices = torch.arange(0, + d_conv - 1, + dtype=torch.int32, + device=device).view([1, 1, d_conv - 1]) + conv_indices = conv_indices.expand([batch_size, d_inner, d_conv - 1 + ]) + offsets + hidden_states = torch.empty(size=[batch_size, seq_len, d_model], dtype=torch_dtype, device=device) @@ -980,10 +1000,9 @@ def test_mamba(self, batch_size, seq_len, d_model, d_state, req_type, output = torch.zeros(size=[batch_size, seq_len, d_model], dtype=torch_dtype, device=device) - present_conv_state = torch.zeros( - size=[batch_size, d_inner, d_conv - 1 + seq_len], - dtype=torch_dtype, - device=device) + present_conv_state = torch.zeros(size=[batch_size, d_inner, d_conv - 1], + dtype=torch_dtype, + device=device) hidden_states_ref = hidden_states.detach().clone() if req_type == 'context': @@ -996,7 +1015,7 @@ def test_mamba(self, batch_size, seq_len, d_model, d_state, req_type, dtype=torch_dtype, device=device), conv_state), dim=2).detach().clone() - ssm_state_ref = ssm_state.detach().clone().permute(0, 2, 1).contiguous() + ssm_state_ref = ssm_state.detach().clone() # get torch layer mamba_torch = mamba_ref(d_model, @@ -1020,7 +1039,7 @@ def test_mamba(self, batch_size, seq_len, d_model, d_state, req_type, D = torch.randn(d_inner, device=device) dt_bias = torch.rand(d_inner, device=device) - 4.0 - mamba_torch.A.data = A.detach().clone().permute(1, 0).contiguous() + mamba_torch.A.data = A.detach().clone() mamba_torch.D.data = D.detach().clone() mamba_torch.dt_proj.bias.data = dt_bias.detach().clone() @@ -1044,6 +1063,14 @@ def test_mamba(self, batch_size, seq_len, d_model, d_state, req_type, name='host_request_types', shape=host_request_types.shape, dtype=tensorrt_llm.str_dtype_to_trt('int32')) + conv_indices_tensor = Tensor( + name='conv_indices', + shape=conv_indices.shape, + dtype=tensorrt_llm.str_dtype_to_trt('int32')) + last_token_ids_tensor = Tensor( + name='last_token_ids', + shape=last_token_ids.shape, + dtype=tensorrt_llm.str_dtype_to_trt('int32')) mamba_layer = tensorrt_llm.layers.Mamba(d_model=d_model, d_state=d_state, d_conv=d_conv, @@ -1075,7 +1102,8 @@ def test_mamba(self, batch_size, seq_len, d_model, d_state, req_type, mamba_torch.dt_proj.weight.detach().cpu()) outputs = mamba_layer(hidden_states_tensor, conv_state_tensor, - ssm_state_tensor, host_request_types_tensor) + ssm_state_tensor, host_request_types_tensor, + conv_indices_tensor, last_token_ids_tensor) net._mark_output(outputs[0], 'output', dtype=tensorrt_llm.str_dtype_to_trt(dtype)) @@ -1091,7 +1119,9 @@ def test_mamba(self, batch_size, seq_len, d_model, d_state, req_type, 'hidden_states': hidden_states, 'conv_state': conv_state, 'ssm_state': ssm_state, - 'host_request_types': host_request_types + 'host_request_types': host_request_types, + 'conv_indices': conv_indices, + 'last_token_ids': last_token_ids, } outputs = { 'output': output, @@ -1108,25 +1138,43 @@ def test_mamba(self, batch_size, seq_len, d_model, d_state, req_type, # pytorch run out_ref, conv_state_ref, ssm_state_ref = mamba_torch( - hidden_states_ref, conv_state_ref, ssm_state_ref, seqlen_offset) - - ssm_state_cpu = outputs['present_ssm_state'].to(torch.float32).cpu() - ssm_state_cpu = ssm_state_cpu.permute(0, 2, 1).contiguous() - dtype_atol = {"float16": 5e-3, "float32": 2e-3, "bfloat16": 5e-2} - np.testing.assert_allclose( - out_ref.detach().to(torch.float32).cpu().numpy(), - outputs['output'].to(torch.float32).cpu().numpy(), - atol=dtype_atol[dtype]) - - np.testing.assert_allclose( - conv_state_ref[:, :, 1:].detach().to(torch.float32).cpu().numpy(), - outputs['present_conv_state'][:, :, - -3:].to(torch.float32).cpu().numpy(), - atol=dtype_atol[dtype]) - - np.testing.assert_allclose(ssm_state_ref.detach().to( - torch.float32).cpu().numpy(), - ssm_state_cpu.numpy(), + hidden_states_ref, last_token_ids, conv_state_ref, ssm_state_ref, + seqlen_offset) + + dtype_atol = {"float16": 5e-3, "float32": 3e-3, "bfloat16": 5e-2} + + # get out_mask + if req_type == 'context': + out_mask = torch.zeros(batch_size, seq_len, device=device) + for i in range(batch_size): + for j in range(last_token_ids[i]): + out_mask[i, j] = 1 + out_mask = out_mask.unsqueeze(2).expand( + [batch_size, seq_len, d_model]) + else: + out_mask = torch.ones(batch_size, seq_len, d_model, device=device) + + # compare out diff + out_ref = (out_ref * out_mask).detach().to(torch.float32).cpu().numpy() + outputs['output'][out_mask == 0] = 0 + out_trt_llm = outputs['output'].to(torch.float32).cpu().numpy() + np.testing.assert_allclose(out_ref, out_trt_llm, atol=dtype_atol[dtype]) + + # compare conv state diff + conv_state_ref = conv_state_ref[:, :, 1:].detach().to( + torch.float32).cpu().numpy() + conv_state_trt_llm = outputs['present_conv_state'].detach().to( + torch.float32).cpu().numpy() + np.testing.assert_allclose(conv_state_ref, + conv_state_trt_llm, + atol=dtype_atol[dtype]) + + # compare ssm state diff + ssm_state_ref = ssm_state_ref.detach().to(torch.float32).cpu().numpy() + ssm_state_trt_llm = outputs['present_ssm_state'] + ssm_state_trt_llm = ssm_state_trt_llm.to(torch.float32).cpu().numpy() + np.testing.assert_allclose(ssm_state_ref, + ssm_state_trt_llm, atol=dtype_atol[dtype]) diff --git a/tests/test_llama_conversion.sh b/tests/test_llama_conversion.sh index 063dc860f..972230321 100755 --- a/tests/test_llama_conversion.sh +++ b/tests/test_llama_conversion.sh @@ -70,7 +70,7 @@ test_wo_int8() { } test_sq() { - python3 convert_checkpoint.py --model_dir ${MODEL} --output_dir ./tllm_checkpoint/sq --dtype float16 --smoothquant 0.5 + python3 convert_checkpoint.py --model_dir ${MODEL} --output_dir ./tllm_checkpoint/sq --dtype float16 --smoothquant 0.5 --int8_kv_cache trtllm-build --checkpoint_dir ./tllm_checkpoint/sq --output_dir ./trt_engines/sq --gemm_plugin float16 python ../summarize.py --test_trt_llm --hf_model_dir ${MODEL} --data_type fp16 --engine_dir trt_engines/sq --test_hf } @@ -84,7 +84,8 @@ test_gptq() { --use_weight_only \ --weight_only_precision int4_gptq \ --per_group \ - --tp_size 2 + --tp_size 2 \ + --workers 2 trtllm-build --checkpoint_dir ./tllm_checkpoint/2gpu_gptq \ --output_dir ./trt_engines/gptq \ @@ -103,13 +104,13 @@ test_lora() { python convert_checkpoint.py --model_dir /home/scratch.trt_llm_data/llm-models/llama-models-v2/llama-v2-13b-hf \ --output_dir ./tllm_checkpoint/2gpu_lora \ --dtype float16 \ - --tp_size 2 \ - --hf_lora_dir ${hf_lora_dir} + --tp_size 2 trtllm-build --checkpoint_dir ./tllm_checkpoint/2gpu_lora \ --output_dir ./trt_engines/llama-v2-13b-with-lora \ --gemm_plugin float16 \ --lora_plugin float16 \ + --lora_dir ${hf_lora_dir} \ --max_batch_size 1 \ --max_input_len 512 \ --max_output_len 50 @@ -118,8 +119,7 @@ test_lora() { python ../run.py --engine_dir ./trt_engines/llama-v2-13b-with-lora \ --max_output_len 50 \ --tokenizer_dir ${hf_lora_dir} \ - --input_text "今天天气很好,我到公园的时后," \ - --lora_dir ${hf_lora_dir} \ + --input_text "今天天气很好,我到公园的时候," \ --lora_task_uids 0 \ --no_add_special_tokens \ --use_py_session @@ -128,7 +128,7 @@ test_lora() { test_mixtral() { python convert_checkpoint.py --model_dir /home/scratch.trt_llm_data/llm-models/Mixtral-8x7B-v0.1/ \ --output_dir ./tllm_checkpoint/mixtral_2gpu \ - --dtype float16 --load_model_on_cpu \ + --dtype float16 \ --pp_size 2 \ trtllm-build --checkpoint_dir ./tllm_checkpoint/mixtral_2gpu \ diff --git a/tests/test_model_dtype.py b/tests/test_model_dtype.py index 8fa130628..95f2fe3bf 100644 --- a/tests/test_model_dtype.py +++ b/tests/test_model_dtype.py @@ -19,7 +19,7 @@ import tensorrt_llm from tensorrt_llm._utils import str_dtype_to_np -from tensorrt_llm.models import GPTLMHeadModel +from tensorrt_llm.models import GPTForCausalLM, PretrainedConfig class TestModelDtype(unittest.TestCase): @@ -27,20 +27,25 @@ class TestModelDtype(unittest.TestCase): def setUp(self): tensorrt_llm.logger.set_level('error') - @parameterized.expand([(GPTLMHeadModel, 'float32'), - (GPTLMHeadModel, 'bfloat16'), - (GPTLMHeadModel, 'float16')], + @parameterized.expand([(GPTForCausalLM, 'float32'), + (GPTForCausalLM, 'bfloat16'), + (GPTForCausalLM, 'float16')], name_func=unittest_name_func) def test_model_dtype(self, model_cls, dtype): ''' Every parameter in the model should have the same dtype as the model initialized to ''' - tiny_model = model_cls(num_layers=6, - num_heads=4, - hidden_size=128, - vocab_size=128, - hidden_act='relu', - max_position_embeddings=128, - dtype=dtype) + config = { + 'architecture': 'GPTForCausalLM', + 'dtype': dtype, + 'num_hidden_layers': 6, + 'num_attention_heads': 4, + 'hidden_size': 128, + 'vocab_size': 128, + 'max_position_embeddings': 128, + 'hidden_act': 'relu', + } + config = PretrainedConfig.from_dict(config) + tiny_model = model_cls(config) for p in tiny_model.parameter(): self.assertEqual(p.raw_value.dtype, str_dtype_to_np(dtype)) diff --git a/tests/utils/util.py b/tests/utils/util.py index 5ac2ec0ec..d155b0805 100644 --- a/tests/utils/util.py +++ b/tests/utils/util.py @@ -50,6 +50,10 @@ def getSMVersion(): getSMVersion() < 90, reason="This test is not supported in pre-Hopper architecture") +force_ampere = pytest.mark.skipif( + getSMVersion() < 80 or getSMVersion() > 89, + reason="This test is only enabled in Ampere architecture") + def is_bf16(dtype): return dtype == 'bfloat16' or dtype == 'bf16' or dtype == torch.bfloat16