4
4
5
5
#include <mkl.h>
6
6
7
- #define MM 4000
8
- #define KK 4000
9
- #define NN 4000
7
+ /*
8
+ #define MM 4096
9
+ #define KK 4096
10
+ #define NN 4096
10
11
11
12
#define BLOCK_M 64
12
13
#define BLOCK_K 64
17
18
#define SIMD_BLOCK_M BLOCK_M/SIMD_SIZE
18
19
#define SIMD_BLOCK_K BLOCK_K/SIMD_SIZE
19
20
#define SIMD_BLOCK_N BLOCK_N/SIMD_SIZE
21
+ */
22
+
23
+ const unsigned long long MM = 4096 ;
24
+ const unsigned long long KK = 4096 ;
25
+ const unsigned long long NN = 4096 ;
26
+
27
+ const unsigned long long BLOCK_M = 64 ;
28
+ const unsigned long long BLOCK_K = 64 ;
29
+ const unsigned long long BLOCK_N = 64 ;
30
+
31
+ const unsigned long long SIMD_SIZE = 16 ;
32
+
33
+ const unsigned long long SIMD_BLOCK_M = BLOCK_M /SIMD_SIZE ;
34
+ const unsigned long long SIMD_BLOCK_K = BLOCK_K /SIMD_SIZE ;
35
+ const unsigned long long SIMD_BLOCK_N = BLOCK_N /SIMD_SIZE ;
20
36
21
37
// L1 (32K) should be enough for a 8K(16*512)
22
38
@@ -26,11 +42,11 @@ void aligned_MM(float *A, float *B, float *C, int M, int K, int N) {
26
42
27
43
for (I = 0 ; I < M ; I ++ )
28
44
for (J = 0 ; J < K ; J += SIMD_SIZE )
29
- for (L = 0 ; L < M ; L += SIMD_SIZE ) {
45
+ for (L = 0 ; L < N ; L += SIMD_SIZE ) {
30
46
__m512 a ;
31
47
__m512 b , c ;
32
48
float * A_ptr = A + I * K + J ;
33
- float * B_ptr = B + J * N + K ;
49
+ float * B_ptr = B + J * N + L ;
34
50
float * C_ptr = C + I * N + L ;
35
51
36
52
c = _mm512_load_ps (C_ptr );
@@ -43,30 +59,50 @@ void aligned_MM(float *A, float *B, float *C, int M, int K, int N) {
43
59
}
44
60
}
45
61
46
- void Block_C_SIMD (const float * A , const float * B , float * C , const int m , const int n , const int M , const int K , const int N ) {
62
+ void Block_AB (float * A , float * B , float * C , const int m , const int n , const int k , int M , int K , int N ) {
63
+ int i , j , l ;
64
+ float a [BLOCK_M * BLOCK_K ], b [BLOCK_N * BLOCK_K ];
65
+ for (i = 0 ; i < k ; i ++ )
66
+ for (l = 0 ; l < m ; l ++ )
67
+ a [i ] = A [l * K + i ];
68
+ for (l = 0 ; l < k ; l ++ )
69
+ for (i = 0 ; i < n ; i ++ )
70
+ b [i ] = B [l * N + i ];
71
+
72
+ }
73
+
74
+ void Block_C_SIMD (float * A , float * B , float * C , const int m , const int n , const int M , const int K , const int N ) {
47
75
int i , j , l ;
48
76
int align_K = BLOCK_K * (K /BLOCK_K );
49
77
50
78
__m512 b [SIMD_BLOCK_N ];
51
79
__m512 c [BLOCK_M * SIMD_BLOCK_N ] = {0. };
52
80
53
81
for (l = 0 ; l < K ; l ++ ) {
54
- for (j = 0 ; j < SIMD_BLOCK_M ; j ++ )
55
- b [j ] = _mm512_load_ps (& B [l * N + j * SIMD_SIZE ]);
82
+ float * a_ptr = & A [l ];
83
+ float * b_ptr = & B [l * N ];
84
+ for (j = 0 ; j < SIMD_BLOCK_N ; j ++ , b_ptr += SIMD_SIZE )
85
+ b [j ] = _mm512_load_ps (b_ptr );
56
86
87
+ __m512 * c_ptr = & c [0 ];
57
88
for (i = 0 ; i < m ; i ++ ) {
58
- __m512 a = _mm512_set1_ps (A [i * K + l ]);
59
- for (j = 0 ; j < SIMD_BLOCK_M ; j ++ ) {
60
- c [i * SIMD_BLOCK_N + j ] += a * b [j ];
89
+ __m512 a = _mm512_set1_ps (* a_ptr );
90
+ for (j = 0 ; j < SIMD_BLOCK_N ; j ++ ) {
91
+ * c_ptr += a * b [j ];
92
+ c_ptr ++ ;
61
93
}
94
+ a_ptr += K ;
62
95
}
63
96
}
64
97
65
- for (i = 0 ; i < m ; i ++ )
66
- for (j = 0 ; j < SIMD_BLOCK_M ; j ++ ) {
67
- __m512 cc = _mm512_load_ps (& C [i * M + j * SIMD_SIZE ]);
68
- cc += c [i * SIMD_BLOCK_N + j ];
69
- _mm512_store_ps (& C [i * M + j * SIMD_SIZE ],cc );
98
+ for (i = 0 ; i < m ; i ++ ) {
99
+ float * c_ptr = & C [i * M ];
100
+ for (j = 0 ; j < SIMD_BLOCK_N ; j ++ ) {
101
+ __m512 cc = _mm512_load_ps (c_ptr );
102
+ cc += c [i * SIMD_BLOCK_N + j ];
103
+ _mm512_store_ps (c_ptr ,cc );
104
+ c_ptr += SIMD_SIZE ;
105
+ }
70
106
}
71
107
}
72
108
@@ -82,18 +118,6 @@ void MatMul(float *A, float *B, float *C, int M, int K, int N) {
82
118
}
83
119
}
84
120
85
- void Block_AB (float * A , float * B , float * C , const int m , const int n , const int k , int M , int K , int N ) {
86
- int i , j , l ;
87
- float a [BLOCK_M * BLOCK_K ], b [BLOCK_N * BLOCK_K ];
88
- for (i = 0 ; i < k ; i ++ )
89
- for (l = 0 ; l < m ; l ++ )
90
- a [i ] = A [l * K + i ];
91
- for (l = 0 ; l < k ; l ++ )
92
- for (i = 0 ; i < n ; i ++ )
93
- b [i ] = B [l * N + i ];
94
-
95
- }
96
-
97
121
void Block_C (float * A , float * B , float * C , const int m , const int n , int M , int K , int N ) {
98
122
int i , j , l ;
99
123
int align_K = BLOCK_K * (K /BLOCK_K );
@@ -136,6 +160,7 @@ void MatMul_block_ins(float *A, float *B, float *C, int M, int K, int N) {
136
160
int align_M = BLOCK_M * (M /BLOCK_M );
137
161
int align_N = BLOCK_N * (N /BLOCK_N );
138
162
163
+ #pragma omp parallel for
139
164
for (I = 0 ; I < align_M ; I += BLOCK_M )
140
165
for (L = 0 ; L < align_N ; L += BLOCK_N )
141
166
Block_C_SIMD (& A [I * K ], & B [L ], & C [I * N + L ], BLOCK_M , BLOCK_N , M , K , N );
@@ -159,45 +184,48 @@ int main() {
159
184
float * A , * B , * C ;
160
185
161
186
struct timeval begin , end ;
162
- int timeuse ;
187
+ float timeuse ;
163
188
164
189
A = (float * )_mm_malloc (sizeof (float )* MM * KK , 64 );
165
190
B = (float * )_mm_malloc (sizeof (float )* KK * NN , 64 );
166
191
C = (float * )_mm_malloc (sizeof (float )* MM * NN , 64 );
167
192
168
- for (i = 0 ; i < MM * KK ; i ++ ) A [i ] = 1. ;
169
- for (i = 0 ; i < KK * NN ; i ++ ) B [i ] = 2. ;
193
+ for (i = 0 ; i < MM * KK ; i ++ ) A [i ] = i ;
194
+ for (i = 0 ; i < KK * NN ; i ++ ) B [i ] = 2. * i ;
170
195
for (i = 0 ; i < MM * NN ; i ++ ) C [i ] = 0. ;
171
196
172
- gettimeofday ( & begin , NULL );
173
197
cblas_sgemm (CblasRowMajor , CblasNoTrans , CblasNoTrans , MM , NN , KK , 1. , A , KK , B , NN , 1. , C , NN );
198
+ gettimeofday ( & begin , NULL );
199
+ for (int f = 0 ;f < 5 ;f ++ )
200
+ cblas_sgemm (CblasRowMajor , CblasNoTrans , CblasNoTrans , MM , NN , KK , 1. , A , KK , B , NN , 1. , C , NN );
174
201
gettimeofday ( & end , NULL );
175
- timeuse = 1000000 * ( end .tv_sec - begin .tv_sec ) + end .tv_usec - begin .tv_usec ;
176
- printf ("mkl time: %d us \n" , timeuse );
202
+ timeuse = ( 1000000. * ( end .tv_sec - begin .tv_sec ) + end .tv_usec - begin .tv_usec )/ 1000. ;
203
+ printf ("mkl time: %.2f ms \n" , timeuse / 5 );
177
204
/*
178
205
gettimeofday( &begin, NULL );
179
206
MatMul(A, B, C, MM, KK, NN);
180
207
gettimeofday( &end, NULL );
181
208
timeuse = 1000000 * ( end.tv_sec - begin.tv_sec ) + end.tv_usec - begin.tv_usec;
182
209
printf("org time: %d us\n", timeuse);
183
- */
210
+
184
211
gettimeofday( &begin, NULL );
185
- MatMul_block (A , B , C , MM , KK , NN );
212
+ MatMul_ins (A, B, C, MM, KK, NN);
186
213
gettimeofday( &end, NULL );
187
- timeuse = 1000000 * ( end .tv_sec - begin .tv_sec ) + end .tv_usec - begin .tv_usec ;
188
- printf ("opt time: %d us\n" , timeuse );
214
+ timeuse = (1000000 * ( end.tv_sec - begin.tv_sec ) + end.tv_usec - begin.tv_usec)/1000.;
215
+ printf("ins time: %.2f ms\n", timeuse);
216
+
189
217
gettimeofday( &begin, NULL );
190
- MatMul_block_ins (A , B , C , MM , KK , NN );
218
+ MatMul_block (A, B, C, MM, KK, NN);
191
219
gettimeofday( &end, NULL );
192
- timeuse = 1000000 * ( end .tv_sec - begin .tv_sec ) + end .tv_usec - begin .tv_usec ;
193
- printf ("opt time: %d us \n" , timeuse );
194
- /*
220
+ timeuse = ( 1000000. * ( end.tv_sec - begin.tv_sec ) + end.tv_usec - begin.tv_usec)/1000. ;
221
+ printf("block time: %.2f ms \n", timeuse);
222
+ */
195
223
gettimeofday ( & begin , NULL );
196
- MatMul_ins (A, B, C, MM, KK, NN);
224
+ MatMul_block_ins (A , B , C , MM , KK , NN );
197
225
gettimeofday ( & end , NULL );
198
- timeuse = 1000000 * ( end.tv_sec - begin.tv_sec ) + end.tv_usec - begin.tv_usec;
199
- printf("ins time: %d us \n", timeuse);
200
- */
226
+ timeuse = ( 1000000. * ( end .tv_sec - begin .tv_sec ) + end .tv_usec - begin .tv_usec )/ 1000. ;
227
+ printf ("block ins time: %.2f ms \n" , timeuse );
228
+
201
229
_mm_free (A );
202
230
_mm_free (B );
203
231
_mm_free (C );
0 commit comments