forked from jvhs0706/zkllm-ccs2024
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathzkrelu.cuh
32 lines (23 loc) · 869 Bytes
/
zkrelu.cuh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
#ifndef ZKRELU_CUH
#define ZKRELU_CUH
#include <cstddef>
#include <cuda_runtime.h>
#include "bls12-381.cuh" // adjust this to point to the blstrs header file
#include "fr-tensor.cuh"
#include "tlookup.cuh"
#include "proof.cuh"
class zkReLU {
public:
uint scaling_factor;
tLookupRange tl_rem; // table for remainder
FrTensor decomp(const FrTensor& X, FrTensor& sign, FrTensor& abs, FrTensor& rem);
FrTensor *sign_tensor_ptr, *abs_tensor_ptr, *rem_tensor_ptr, *m_tensor_ptr;
zkReLU(uint scaling_factor);
FrTensor operator()(const FrTensor& X);
void prove(const FrTensor& Z, const FrTensor& A);
~zkReLU();
};
// DEVICE Fr_t ulong_to_scalar(unsigned long num);
// DEVICE unsigned long scalar_to_ulong(Fr_t num);
// KERNEL void relu_kernel(Fr_t* X, Fr_t* Z, Fr_t* sign, Fr_t* mag_bin, Fr_t* rem_bin, uint n);
#endif // ZKRELU_CUH