Skip to content

Commit

Permalink
Bugfix: pushing the correct top-level function for fully-unrolled LUT…
Browse files Browse the repository at this point in the history
…Net with MNIST
  • Loading branch information
Erwei Wang committed May 6, 2020
1 parent c4f7706 commit 81b044f
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 82 deletions.
31 changes: 31 additions & 0 deletions unrolled-lutnet/lutnet/src/network/MNIST/hw/mnist_4lut_weights.h

Large diffs are not rendered by default.

157 changes: 75 additions & 82 deletions unrolled-lutnet/lutnet/src/network/MNIST/hw/top.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,35 +40,25 @@
*
* @file top.cpp
*
* HLS Description of the CNV BNN, with axi-lite based parameter loading (DoMemInit)
* HLS Description of the BNN, with axi-lite based parameter loading (DoMemInit)
* and dataflow architecture of the image inference (DoCompute)
*
*
*
*****************************************************************************/
#define AP_INT_MAX_W 9216
#include <ap_int.h>

#include "bnn-library.h"
#include "config.h"
#include"mnist_4lut_weights.h"


static ap_uint<L0_SIMD> weightMem0[L0_PE][L0_WMEM];
static ap_fixed<24, 16> thresMem0[L0_PE][L0_TMEM];
static ap_fixed<24, 16> alphaMem0[L0_PE][L0_TMEM];
static ap_fixed<24,16> means_in0[numRes];
static ap_fixed<24,16> means_out0[numRes];
static ap_uint<L1_SIMD> weightMem1[L1_PE][L1_WMEM];
static ap_fixed<24, 16> thresMem1[L1_PE][L1_TMEM];
static ap_fixed<24, 16> alphaMem1[L1_PE][L1_TMEM];
static ap_fixed<24,16> means_in1[numRes];
static ap_fixed<24,16> means_out1[numRes];
static ap_uint<L2_SIMD> weightMem2[L2_PE][L2_WMEM];
static ap_fixed<24, 16> thresMem2[L2_PE][L2_TMEM];
static ap_fixed<24, 16> alphaMem2[L2_PE][L2_TMEM];
static ap_fixed<24,16> means_in2[numRes];
static ap_fixed<24,16> means_out2[numRes];
static ap_uint<L3_SIMD> weightMem3[L3_PE][L3_WMEM];
static ap_fixed<24, 16> thresMem3[L3_PE][L3_TMEM];
static ap_fixed<24, 16> alphaMem3[L3_PE][L3_TMEM];
static ap_fixed<24,16> means_in3[numRes];
static ap_fixed<24,16> means_out3[numRes];


unsigned int paddedSizeHW(unsigned int in, unsigned int padTo) {
if(in % padTo == 0)
Expand All @@ -77,94 +67,107 @@ unsigned int paddedSizeHW(unsigned int in, unsigned int padTo) {
return in + padTo - (in % padTo);
}

void DoMemInit(unsigned int targetLayer, unsigned int targetMem, unsigned int targetInd, ap_uint<64> val, ap_fixed<24,16> fix_val) {
switch(targetLayer) {
void DoMemInit(unsigned int targetLayer, unsigned int targetMem, unsigned int targetInd, ap_uint<64> val, ap_fixed<24,16> fix_val) {
switch (targetLayer) {
case 0:
weightMem0[targetMem][targetInd] = val;
break;
case 1:
thresMem0[targetMem][targetInd] = val;
break;
case 2:
weightMem1[targetMem][targetInd] = val;
//weightMem1[targetMem][targetInd] = val;
break;
case 3:
thresMem1[targetMem][targetInd] = val;
//thresMem1[targetMem][targetInd] = val;
break;
case 4:
weightMem2[targetMem][targetInd] = val;
//weightMem2[targetMem][targetInd] = val;
break;
case 5:
thresMem2[targetMem][targetInd] = val;
//thresMem2[targetMem][targetInd] = val;
break;
case 6:
weightMem3[targetMem][targetInd] = val;
//weightMem3[targetMem][targetInd] = val;
break;
case 7:
thresMem3[targetMem][targetInd] = val;
//thresMem3[targetMem][targetInd] = val;
break;
case 8:
alphaMem0[targetMem][targetInd] = val;
break;
break;
case 9:
alphaMem1[targetMem][targetInd] = val;
//alphaMem1[targetMem][targetInd] = val;
break;
case 10:
alphaMem2[targetMem][targetInd] = val;
//alphaMem2[targetMem][targetInd] = val;
break;
case 11:
alphaMem3[targetMem][targetInd] = val;
//alphaMem3[targetMem][targetInd] = val;
break;
case 12:
means_in1[targetMem][targetInd] = val;
//means_in1[targetMem][targetInd] = val;
break;
case 13:
means_in2[targetMem][targetInd] = val;
//means_in2[targetMem][targetInd] = val;
break;
case 14:
means_in3[targetMem][targetInd] = val;
//means_in3[targetMem][targetInd] = val;
break;
case 15:
means_out0[targetMem][targetInd] = val;
break;
case 16:
means_out1[targetMem][targetInd] = val;
//means_out1[targetMem][targetInd] = val;
break;
case 17:
means_out2[targetMem][targetInd] = val;
//means_out2[targetMem][targetInd] = val;
break;
case 18:
means_out3[targetMem][targetInd] = val;
//means_out3[targetMem][targetInd] = val;
break;
case 19:
means_in0[targetMem][targetInd] = val;
break;
default:
break;
}
}
}


void DoCompute(ap_uint<64> * in, ap_uint<64> * out, const unsigned int numReps) {
#pragma HLS DATAFLOW

hls::stream<ap_uint<64>> memInStrm("DoCompute.memInStrm");
hls::stream<ap_uint<L0_PE>> inter0("DoCompute.inter0");
hls::stream<ap_uint<L1_PE>> inter1("DoCompute.inter1");
hls::stream<ap_uint<L2_PE>> inter2("DoCompute.inter2");
hls::stream<ap_uint<64>> memOutStrm("DoCompute.memOutStrm");

hls::stream<ap_uint<64>> memInStrm("DoCompute.memInStrm");
hls::stream<ap_uint<L0_PE>> inter0("DoCompute.inter0");
// This is where we implement LUTNet
stream<ap_uint<256*2> > inter0_1("DoCompute.inter0_1");
#pragma HLS STREAM variable=inter0_1 depth=1
stream<ap_uint<256*2> > inter0_2("DoCompute.inter0_2");
#pragma HLS STREAM variable=inter0_2 depth=1
stream<ap_uint<256*2> > inter0_3("DoCompute.inter0_3");
#pragma HLS STREAM variable=inter0_3 depth=1
stream<ap_uint<256*2> > inter0_4("DoCompute.inter0_4");
#pragma HLS STREAM variable=inter0_4 depth=1
stream<ap_uint<10*24> > inter0_5("DoCompute.inter0_5");
#pragma HLS STREAM variable=inter0_5 depth=1
stream<ap_uint<960> > inter0_6("DoCompute.inter0_6");
#pragma HLS STREAM variable=inter0_6 depth=1

unsigned const L0_DEPTH = 128 / L0_PE;
unsigned const L1_DEPTH = 128 / L1_PE;
unsigned const L2_DEPTH = 128 / L2_PE;

// Back to REBNet

hls::stream<ap_uint<L1_PE>> inter1("DoCompute.inter1");
hls::stream<ap_uint<L2_PE>> inter2("DoCompute.inter2");
hls::stream<ap_uint<64>> memOutStrm("DoCompute.memOutStrm");


unsigned const L0_DEPTH = 128 / L0_PE;

#pragma HLS DATAFLOW
#pragma HLS stream depth=256 variable=memInStrm // mask memory latency
#pragma HLS stream depth=256 variable=memInStrm // mask memory latency
#pragma HLS stream depth=L0_DEPTH variable=inter0
#pragma HLS stream depth=L1_DEPTH variable=inter1
#pragma HLS stream depth=L2_DEPTH variable=inter2
#pragma HLS stream depth=256 variable=memOutStrm // mask memory latency
#pragma HLS stream depth=256 variable=memOutStrm // mask memory latency

const unsigned int inBits = 28*28;
const unsigned int inBitsPadded = 832; // paddedSizeHW(inBits, 64)
Expand All @@ -177,32 +180,41 @@ void DoCompute(ap_uint<64> * in, ap_uint<64> * out, const unsigned int numReps)



Mem2Stream_Batch<64, inBytesPadded>(in, memInStrm, numReps);
Mem2Stream_Batch<64, inBytesPadded>(in, memInStrm, numReps);

StreamingFCLayer_Batch<64, L0_PE, L0_SIMD, L0_PE, 16, 24, 16,L0_MW, L0_MH, L0_WMEM, L0_TMEM,numRes>
(memInStrm, inter0, weightMem0, thresMem0, alphaMem0, means_in0, means_out0, numReps);
StreamingFCLayer_Batch<L0_PE, L1_PE, L1_SIMD, L1_PE, 16, 24, 16,L1_MW, L1_MH, L1_WMEM, L1_TMEM,numRes>
(inter0, inter1, weightMem1, thresMem1, alphaMem1, means_in1, means_out1,numReps);
StreamingFCLayer_Batch<L1_PE, L2_PE, L2_SIMD, L2_PE, 16, 24, 16,L2_MW, L2_MH, L2_WMEM, L2_TMEM,numRes>
(inter1, inter2, weightMem2, thresMem2, alphaMem2, means_in2, means_out2,numReps);
StreamingFCLayer_Batch<L2_PE, 64, L3_SIMD, L3_PE, 16, 24, 16,L3_MW, L3_MH, L3_WMEM, L3_TMEM,numRes>
(inter2, memOutStrm, weightMem3, thresMem3, alphaMem3, means_in3, means_out3,numReps);
Stream2Mem_Batch<64, outBytesPadded>(memOutStrm, out, numReps);
}
StreamingFCLayer_Batch<64, L0_PE, L0_SIMD, L0_PE, 16, 24, 16,L0_MW, L0_MH, L0_WMEM, L0_TMEM,numRes>(memInStrm, inter0, weightMem0, thresMem0, alphaMem0, means_in0, means_out0, numReps);

LUTNET_StreamingNumResConverter<L0_PE, L0_MH*2, 1*1, 2>(inter0, inter0_1);

LUTNET_LUT4MV<1, 1, 256, 1, 1, 256, 1, 1, numRes, 24, 16, hls::stream<ap_uint<256*2>>, hls::stream<ap_uint<256*2>>>(inter0_1, inter0_2, thresh_fc2, alpha_fc2, next_layer_means_fc2, rand_map_0_fc2, rand_map_1_fc2, rand_map_2_fc2);
LUTNET_LUT4MV<1, 1, 256, 1, 1, 256, 1, 1, numRes, 24, 16, hls::stream<ap_uint<256*2>>, hls::stream<ap_uint<256*2>>>(inter0_2, inter0_3, thresh_fc3, alpha_fc3, next_layer_means_fc3, rand_map_0_fc3, rand_map_1_fc3, rand_map_2_fc3);
LUTNET_LUT4MV<1, 1, 256, 1, 1, 256, 1, 1, numRes, 24, 16, hls::stream<ap_uint<256*2>>, hls::stream<ap_uint<256*2>>>(inter0_3, inter0_4, thresh_fc4, alpha_fc4, next_layer_means_fc4, rand_map_0_fc4, rand_map_1_fc4, rand_map_2_fc4);
LUTNET_LUT4MV_NOTH<1, 1, 256, 1, 1, 10, 1, 1, numRes, 24, 16, hls::stream<ap_uint<256*2>>, hls::stream<ap_uint<10*24>>>(inter0_4, inter0_5, alpha_fc5, rand_map_0_fc5, rand_map_1_fc5, rand_map_2_fc5);

LUTNET_StreamingNumResConverter<10*24, 960, 1*1, 2>(inter0_5, inter0_6);
LUTNET_StreamingNumResConverter<960, 64, 1*1, 2>(inter0_6, memOutStrm);

Stream2Mem_Batch<64, outBytesPadded>(memOutStrm, out, numReps);

}

void BlackBoxJam(ap_uint<64> * in, ap_uint<64> * out, bool doInit,
unsigned int targetLayer, unsigned int targetMem,
unsigned int targetInd, ap_uint<64> val, unsigned int numReps, ap_fixed<24,16> fix_val) {
unsigned int targetInd, ap_uint<64> val, ap_fixed<24,16> fix_val) {

unsigned int numReps=2;
//#pragma HLS RESOURCE variable=thresMem4 core=RAM_S2P_LUTRAM
//#pragma HLS RESOURCE variable=thresMem5 core=RAM_S2P_LUTRAM
//#pragma HLS RESOURCE variable=thresMem6 core=RAM_S2P_LUTRAM
// pragmas for MLBP jam interface
// signals to be mapped to the AXI Lite slave port
numReps=1;
#pragma HLS INTERFACE s_axilite port=return bundle=control
#pragma HLS INTERFACE s_axilite port=doInit bundle=control
#pragma HLS INTERFACE s_axilite port=targetLayer bundle=control
#pragma HLS INTERFACE s_axilite port=targetMem bundle=control
#pragma HLS INTERFACE s_axilite port=targetInd bundle=control
#pragma HLS INTERFACE s_axilite port=val bundle=control
#pragma HLS INTERFACE s_axilite port=fix_val bundle=control
#pragma HLS INTERFACE s_axilite port=numReps bundle=control
// signals to be mapped to the AXI master port (hostmem)
#pragma HLS INTERFACE m_axi offset=slave port=in bundle=hostmem depth=256
Expand All @@ -213,33 +225,14 @@ void BlackBoxJam(ap_uint<64> * in, ap_uint<64> * out, bool doInit,
// partition PE arrays
#pragma HLS ARRAY_PARTITION variable=weightMem0 complete dim=1
#pragma HLS ARRAY_PARTITION variable=thresMem0 complete dim=1
#pragma HLS ARRAY_PARTITION variable=weightMem1 complete dim=1
#pragma HLS ARRAY_PARTITION variable=thresMem1 complete dim=1
#pragma HLS ARRAY_PARTITION variable=weightMem2 complete dim=1
#pragma HLS ARRAY_PARTITION variable=thresMem2 complete dim=1
#pragma HLS ARRAY_PARTITION variable=weightMem3 complete dim=1
#pragma HLS ARRAY_PARTITION variable=thresMem3 complete dim=1


#pragma HLS ARRAY_PARTITION variable=alphaMem0 complete dim=1
#pragma HLS ARRAY_PARTITION variable=alphaMem1 complete dim=1
#pragma HLS ARRAY_PARTITION variable=alphaMem2 complete dim=1
#pragma HLS ARRAY_PARTITION variable=alphaMem3 complete dim=1


#pragma HLS ARRAY_PARTITION variable=means_in0 complete dim=1
#pragma HLS ARRAY_PARTITION variable=means_in1 complete dim=1
#pragma HLS ARRAY_PARTITION variable=means_in2 complete dim=1
#pragma HLS ARRAY_PARTITION variable=means_in3 complete dim=1

#pragma HLS ARRAY_PARTITION variable=means_out1 complete dim=1
#pragma HLS ARRAY_PARTITION variable=means_out2 complete dim=1
#pragma HLS ARRAY_PARTITION variable=means_out3 complete dim=1
#pragma HLS ARRAY_PARTITION variable=means_out0 complete dim=1

if (doInit) {
DoMemInit(targetLayer, targetMem, targetInd, val,fix_val);
} else {
DoCompute(in, out, numReps);

}
}

0 comments on commit 81b044f

Please sign in to comment.