diff --git a/cuda_ext.py b/cuda_ext.py index 5efd3d4a..1fe94ab9 100644 --- a/cuda_ext.py +++ b/cuda_ext.py @@ -58,7 +58,7 @@ def find_msvc(): extra_include_paths = [os.path.join(library_dir, "exllama_ext")], verbose = verbose, extra_ldflags = (["cublas.lib"] + ([f"/LIBPATH:{os.path.join(sys.base_prefix, 'libs')}"] if sys.base_prefix != sys.prefix else [])) if windows else [], - extra_cuda_cflags = ["-lineinfo"] + (["-U__HIP_NO_HALF_CONVERSIONS__", "-O3"] if torch.version.hip else []), + extra_cuda_cflags = ["-lineinfo"] + (["-O3"] if torch.version.hip else []), extra_cflags = ["-O3"] # extra_cflags = ["-ftime-report", "-DTORCH_USE_CUDA_DSA"] ) diff --git a/exllama_ext/cuda_func/half_matmul.cu b/exllama_ext/cuda_func/half_matmul.cu index 76ee1e43..e97c2c3c 100644 --- a/exllama_ext/cuda_func/half_matmul.cu +++ b/exllama_ext/cuda_func/half_matmul.cu @@ -41,8 +41,8 @@ __global__ void half_matmul_kernel for (int k = k0; k < k0 + BLOCKSIZE / 2; k++) { half2 x_item = *x_ptr++; - half2 x_item_0 = __half2half2(x_item.x); - half2 x_item_1 = __half2half2(x_item.y); + half2 x_item_0 = __low2half2(x_item); + half2 x_item_1 = __high2half2(x_item); half2 w_item_0 = *w_ptr; w_ptr += w_.width / 2; half2 w_item_1 = *w_ptr; w_ptr += w_.width / 2; acc = __hfma2(x_item_0, w_item_0, acc); @@ -184,7 +184,7 @@ __global__ void half_matmul_small_kernel r = __hfma2(x_23, w_23, r); } - half rh = __hadd(r.x, r.y); + half rh = __hadd(__low2half(r), __high2half(r)); __shared__ half accum[MAX_DIM_SMALL / S_BLOCKSIZE][S_THREADS_X]; accum[threadIdx.y][threadIdx.x] = rh; diff --git a/exllama_ext/cuda_func/q4_matmul.cu b/exllama_ext/cuda_func/q4_matmul.cu index 4bbe5883..e7a1b011 100644 --- a/exllama_ext/cuda_func/q4_matmul.cu +++ b/exllama_ext/cuda_func/q4_matmul.cu @@ -131,7 +131,7 @@ __global__ void q4_matmul_kernel if constexpr (use_half2) { - half result = __hadd(acc.x, acc.y); + half result = __hadd(__low2half(acc), __high2half(acc)); atomicAdd(out_.item_ptr(x_row, w_column), result); } else diff --git a/exllama_ext/cuda_func/q4_mlp.cu b/exllama_ext/cuda_func/q4_mlp.cu index ac21e0e2..16b9bd9e 100644 --- a/exllama_ext/cuda_func/q4_mlp.cu +++ b/exllama_ext/cuda_func/q4_mlp.cu @@ -13,6 +13,7 @@ const int THREADS_X = 32; const int THREADS_Y = 4; // const int MAX_DIMENSION = 8192; + __device__ __forceinline__ half silu(half x) { half one = __float2half(1.0f); diff --git a/exllama_ext/cuda_func/rms_norm.cu b/exllama_ext/cuda_func/rms_norm.cu index 991390b4..0fc5bfaa 100644 --- a/exllama_ext/cuda_func/rms_norm.cu +++ b/exllama_ext/cuda_func/rms_norm.cu @@ -50,8 +50,8 @@ __global__ void rms_norm_row_product_kernel for (int k = 0; k < BLOCKSIZE_X / 2; k++) { half2 x2 = *x_ptr++; - float m0 = __half2float(x2.x); - float m1 = __half2float(x2.y); + float m0 = __low2float(x2); + float m1 = __high2float(x2); acc = fma(m0, m0, acc); acc = fma(m1, m1, acc); } diff --git a/exllama_ext/hip_compat.cuh b/exllama_ext/hip_compat.cuh index e650cd4a..1b4e0634 100644 --- a/exllama_ext/hip_compat.cuh +++ b/exllama_ext/hip_compat.cuh @@ -8,8 +8,8 @@ __device__ __forceinline__ __half __compat_hrcp(__half x) { } __device__ __forceinline__ __half2 __compat_h2rcp(__half2 x) { - return _Float16_2{static_cast<_Float16>(__builtin_amdgcn_rcph(x.x)), - static_cast<_Float16>(__builtin_amdgcn_rcph(x.y))}; + return _Float16_2{static_cast<_Float16>(__builtin_amdgcn_rcph(static_cast<__half2_raw>(x).data.x)), + static_cast<_Float16>(__builtin_amdgcn_rcph(static_cast<__half2_raw>(x).data.y))}; } #define hrcp __compat_hrcp diff --git a/model_init.py b/model_init.py index 68901992..ae7e1e1c 100644 --- a/model_init.py +++ b/model_init.py @@ -1,7 +1,6 @@ from model import ExLlama, ExLlamaCache, ExLlamaConfig from tokenizer import ExLlamaTokenizer import argparse, sys, os, glob -from torch import version as torch_version def add_args(parser): @@ -28,13 +27,12 @@ def add_args(parser): parser.add_argument("-mmnh2", "--matmul_no_half2", action = "store_true", help = "Don't use half2 in Q4 matmul kernel") parser.add_argument("-snh2", "--silu_no_half2", action = "store_true", help = "Don't use half2 in SiLU kernel") parser.add_argument("-nh2", "--no_half2", action = "store_true", help = "(All of the above) disable half2 in all kernela") - parser.add_argument("-fh2", "--force_half2", action = "store_true", help = "Force enable half2 even if unsupported") parser.add_argument("-cs", "--concurrent_streams", action = "store_true", help = "Use concurrent CUDA streams") def post_parse(args): - if args.no_half2 or torch_version.hip and not args.force_half2: + if args.no_half2: args.rmsnorm_no_half2 = True args.rope_no_half2 = True args.matmul_no_half2 = True