Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,4 @@
url = https://gitcode.com/xLLM-AI/spdlog.git
[submodule "third_party/Mooncake"]
path = third_party/Mooncake
url = https://gitcode.com/xLLM-AI/Mooncake.git
url = https://gitcode.com/xLLM-AI/Mooncake.git
73 changes: 71 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ set_property(GLOBAL PROPERTY USE_FOLDERS ON)

option(USE_NPU "Enable NPU support" OFF)
option(USE_MLU "Enable MLU support" OFF)
option(USE_CUDA "Enable CUDA support" OFF)

if(DEVICE_ARCH STREQUAL "ARM")
set(CMAKE_SYSTEM_PROCESSOR aarch64)
Expand Down Expand Up @@ -101,7 +102,7 @@ set(CMAKE_CXX_STANDARD 20)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CXX_EXTENSIONS ON)

if(USE_NPU)
if(USE_NPU OR USE_CUDA)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=0")
add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0)
elseif(USE_MLU)
Expand Down Expand Up @@ -178,6 +179,32 @@ if (DEFINED ENV{DEPENDENCES_ROOT})
message(STATUS "Using DEPENDENCES_ROOT: $ENV{DEPENDENCES_ROOT}")
endif()


# Build TORCH_CUDA_ARCH_LIST
if(USE_CUDA)
# set architecture for CUDA
if (NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
set(CMAKE_CUDA_ARCHITECTURES 80)
endif()
# Build TORCH_CUDA_ARCH_LIST
set(TORCH_CUDA_ARCH_LIST "")
foreach(CUDA_ARCH IN LISTS CMAKE_CUDA_ARCHITECTURES)
if(CUDA_ARCH MATCHES "^([0-9])([0-9])a$")
set(TORCH_ARCH "${CMAKE_MATCH_1}.${CMAKE_MATCH_2}a")
elseif(CUDA_ARCH MATCHES "^([0-9])([0-9])*$")
set(TORCH_ARCH "${CMAKE_MATCH_1}.${CMAKE_MATCH_2}")
elseif(CUDA_ARCH STREQUAL "native")
set(TORCH_ARCH "Auto")
else()
message(FATAL_ERROR "${CUDA_ARCH} is not supported")
endif()
list(APPEND TORCH_CUDA_ARCH_LIST ${TORCH_ARCH})
endforeach()

message(STATUS "CMAKE_CUDA_ARCHITECTURES: ${CMAKE_CUDA_ARCHITECTURES}")
message(STATUS "TORCH_CUDA_ARCH_LIST: ${TORCH_CUDA_ARCH_LIST}")
endif()

# configure vcpkg
# have to set CMAKE_TOOLCHAIN_FILE before first project call.
# if (DEFINED ENV{VCPKG_ROOT} AND NOT DEFINED CMAKE_TOOLCHAIN_FILE)
Expand Down Expand Up @@ -217,7 +244,12 @@ endif()
set(CPPREST_EXCLUDE_WEBSOCKETS ON CACHE BOOL "Exclude websockets functionality." FORCE)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-format-truncation")

project("xllm" LANGUAGES C CXX)
if(USE_CUDA)
project("xllm" LANGUAGES C CXX CUDA)
find_package(CUDAToolkit REQUIRED)
else()
project("xllm" LANGUAGES C CXX)
endif()

# find_package(CUDAToolkit REQUIRED)

Expand Down Expand Up @@ -352,6 +384,43 @@ if(USE_MLU)
)
endif()

if(USE_CUDA)
add_definitions(-DUSE_CUDA)
add_compile_definitions(TORCH_CUDA=1)
set(CMAKE_VERBOSE_MAKEFILE ON)
include_directories(
$ENV{PYTHON_INCLUDE_PATH}
$ENV{PYTORCH_INSTALL_PATH}/include
$ENV{PYTORCH_INSTALL_PATH}/include/torch/csrc/api/include
)

link_directories(
$ENV{PYTHON_LIB_PATH}
$ENV{PYTORCH_INSTALL_PATH}/lib
$ENV{CUDA_TOOLKIT_ROOT_DIR}/lib64
)

set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS} -O3)
# The following definitions must be undefined since half-precision operation is required.
set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS}
-U__CUDA_NO_HALF_OPERATORS__
-U__CUDA_NO_HALF_CONVERSIONS__
-U__CUDA_NO_HALF2_OPERATORS__
-U__CUDA_NO_BFLOAT16_CONVERSIONS__)
set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS} --use_fast_math -Xfatbin -compress-all)
message(STATUS "CUDA_NVCC_FLAGS: ${CUDA_NVCC_FLAGS}")

# find_package(NCCL REQUIRED)

# find cudnn
execute_process(COMMAND python -c "import nvidia.cudnn; print(nvidia.cudnn.__file__)" OUTPUT_VARIABLE CUDNN_PYTHON_PATH)
get_filename_component(CUDNN_ROOT_DIR "${CUDNN_PYTHON_PATH}" DIRECTORY)
link_directories(
${CUDNN_ROOT_DIR}/lib64
${CUDNN_ROOT_DIR}/lib
)
endif()

# check if USE_CXX11_ABI is set correctly
# if (DEFINED USE_CXX11_ABI)
# parse_make_options(${TORCH_CXX_FLAGS} "TORCH_CXX_FLAGS")
Expand Down
32 changes: 26 additions & 6 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ def get_python_include_path():
return None


# PYTORCH_INSTALL_PATH and LIBTORCH_ROOT
def get_torch_root_path():
try:
import torch
Expand All @@ -115,6 +114,12 @@ def get_torch_mlu_root_path():
except ImportError:
return None

def get_nccl_root_path():
try:
from nvidia import nccl
return str(Path(nccl.__file__).parent)
except ImportError:
return None

def set_npu_envs():
PYTORCH_NPU_INSTALL_PATH = os.getenv("PYTORCH_NPU_INSTALL_PATH")
Expand Down Expand Up @@ -212,7 +217,16 @@ def set_mlu_envs():
os.environ["LIBTORCH_ROOT"] = get_torch_root_path()
os.environ["PYTORCH_INSTALL_PATH"] = get_torch_root_path()
os.environ["PYTORCH_MLU_INSTALL_PATH"] = get_torch_mlu_root_path()


def set_cuda_envs():
os.environ["PYTHON_INCLUDE_PATH"] = get_python_include_path()
os.environ["PYTHON_LIB_PATH"] = get_torch_root_path()
os.environ["LIBTORCH_ROOT"] = get_torch_root_path()
os.environ["PYTORCH_INSTALL_PATH"] = get_torch_root_path()
os.environ["CUDA_TOOLKIT_ROOT_DIR"] = "/usr/local/cuda"
os.environ["NCCL_ROOT"] = get_nccl_root_path()
os.environ["NCCL_VERSION"] = "2"

class CMakeExtension(Extension):
def __init__(self, name: str, path: str, sourcedir: str = "") -> None:
super().__init__(name, sources=[])
Expand All @@ -223,7 +237,7 @@ def __init__(self, name: str, path: str, sourcedir: str = "") -> None:
class ExtBuild(build_ext):
user_options = build_ext.user_options + [
("base-dir=", None, "base directory of xLLM project"),
("device=", None, "target device type (a3 or a2 or mlu)"),
("device=", None, "target device type (a3 or a2 or mlu or cuda)"),
("arch=", None, "target arch type (x86 or arm)"),
("install-xllm-kernels=", None, "install xllm_kernels RPM package (true/false)"),
]
Expand Down Expand Up @@ -302,8 +316,14 @@ def build_extension(self, ext: CMakeExtension):
cmake_args += ["-DUSE_MLU=ON"]
# set mlu environment variables
set_mlu_envs()
elif self.device == "cuda":
cuda_architectures = "80;89;90"
cmake_args += ["-DUSE_CUDA=ON",
f"-DCMAKE_CUDA_ARCHITECTURES={cuda_architectures}"]
# set cuda environment variables
set_cuda_envs()
else:
raise ValueError("Please set --device to a2 or a3 or mlu.")
raise ValueError("Please set --device to a2 or a3 or mlu or cuda.")


# Adding CMake arguments set as environment variable
Expand Down Expand Up @@ -353,7 +373,7 @@ def build_extension(self, ext: CMakeExtension):

class BuildDistWheel(bdist_wheel):
user_options = bdist_wheel.user_options + [
("device=", None, "target device type (a3 or a2 or mlu)"),
("device=", None, "target device type (a3 or a2 or mlu or cuda)"),
("arch=", None, "target arch type (x86 or arm)"),
]

Expand Down Expand Up @@ -530,7 +550,7 @@ def apply_patch():
idx = sys.argv.index('--device')
if idx + 1 < len(sys.argv):
device = sys.argv[idx+1].lower()
if device not in ('a2', 'a3', 'mlu'):
if device not in ('a2', 'a3', 'mlu', 'cuda'):
print("Error: --device must be a2 or a3 or mlu (case-insensitive)")
sys.exit(1)
# Remove the arguments so setup() doesn't see them
Expand Down
16 changes: 15 additions & 1 deletion xllm/core/common/global_flags.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ DEFINE_int32(transfer_listen_port, 26000, "The KVCacheTranfer listen port.");
DEFINE_bool(enable_shm,
false,
"Whether to enable shared memory for executing model.");

// --- function call config ---

DEFINE_string(tool_call_parser,
Expand Down Expand Up @@ -353,6 +354,7 @@ DEFINE_int32(micro_batch_num,
"Default use two micro batches for multi-stream parallel.");

// --- dit config ---

DEFINE_int32(max_requests_per_batch, 1, "Max number of request per batch.");

// --- continuous kv cache config ---
Expand All @@ -377,22 +379,34 @@ DEFINE_int64(buffer_size_per_seq,
"Buffer size per sequence in bytes, default 0.");

// --- beam search config ---

DEFINE_bool(enable_beam_search_kernel,
false,
"Whether to enable beam search kernel.");

// --- reasoning parser config ---

DEFINE_string(reasoning_parser,
"",
"Specify the reasoning parser for handling reasoning "
"interactions(e.g. glm45, qwen3, deepseek-r1).");

// --- qwen3 reranker config ---

DEFINE_bool(enable_qwen3_reranker, false, "Whether to enable qwen3 reranker.");

// --- flashinfer config ---

DEFINE_int32(flashinfer_workspace_buffer_size,
128 * 1024 * 1024,
"The user reserved workspace buffer used to store intermediate "
"attention results in split-k algorithm for flashinfer.");

// --- prefetch weight config ---

DEFINE_bool(
enable_prefetch_weight,
false,
"Whether to enable prefetch weight,only applicable to Qwen3-dense model."
"The default prefetching ratio for gateup weight is 40%."
"If adjustments are needed, e.g. export PREFETCH_COEFFOCIENT=0.5");
"If adjustments are needed, e.g. export PREFETCH_COEFFOCIENT=0.5");
2 changes: 2 additions & 0 deletions xllm/core/common/global_flags.h
Original file line number Diff line number Diff line change
Expand Up @@ -204,3 +204,5 @@ DECLARE_string(reasoning_parser);
DECLARE_bool(enable_shm);

DECLARE_bool(enable_prefetch_weight);

DECLARE_int32(flashinfer_workspace_buffer_size);
38 changes: 35 additions & 3 deletions xllm/core/framework/batch/batch_input_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ void BatchInputBuilder::process_sequences_multithreaded(uint32_t start_idx,
state_.q_seq_lens.insert(state_.q_seq_lens.end(),
state.q_seq_lens.begin(),
state.q_seq_lens.end());
#elif defined(USE_MLU)
#elif defined(USE_MLU) || defined(USE_CUDA)
int32_t seq_len_offset = state_.seq_lens.back();
// skip the first element which is 0
for (size_t i = 1; i < state.seq_lens.size(); ++i) {
Expand Down Expand Up @@ -248,6 +248,16 @@ void BatchInputBuilder::process_sequences_multithreaded(uint32_t start_idx,
state.kv_cache_start_offsets.begin(),
state.kv_cache_start_offsets.end());
}
// for flashinfer
state_.paged_kv_indptr.insert(state_.paged_kv_indptr.end(),
state.paged_kv_indptr.begin(),
state.paged_kv_indptr.end());
state_.paged_kv_indices.insert(state_.paged_kv_indices.end(),
state.paged_kv_indices.begin(),
state.paged_kv_indices.end());
state_.paged_kv_last_page_len.insert(state_.paged_kv_last_page_len.end(),
state.paged_kv_last_page_len.begin(),
state.paged_kv_last_page_len.end());
}
for (const auto& write_block_ids : thread_write_block_ids) {
write_block_ids_.insert(write_block_ids.begin(), write_block_ids.end());
Expand Down Expand Up @@ -288,7 +298,7 @@ void BatchInputBuilder::process_single_sequence(
#if defined(USE_NPU)
state.seq_lens.push_back(seq_len);
state.q_seq_lens.push_back(q_seq_len);
#elif defined(USE_MLU)
#elif defined(USE_MLU) || defined(USE_CUDA)
state.seq_lens.push_back(state.seq_lens.back() + seq_len);
state.q_seq_lens.push_back(state.q_seq_lens.back() + q_seq_len);
#endif
Expand Down Expand Up @@ -448,7 +458,12 @@ void BatchInputBuilder::setup_kv_cache_info(
block_size = block.size();
block_ids.push_back(block.id());
u_block_ids.emplace_back(block.id());
state.paged_kv_indices.push_back(block.id());
}
state.paged_kv_indptr.push_back(state.paged_kv_indptr.back() + blocks.size());
int32_t last_page_len =
(seq_len % block_size == 0) ? block_size : seq_len % block_size;
state.paged_kv_last_page_len.push_back(last_page_len);

int32_t kv_cache_block_idx = n_kv_cache_tokens / block_size;
for (auto iter = block_ids.begin() + kv_cache_block_idx;
Expand Down Expand Up @@ -517,12 +532,15 @@ void BatchInputBuilder::padding_decode_batch_size(
#if defined(USE_NPU)
state_.seq_lens.push_back(num_decoding_tokens);
state_.q_seq_lens.push_back(num_decoding_tokens);
#elif defined(USE_MLU)
#elif defined(USE_MLU) || defined(USE_CUDA)
state_.seq_lens.push_back(state_.seq_lens.back() + num_decoding_tokens);
state_.q_seq_lens.push_back(state_.q_seq_lens.back() +
num_decoding_tokens);
#endif
state_.block_tables_vec.emplace_back();
state_.paged_kv_indices.push_back(0);
state_.paged_kv_indptr.push_back(state_.paged_kv_indptr.back() + 1);
state_.paged_kv_last_page_len.push_back(1);
}
}
}
Expand Down Expand Up @@ -560,6 +578,14 @@ ForwardInput BatchInputBuilder::state_to_forward_input() {
input_params.decode_seq_range =
util::find_ones_indices(input_params.q_seq_lens_vec);

// for flashinfer
input_params.paged_kv_indptr =
torch::tensor(state_.paged_kv_indptr, torch::kInt);
input_params.paged_kv_indices =
torch::tensor(state_.paged_kv_indices, torch::kInt);
input_params.paged_kv_last_page_len =
torch::tensor(state_.paged_kv_last_page_len, torch::kInt);

// Setup multimodal data
input_params.mm_data = MMData::batch(mm_data_vec_);

Expand Down Expand Up @@ -634,6 +660,12 @@ RawForwardInput BatchInputBuilder::state_to_raw_forward_input() {
raw_forward_input.transfer_kv_infos = std::move(state_.transfer_kv_infos);
raw_forward_input.prefill_seq_len = state_.prefill_seq_len;

// for flashinfer
raw_forward_input.paged_kv_indptr = std::move(state_.paged_kv_indptr);
raw_forward_input.paged_kv_indices = std::move(state_.paged_kv_indices);
raw_forward_input.paged_kv_last_page_len =
std::move(state_.paged_kv_last_page_len);

raw_forward_input.embedding_ids = std::move(state_.embedding_ids);
raw_forward_input.extra_token_ids = std::move(state_.extra_token_ids);
// beam search kernel input
Expand Down
7 changes: 6 additions & 1 deletion xllm/core/framework/batch/batch_input_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ class BatchInputBuilder {
#if defined(USE_NPU)
std::vector<int32_t> seq_lens;
std::vector<int32_t> q_seq_lens;
#elif defined(USE_MLU)
#elif defined(USE_MLU) || defined(USE_CUDA)
std::vector<int32_t> seq_lens = {0}; // cu_seq_lens
std::vector<int32_t> q_seq_lens = {0}; // q_cu_seq_len
#endif
Expand All @@ -107,6 +107,11 @@ class BatchInputBuilder {
// for continuous kvcache
std::vector<int64_t> new_cache_slot_offsets; //[n_tokens]
std::vector<int64_t> kv_cache_start_offsets; //[n_seq]

// for flashinfer
std::vector<int32_t> paged_kv_indptr = {0};
std::vector<int32_t> paged_kv_indices;
std::vector<int32_t> paged_kv_last_page_len;
};

// Helper methods for sequence processing
Expand Down
4 changes: 2 additions & 2 deletions xllm/core/framework/batch/batch_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,15 +152,15 @@ TEST(BatchTest, Basic) {

#if defined(USE_NPU)
const std::vector<int32_t> q_seq_lens = {9, 1, 1, 4};
#elif defined(USE_MLU)
#else
const std::vector<int32_t> q_seq_lens = {0, 9, 10, 11, 15};
#endif
EXPECT_TRUE(equal(input_params.q_seq_lens, q_seq_lens));

// seq4's kv_seq_len = q_len + num_cached_tokens (q_len<=max_allowed_tokens)
#if defined(USE_NPU)
const std::vector<int32_t> kv_seq_lens = {9, 8, 16, 8};
#elif defined(USE_MLU)
#else
const std::vector<int32_t> kv_seq_lens = {0, 9, 17, 33, 41};
#endif
EXPECT_TRUE(equal(input_params.kv_seq_lens, kv_seq_lens));
Expand Down
Loading
Loading