forked from jvhs0706/zkllm-ccs2024
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathzksoftmax.cuh
75 lines (59 loc) · 4.23 KB
/
zksoftmax.cuh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
#ifndef ZKSOFTMAX_CUH
#define ZKSOFTMAX_CUH
#include "tlookup.cuh"
#include "zkfc.cuh"
class zkSoftmax {
public:
zkSoftmax(const vector<uint>& bs, uint L, uint M, unsigned long scaling_factor_in, const vector<double>& thetas, uint m, uint n, uint d, uint E);
FrTensor compute(const FrTensor& X, FrTensor& shift, FrTensor& X_shifted, vector<FrTensor>& X_segments, vector<FrTensor>& Y_segments, vector<FrTensor>& m_segments);
Fr_t prove(const FrTensor& Y, const FrTensor& X, const FrTensor& shift, const FrTensor& X_shifted,
const vector<FrTensor>& X_segments, const vector<FrTensor>& Y_segments, const vector<FrTensor>& m_segments,
const vector<Fr_t>& u_Y, const vector<Fr_t>& v_Y,
const Fr_t& r_seg, const Fr_t& alpha_seg, const Fr_t& beta_seg,
vector<Polynomial>& proof);
protected:
const vector<uint> bs;
// const vector<unsigned long> Bs; // length of each segment, and its cumulative prod
const uint K, L, M; // the number of most and least significant segments
const unsigned long scaling_factor_in; // the input of scaling factor (gamma**2)
const vector<double> thetas; // the output scaling factor for each segment
const uint m, n, d; // the dimensions of the input
const uint E; // the error of the output
vector<tLookupRange> least_significant_segments; // the lookup table for the least significant segments
vector<tLookupRangeMapping> other_segments; // the lookup table for other segments
};
class zkAttn : public zkSoftmax {
public:
zkAttn(unsigned long sf_Q, unsigned long sf_K, const vector<uint>& bs, uint L, uint M, const vector<double>& thetas, uint m, uint n, uint d, uint E);
// Q: m * d, K: n * d, V: n * d
FrTensor compute(const FrTensor& Q, const FrTensor& K, const FrTensor& V, FrTensor& sm_in, FrTensor& sm_out,
FrTensor& sm_shift, FrTensor& sm_in_shifted, vector<FrTensor>& sm_in_segments, vector<FrTensor>& sm_out_segments, vector<FrTensor>& sm_m_segments);
Fr_t prove(const FrTensor& Q, const FrTensor& K, const FrTensor& V, const FrTensor& out,
const FrTensor& sm_out, const FrTensor& sm_in, const FrTensor& sm_shift, const FrTensor& sm_in_shifted,
const vector<FrTensor>& sm_in_segments, const vector<FrTensor>& sm_out_segments, const vector<FrTensor>& sm_m_segments,
const vector<Fr_t>& u_matmul_out, const vector<Fr_t>& v_matmul_out, const vector<Fr_t>& w_matmul_out,
const vector<Fr_t>& v_sm, const Fr_t& r_seg, const Fr_t& alpha_seg, const Fr_t& beta_seg,
const vector<Fr_t>& v_matmul_in,
vector<Polynomial>& proof);
vector<Claim> prove(const FrTensor& Q, const FrTensor& K, const FrTensor& V, const FrTensor& out,
const FrTensor& sm_out, const FrTensor& sm_in, const FrTensor& sm_shift, const FrTensor& sm_in_shifted,
const vector<FrTensor>& sm_in_segments, const vector<FrTensor>& sm_out_segments, const vector<FrTensor>& sm_m_segments);
};
class zkAttnStacked : public zkAttn {
public:
zkAttnStacked(uint num, unsigned long sf_Q, unsigned long sf_K, const vector<uint>& bs, uint L, uint M, const vector<double>& thetas, uint m, uint n, uint d, uint E);
Fr_t prove(const FrTensor& Q, const FrTensor& K, const FrTensor& V, const FrTensor& out,
const FrTensor& sm_out, const FrTensor& sm_in, const FrTensor& sm_shift, const FrTensor& sm_in_shifted,
const vector<FrTensor>& sm_in_segments, const vector<FrTensor>& sm_out_segments, const vector<FrTensor>& sm_m_segments,
const vector<Fr_t>& u_matmul_out_num, const vector<Fr_t>& v_matmul_out_num,
const vector<Fr_t>& u_matmul_out, const vector<Fr_t>& v_matmul_out, const vector<Fr_t>& w_matmul_out,
const vector<Fr_t>& v_sm, const Fr_t& r_seg, const Fr_t& alpha_seg, const Fr_t& beta_seg,
const vector<Fr_t>& v_matmul_in_num, const vector<Fr_t>& v_matmul_in,
vector<Polynomial>& proof );
vector<Claim> prove(const FrTensor& Q, const FrTensor& K, const FrTensor& V, const FrTensor& out,
const FrTensor& sm_out, const FrTensor& sm_in, const FrTensor& sm_shift, const FrTensor& sm_in_shifted,
const vector<FrTensor>& sm_in_segments, const vector<FrTensor>& sm_out_segments, const vector<FrTensor>& sm_m_segments);
protected:
const uint num;
};
#endif