forked from jvhs0706/zkllm-ccs2024
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathzkrelu.cu
67 lines (57 loc) · 2.41 KB
/
zkrelu.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
#include "zkrelu.cuh"
zkReLU::zkReLU(uint scaling_factor): scaling_factor(scaling_factor), tl_rem(-static_cast<int>(scaling_factor>>1), scaling_factor), sign_tensor_ptr(nullptr), abs_tensor_ptr(nullptr), rem_tensor_ptr(nullptr), m_tensor_ptr(nullptr)
{
}
// void decomp(const FrTensor& X, FrTensor& sign, FrTensor& abs, FrTensor& rem, FrTensor& rem_ind);
KERNEL void zkrelu_decomp_kernel(Fr_t* X_ptr, Fr_t* sign_ptr, Fr_t* abs_ptr, Fr_t* rem_ptr, Fr_t* res_ptr, long scaling_factor, uint N)
{
uint tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < N)
{
long hsf = scaling_factor >> 1;
long x = scalar_to_long(X_ptr[tid]);
long temp = (x + hsf) % scaling_factor;
long x_rem = temp < 0 ? temp + scaling_factor : temp;
x_rem -= hsf;
long x_rescaled = (x - x_rem) / scaling_factor;
bool pos = x_rescaled >= 0;
sign_ptr[tid] = {static_cast<uint>(pos), 0, 0, 0, 0, 0, 0, 0};
abs_ptr[tid] = pos? long_to_scalar(x_rescaled) : long_to_scalar(-x_rescaled);
rem_ptr[tid] = long_to_scalar(x_rem);
res_ptr[tid] = pos? long_to_scalar(x_rescaled) : blstrs__scalar__Scalar_ZERO;
}
}
FrTensor zkReLU::decomp(const FrTensor& X, FrTensor& sign, FrTensor& abs, FrTensor& rem)
{
uint N = X.size;
FrTensor res(N);
uint block_size = 256;
uint grid_size = (N + block_size - 1) / block_size;
zkrelu_decomp_kernel<<<grid_size, block_size>>>(X.gpu_data, sign.gpu_data, abs.gpu_data, rem.gpu_data, res.gpu_data, static_cast<long>(scaling_factor), N);
cudaDeviceSynchronize();
return res;
}
FrTensor zkReLU::operator()(const FrTensor& X)
{
if (sign_tensor_ptr) delete sign_tensor_ptr;
sign_tensor_ptr = new FrTensor(X.size);
if (abs_tensor_ptr) delete abs_tensor_ptr;
abs_tensor_ptr = new FrTensor(X.size);
if (rem_tensor_ptr) delete rem_tensor_ptr;
rem_tensor_ptr = new FrTensor(X.size);
if (m_tensor_ptr) delete m_tensor_ptr;
// m_tensor_ptr = new FrTensor(tl_rem.table.size);
FrTensor res = decomp(X, *sign_tensor_ptr, *abs_tensor_ptr, *rem_tensor_ptr);
m_tensor_ptr = new FrTensor(tl_rem.prep(*rem_tensor_ptr));
return res;
}
void zkReLU::prove(const FrTensor& Z, const FrTensor& A)
{
}
zkReLU::~zkReLU()
{
if (sign_tensor_ptr) delete sign_tensor_ptr;
if (abs_tensor_ptr) delete abs_tensor_ptr;
if (rem_tensor_ptr) delete rem_tensor_ptr;
if (m_tensor_ptr) delete m_tensor_ptr;
}