diff --git a/test_minimal.py b/test_minimal.py deleted file mode 100644 index 68d9aa4..0000000 --- a/test_minimal.py +++ /dev/null @@ -1,146 +0,0 @@ -import tvm -import numpy as np - -def decl_V(A): - temp_expr = {} - for j in range(4): - temp_expr[(0, j)] = A[0][j] - A[2][j] - temp_expr[(1, j)] = A[1][j] + A[2][j] - temp_expr[(2, j)] = A[2][j] - A[1][j] - temp_expr[(3, j)] = A[1][j] - A[3][j] - - def compute_temp(i, j): - now = tvm.const(0.0, "float32") - for ii in range(4): - for jj in range(4): - now = tvm.select(tvm.all(i == ii, j == jj), - temp_expr[(ii, jj)], - now) - return now - - T1 = tvm.compute((4,4), compute_temp, name="T1") - - v_expr = {} - for i in range(4): - v_expr[(i, 0)] = T1[i][0] - T1[i][2] - v_expr[(i, 1)] = T1[i][1] + T1[i][2] - v_expr[(i, 2)] = T1[i][2] - T1[i][1] - v_expr[(i, 3)] = T1[i][1] - T1[i][3] - - def compute_V(i, j): - now = tvm.const(0.0, "float32") - for ii in range(4): - for jj in range(4): - now = tvm.select(tvm.all(i == ii, j == jj), - v_expr[(ii, jj)], - now) - return now - - V = tvm.compute((4,4), compute_V) - - return V - -def decl_output(M): - temp_expr = {} - for j in range(4): - t0 = M[0][j] + M[1][j] - t1 = M[1][j] - M[2][j] - temp_expr[(0,j)] = t0 + M[2][j] - temp_expr[(1,j)] = t1 - M[3][j] - - def compute_temp(i, j): - now = tvm.const(0.0, "float32") - for ii in range(2): - for jj in range(4): - now = tvm.select(tvm.all(i == ii, j == jj), - temp_expr[(ii, jj)], - now) - return now - - T1 = tvm.compute((2,4), compute_temp, name="T1") - - output_expr = {} - for i in range(2): - t0 = T1[i][0] + T1[i][1] - t1 = T1[i][1] - T1[i][2] - output_expr[(i,0)] = t0 + T1[i][2] - output_expr[(i,1)] = t1 - T1[i][3] - - def compute_output(i, j): - now = tvm.const(0.0, "float32") - for ii in range(2): - for jj in range(2): - now = tvm.select(tvm.all(i == ii, j == jj), - output_expr[(ii, jj)], - now) - return now - - output = tvm.compute((2,2), compute_output) - - return output - -def schedule(outs): - s = tvm.create_schedule([x.op for x in outs]) - op = outs[0].op - output = op.output(0) - T1 = s[output].op.input_tensors[0] - i, j = s[output].op.axis - s[output].unroll(i) - s[output].unroll(j) - i, j = s[T1].op.axis - s[T1].unroll(i) - s[T1].unroll(j) - - return s - -A = tvm.placeholder((4, 4), name="A") -M = tvm.placeholder((4, 4), name="M") -device = "llvm" -with tvm.target.create(device): - T = decl_V(A) - s = schedule([T]) - output = decl_output(M) - s2 = schedule([output]) - -#print(tvm.lower(s, [A, T], simple_mode=True)) -func = tvm.build(s, [A, T], device) - -ctx = tvm.context(device, 0) -a_np = np.random.uniform(size=(4,4)).astype("float32") -t_np = np.random.uniform(size=(4,4)).astype("float32") -a = tvm.nd.array(a_np, ctx) -t = tvm.nd.array(t_np, ctx) - -func(a,t) -#print(t) - -B_data = np.array([ - [1, 0, 0, 0], - [0, 1, -1, 1], - [-1, 1, 1, 0], - [0, 0, 0, -1] -], "float32") - -ref = np.dot(np.dot(B_data.transpose(), a_np), B_data) -# print(ref) - -print(tvm.lower(s2, [M, output], simple_mode=True)) -func = tvm.build(s2, [M, output], device) - -m_np = np.random.uniform(size=(4,4)).astype("float32") -output_np = np.random.uniform(size=(2,2)).astype("float32") -m = tvm.nd.array(m_np, ctx) -output = tvm.nd.array(output_np, ctx) - -func(m, output) - -A_data = np.array([ - [1, 0], - [1, 1], - [1, -1], - [0, -1], -], "float32") - -ref = np.dot(np.dot(A_data.transpose(), m_np), A_data) -print(np.max(np.abs(output.asnumpy() - ref))) - diff --git a/test_minimal_4x4.py b/test_minimal_4x4.py deleted file mode 100644 index 25749a9..0000000 --- a/test_minimal_4x4.py +++ /dev/null @@ -1,170 +0,0 @@ -import tvm -import numpy as np - -def decl_V(A): - temp_expr = {} - for j in range(6): - t0 = A[4][j] - A[2][j]*4.0 - t1 = A[3][j] - A[1][j]*4.0 - t2 = A[4][j] - A[2][j] - t3 = A[3][j] - A[1][j] - temp_expr[(0, j)] = A[0][j] * 4.0 - A[2][j] * 5.0 + A[4][j] - temp_expr[(1, j)] = t0 + t1 - temp_expr[(2, j)] = t0 - t1 - temp_expr[(3, j)] = t2 + t3*2.0 - temp_expr[(4, j)] = t2 - t3*2.0 - temp_expr[(5, j)] = A[1][j] * 4.0 - A[3][j] * 5.0 + A[5][j] - - def compute_temp(i, j): - now = tvm.const(0.0, "float32") - for ii in range(6): - for jj in range(6): - now = tvm.select(tvm.all(i == ii, j == jj), - temp_expr[(ii, jj)], - now) - return now - - T1 = tvm.compute((6,6), compute_temp, name="T1") - - v_expr = {} - for i in range(6): - t0 = T1[i][4] - T1[i][2]*4.0 - t1 = T1[i][3] - T1[i][1]*4.0 - t2 = T1[i][4] - T1[i][2] - t3 = T1[i][3] - T1[i][1] - v_expr[(i, 0)] = T1[i][0] * 4.0 - T1[i][2] * 5.0 + T1[i][4] - v_expr[(i, 1)] = t0 + t1 - v_expr[(i, 2)] = t0 - t1 - v_expr[(i, 3)] = t2 + t3*2.0 - v_expr[(i, 4)] = t2 - t3*2.0 - v_expr[(i, 5)] = T1[i][1] * 4.0 - T1[i][3] * 5.0 + T1[i][5] - - def compute_V(i, j): - now = tvm.const(0.0, "float32") - for ii in range(6): - for jj in range(6): - now = tvm.select(tvm.all(i == ii, j == jj), - v_expr[(ii, jj)], - now) - return now - - V = tvm.compute((6,6), compute_V) - - return V - -def decl_output(M): - temp_expr = {} - for j in range(6): - t0 = M[1][j] + M[2][j] - t1 = M[3][j] + M[4][j] - t2 = M[1][j] - M[2][j] - t3 = M[3][j] - M[4][j] - temp_expr[(0, j)] = t0 + t1 + M[0][j] - temp_expr[(1, j)] = t2 + t3*2.0 - temp_expr[(2, j)] = t0 + t1*4.0 - temp_expr[(3, j)] = t2 + t3*8.0 + M[5][j] - - def compute_temp(i, j): - now = tvm.const(0.0, "float32") - for ii in range(4): - for jj in range(6): - now = tvm.select(tvm.all(i == ii, j == jj), - temp_expr[(ii, jj)], - now) - return now - - T1 = tvm.compute((4,6), compute_temp, name="T1") - - output_expr = {} - for i in range(4): - t0 = T1[i][1] + T1[i][2] - t1 = T1[i][3] + T1[i][4] - t2 = T1[i][1] - T1[i][2] - t3 = T1[i][3] - T1[i][4] - output_expr[(i, 0)] = t0 + t1 + T1[i][0] - output_expr[(i, 1)] = t2 + t3 * 2.0 - output_expr[(i, 2)] = t0 + t1 * 4.0 - output_expr[(i, 3)] = t2 + t3 * 8.0 + T1[i][5] - - def compute_output(i, j): - now = tvm.const(0.0, "float32") - for ii in range(4): - for jj in range(4): - now = tvm.select(tvm.all(i == ii, j == jj), - output_expr[(ii, jj)], - now) - return now - - output = tvm.compute((4,4), compute_output) - - return output - -def schedule(outs): - s = tvm.create_schedule([x.op for x in outs]) - op = outs[0].op - output = op.output(0) - T1 = s[output].op.input_tensors[0] - i, j = s[output].op.axis - s[output].unroll(i) - s[output].unroll(j) - i, j = s[T1].op.axis - s[T1].unroll(i) - s[T1].unroll(j) - - return s - -A = tvm.placeholder((6, 6), name="A") -M = tvm.placeholder((6, 6), name="M") -device = "llvm" -with tvm.target.create(device): - V = decl_V(A) - s = schedule([V]) - output = decl_output(M) - s2 = schedule([output]) - -print(tvm.lower(s, [A, V], simple_mode=True)) -func = tvm.build(s, [A, V], device) - -ctx = tvm.context(device, 0) -a_np = np.random.uniform(size=(6,6)).astype("float32") -t_np = np.random.uniform(size=(6,6)).astype("float32") -a = tvm.nd.array(a_np, ctx) -t = tvm.nd.array(t_np, ctx) - -func(a,t) -print(t) - -B_data = np.array([ - [4, 0, 0, 0, 0, 0], - [0, -4, 4, -2, 2, 4], - [-5, -4, -4, -1, -1, 0], - [0, 1, -1, 2, -2, -5], - [1, 1, 1, 1, 1, 0], - [0, 0, 0, 0, 0, 1], -], "float32") - -ref = np.dot(np.dot(B_data.transpose(), a_np), B_data) -print(np.max(np.abs(t.asnumpy() - ref))) - -print(tvm.lower(s2, [M, output], simple_mode=True)) -func = tvm.build(s2, [M, output], device) - -m_np = np.random.uniform(size=(6,6)).astype("float32") -output_np = np.random.uniform(size=(4,4)).astype("float32") -m = tvm.nd.array(m_np, ctx) -output = tvm.nd.array(output_np, ctx) - -func(m, output) - -A_data = np.array([ - [1, 0, 0, 0], - [1, 1, 1, 1], - [1, -1, 1, -1], - [1, 2, 4, 8], - [1, -2, 4, -8], - [0, 0, 0, 1] -], "float32") - -ref = np.dot(np.dot(A_data.transpose(), m_np), A_data) -print(np.max(np.abs(output.asnumpy() - ref))) - diff --git a/wino_test_cpu.py b/wino_test_cpu.py index 38e3aee..0f6fa76 100644 --- a/wino_test_cpu.py +++ b/wino_test_cpu.py @@ -7,6 +7,8 @@ from topi import util from topi.nn import pad +bna = 8 +bnb = 8 def reference_direct(batch, in_channel, in_size, num_filter, kernel, stride, padding, device): in_height = in_width = in_size @@ -37,359 +39,162 @@ def get_ref_data(): a = tvm.nd.array(a_np, ctx) w = tvm.nd.array(w_np, ctx) b = tvm.nd.array(np.zeros(util.get_const_tuple(B.shape), dtype=B.dtype), ctx) - with tvm.build_config(auto_unroll_max_step=500, - unroll_explicit=True): - func = tvm.build(s1, [A, W, B], device) - #print(tvm.lower(s1, [A, W, B], simple_mode=True)) - func(a, w, b) - np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) - num_runs = 100 - timer = func.time_evaluator(func.entry_name, ctx, number=num_runs) - return timer(a, w, b).mean - -def const_array(data, name): - """ convert an const array to tvm tensor""" - row, col = data.shape - dtype = str(data.dtype) - - def select_array(i, j): - now = tvm.const(0.0, dtype) - for ii in range(row): - for jj in range(col): - now = tvm.select(tvm.all(i % row == ii, j % col == jj), - tvm.const(data[ii][jj], dtype), - now) - return now - return tvm.compute(data.shape, select_array, name=name) - -def decl_U(data, kernel, stride, padding, out_dtype): - N, co, H, W, ci = [util.get_const_int(x) for x in data.shape] - K, C, KH, KW = [util.get_const_int(x) for x in kernel.shape] - HPAD, WPAD = 1,1 - if isinstance(stride, (tuple, list)): - HSTR, WSTR = stride - else: - HSTR, WSTR = stride, stride - - assert HSTR == 1 and WSTR == 1 and HPAD == 1 and WPAD == 1 - - G_data = np.array([ - [1, 0, 0], - [1.0/2, 1.0/2, 1.0/2], - [1.0/2, -1.0/2, 1.0/2], - [0, 0, 1], - ], out_dtype) - - m = 2 - r = 3 - alpha = m + r - 1 - nH, nW = (H + m-1) // m, (W + m-1) // m - P = N * nH * nW - bna, bnb = 8, 8 - # transform kernel - G = const_array(G_data, 'G') - r_kh = tvm.reduce_axis((0, KH), 'r_kh') - r_kw = tvm.reduce_axis((0, KW), 'r_kw') - U = tvm.compute((C // bnb, K // bna, alpha, alpha, bna, bnb), lambda c, k, eps, nu, cc, kk: - tvm.sum(kernel[k * bna + kk][c * bnb + cc][r_kh][r_kw] * G[eps][r_kh] * G[nu][r_kw], - axis=[r_kh, r_kw]), name='U') - outs = [U] - s = tvm.create_schedule([x.op for x in outs]) - op = outs[0].op - U = op.output(0) - kernel, G = s[U].op.input_tensors - s[G].compute_inline() - c, k, eps, nu, cc, kk = s[U].op.axis - r_kh, r_kw = s[U].op.reduce_axis - s[U].reorder(c, k, cc, kk, eps, nu, r_kh, r_kw) - _ = [s[U].unroll(x) for x in [eps, nu, r_kh, r_kw]] - kk = s[U].fuse(kk, cc) - s[U].vectorize(kk) - fused = s[U].fuse(k, c) - s[U].parallel(fused) - - return U, s - -def decl_V(data, kernel, stride, padding, out_dtype): - """declare winograd fast convolution F(2x2, 3x3) for conv2d""" - N, co, H, W, ci = [util.get_const_int(x) for x in data.shape] - K, C, KH, KW = [util.get_const_int(x) for x in kernel.shape] - HPAD, WPAD = 1,1 - if isinstance(stride, (tuple, list)): - HSTR, WSTR = stride - else: - HSTR, WSTR = stride, stride - - assert HSTR == 1 and WSTR == 1 and HPAD == 1 and WPAD == 1 - - B_data = np.array([ - [1, 0, 0, 0], - [0, 1, -1, 1], - [-1, 1, 1, 0], - [0, 0, 0, -1] - ], out_dtype) - - m = 2 - r = 3 - alpha = m + r - 1 - nH, nW = (H + m-1) // m, (W + m-1) // m - P = N * nH * nW - bna, bnb = 8, 8 - - data_pad = pad(data, (0, 0, HPAD, WPAD, 0), name="data_pad") - - # transform image - B = const_array(B_data, 'B') - r_eps = tvm.reduce_axis((0, alpha), 'r_eps') - r_nu = tvm.reduce_axis((0, alpha), 'r_nu') - V = tvm.compute((P // bnb, C // bna, alpha, alpha, bnb, bna), lambda b, c, eps, nu, bb, cc: - tvm.sum(data_pad[(b*bnb+bb) // (nH*nW)][c][(b*bnb+bb) // nW % nH * m + r_eps][(b*bnb+bb) % nW * m + r_nu][cc] * B[r_eps][eps] * B[r_nu][nu], - axis=[r_eps, r_nu]), name='V') - - outs = [V] - s = tvm.create_schedule([x.op for x in outs]) - op = outs[0].op - V = op.output(0) - data_pad, B = s[V].op.input_tensors - s[data_pad].compute_inline() - s[B].compute_inline() - b, c, eps, nu, bb, cc = s[V].op.axis - r_eps, r_nu = s[V].op.reduce_axis - s[V].reorder(b, c, bb, cc, eps, nu, r_nu, r_eps) - s[V].vectorize(cc) - _ = [s[V].unroll(x) for x in [eps, nu, r_eps, r_nu]] - fused = s[V].fuse(b, c, bb) - s[V].parallel(fused) - - return V, s - -def decl_M(data, kernel, U, V, stride, padding, out_dtype): - """declare winograd fast convolution F(2x2, 3x3) for conv2d""" - N, co, H, W, ci = [util.get_const_int(x) for x in data.shape] - K, C, KH, KW = [util.get_const_int(x) for x in kernel.shape] - HPAD, WPAD = 1,1 - if isinstance(stride, (tuple, list)): - HSTR, WSTR = stride - else: - HSTR, WSTR = stride, stride - - assert HSTR == 1 and WSTR == 1 and HPAD == 1 and WPAD == 1 - - m = 2 - r = 3 - alpha = m + r - 1 - nH, nW = (H + m-1) // m, (W + m-1) // m - P = N * nH * nW - bna, bnb = 8, 8 - - # batch gemm - c = tvm.reduce_axis((0, C), name='c') - M = tvm.compute((P //bnb, K // bna, alpha, alpha, bnb, bna), lambda b, k, eps, nu, bb, kk: - tvm.sum(V[b][c // bna][eps][nu][bb][c % bna] * - U[c // bna][k][eps][nu][c % bna][kk], axis=c), name='M') - - outs = [M] - s = tvm.create_schedule([x.op for x in outs]) - op = outs[0].op - M = op.output(0) - b, k, eps, nu, bb, kk = s[M].op.axis - c = s[M].op.reduce_axis[0] - - fused = s[M].fuse(b, k) - s[M].parallel(fused) - co, ci = s[M].split(c, factor=8) - s[M].reorder(co, bb, ci, kk) - s[M].unroll(ci) - s[M].vectorize(kk) - - return M, s - -def decl_output(data, kernel, M, stride, padding, out_dtype): - """declare winograd fast convolution F(2x2, 3x3) for conv2d""" - N, co, H, W, ci = [util.get_const_int(x) for x in data.shape] - K, C, KH, KW = [util.get_const_int(x) for x in kernel.shape] - HPAD, WPAD = 1,1 - if isinstance(stride, (tuple, list)): - HSTR, WSTR = stride - else: - HSTR, WSTR = stride, stride - - assert HSTR == 1 and WSTR == 1 and HPAD == 1 and WPAD == 1 + # with tvm.build_config(auto_unroll_max_step=500, + # unroll_explicit=True): + func = tvm.build(s1, [A, W, B], device) + func(a, w, b) + np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) + num_runs = 1000 + timer = func.time_evaluator(func.entry_name, ctx, number=num_runs) + return timer(a, w, b).mean + +def reference_direct_NCHWc(batch, in_channel, in_size, num_filter, kernel, stride, padding, device): + in_height = in_width = in_size + ic_block = 8 + oc_block = 8 + A = tvm.placeholder((batch, in_channel//ic_block, in_height, in_width, ic_block), name='A') + W = tvm.placeholder((num_filter//oc_block, in_channel//ic_block, kernel, kernel, ic_block, oc_block), name='W') - A_data = np.array([ - [1, 0], - [1, 1], - [1, -1], - [0, -1], - ], out_dtype) + a_shape = util.get_const_tuple(A.shape) + w_shape = util.get_const_tuple(W.shape) + dtype = A.dtype - m = 2 - r = 3 - alpha = m + r - 1 - nH, nW = (H + m-1) // m, (W + m-1) // m - P = N * nH * nW - bna, bnb = 8, 8 + @memoize("topi.tests.test_topi_conv2d_nchw.reference_direct") + def get_ref_data(): + a_np = np.random.uniform(size=a_shape).astype(dtype) + w_np = np.random.uniform(size=w_shape).astype(dtype) + b_np = topi.testing.conv2d_nchw_python(nchwc_to_nchw(a_np), nchwc_to_nchw_kernel(w_np), stride, padding) + c_np = np.maximum(b_np, 0) + return a_np, w_np, b_np, c_np - # inverse transform - A = const_array(A_data, 'A') - r_eps = tvm.reduce_axis((0, alpha), 'r_eps') - r_nu = tvm.reduce_axis((0, alpha), 'r_nu') - output = tvm.compute((N, K // bna, H, W, bna), lambda n, k, h, w, kk: - tvm.sum(M[(n * nH * nW + (h//m) * nW + w//m)//bna][k][r_eps][r_nu][(n * nH * nW + (h//m) * nW + w//m)%bna][kk] * A[r_eps][h % m] * A[r_nu][w % m], - axis=[r_eps, r_nu]), name='output') - - outs = [output] - s = tvm.create_schedule([x.op for x in outs]) - op = outs[0].op - output = op.output(0) - _, A = s[output].op.input_tensors - s[A].compute_inline() + a_np, w_np, b_np, c_np = get_ref_data() - n, k, h, w, kk = s[output].op.axis - r_eps, r_nu = s[output].op.reduce_axis - ho, hi = s[output].split(h, factor=2) - wo, wi = s[output].split(w, factor=2) - s[output].reorder(n, k, ho, wo, hi, wi, r_eps, r_nu, kk) - s[output].vectorize(kk) - fused = s[output].fuse(n, k, ho, wo) - s[output].parallel(fused) - - return output, s + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + return + with tvm.target.create(device): + B = topi.nn.conv2d_NCHWc(A, W, num_filter=num_filter, kernel_size=(3,3), stride=1, padding=1, layout='NCHWc', out_layout='NCHWc', out_dtype='float32') + s1 = topi.generic.schedule_conv2d_NCHWc(num_filter=num_filter, kernel_size=(3,3), strides=1, padding=1, layout='NCHWc', out_layout='NCHWc', outs=[B]) + a = tvm.nd.array(a_np, ctx) + w = tvm.nd.array(w_np, ctx) + b = tvm.nd.array(np.zeros(util.get_const_tuple(B.shape), dtype=B.dtype), ctx) + # with tvm.build_config(auto_unroll_max_step=500, + # unroll_explicit=True): + func = tvm.build(s1, [A, W, B], device) + func(a, w, b) + np.testing.assert_allclose(nchwc_to_nchw(b.asnumpy()), b_np, rtol=1e-5) + num_runs = 1000 + timer = func.time_evaluator(func.entry_name, ctx, number=num_runs) + return timer(a, w, b).mean -def decl_winograd(data, kernel, stride, padding, out_dtype): - """declare winograd fast convolution F(2x2, 3x3) for conv2d""" - N, co, H, W, ci = [util.get_const_int(x) for x in data.shape] - K, C, KH, KW = [util.get_const_int(x) for x in kernel.shape] - HPAD, WPAD = 1,1 - if isinstance(stride, (tuple, list)): - HSTR, WSTR = stride - else: - HSTR, WSTR = stride, stride +def decl_V_minimal(data_pad, P, C, alpha, bna, bnb, nH, nW, m): + def compute_temp(b, c, eps, nu, cc): + temp_expr = {} + batch_index = b // (nH*nW) + h = b // nW % nH * m + w = b % nW * m + for j in range(6): + t0 = data_pad[batch_index][c][h+4][w+j][cc] - data_pad[batch_index][c][h+2][w+j][cc]*4.0 + t1 = data_pad[batch_index][c][h+3][w+j][cc] - data_pad[batch_index][c][h+1][w+j][cc]*4.0 + t2 = data_pad[batch_index][c][h+4][w+j][cc] - data_pad[batch_index][c][h+2][w+j][cc] + t3 = data_pad[batch_index][c][h+3][w+j][cc] - data_pad[batch_index][c][h+1][w+j][cc] + temp_expr[(0, j)] = data_pad[batch_index][c][h+0][w+j][cc] * 4.0 - data_pad[batch_index][c][h+2][w+j][cc] * 5.0 + data_pad[batch_index][c][h+4][w+j][cc] + temp_expr[(1, j)] = t0 + t1 + temp_expr[(2, j)] = t0 - t1 + temp_expr[(3, j)] = t2 + t3*2.0 + temp_expr[(4, j)] = t2 - t3*2.0 + temp_expr[(5, j)] = data_pad[batch_index][c][h+1][w+j][cc] * 4.0 - data_pad[batch_index][c][h+3][w+j][cc] * 5.0 + data_pad[batch_index][c][h+5][w+j][cc] + + now = tvm.const(0.0, "float32") + for ii in range(alpha): + for jj in range(alpha): + now = tvm.select(tvm.all(eps == ii, nu == jj), + temp_expr[(ii, jj)], + now) + return now - assert HSTR == 1 and WSTR == 1 and HPAD == 1 and WPAD == 1 - data_pad = pad(data, (0, 0, HPAD, WPAD, 0), name="data_pad") + temp = tvm.compute((P, C // bna, alpha, alpha, bna), compute_temp, name="temp_V") + + def compute_V(b, c, eps, nu, cc): + v_expr = {} + for i in range(6): + t0 = temp[b][c][i][4][cc] - temp[b][c][i][2][cc]*4.0 + t1 = temp[b][c][i][3][cc] - temp[b][c][i][1][cc]*4.0 + t2 = temp[b][c][i][4][cc] - temp[b][c][i][2][cc] + t3 = temp[b][c][i][3][cc] - temp[b][c][i][1][cc] + v_expr[(i, 0)] = temp[b][c][i][0][cc] * 4.0 - temp[b][c][i][2][cc] * 5.0 + temp[b][c][i][4][cc] + v_expr[(i, 1)] = t0 + t1 + v_expr[(i, 2)] = t0 - t1 + v_expr[(i, 3)] = t2 + t3*2.0 + v_expr[(i, 4)] = t2 - t3*2.0 + v_expr[(i, 5)] = temp[b][c][i][1][cc] * 4.0 - temp[b][c][i][3][cc] * 5.0 + temp[b][c][i][5][cc] + + now = tvm.const(0.0, "float32") + for ii in range(6): + for jj in range(6): + now = tvm.select(tvm.all(eps == ii, nu == jj), + v_expr[(ii, jj)], + now) + return now - B_data = np.array([ - [1, 0, 0, 0], - [0, 1, -1, 1], - [-1, 1, 1, 0], - [0, 0, 0, -1] - ], out_dtype) - - A_data = np.array([ - [1, 0], - [1, 1], - [1, -1], - [0, -1], - ], out_dtype) - - G_data = np.array([ - [1, 0, 0], - [1.0/2, 1.0/2, 1.0/2], - [1.0/2, -1.0/2, 1.0/2], - [0, 0, 1], - ], out_dtype) - - m = 2 - r = 3 - alpha = m + r - 1 - nH, nW = (H + m-1) // m, (W + m-1) // m - P = N * nH * nW - bna, bnb = 8, 8 + V = tvm.compute((P, C // bna, alpha, alpha, bna), compute_V, name="V") + return V + +def decl_output_minimal(M, N, K, H, W, P, alpha, bna, bnb, nH, nW, m): + + def compute_temp(b, k, eps, nu, kk): + temp_expr = {} + for j in range(6): + t0 = M[b][k][1][j][kk] + M[b][k][2][j][kk] + t1 = M[b][k][3][j][kk] + M[b][k][4][j][kk] + t2 = M[b][k][1][j][kk] - M[b][k][2][j][kk] + t3 = M[b][k][3][j][kk] - M[b][k][4][j][kk] + temp_expr[(0, j)] = t0 + t1 + M[b][k][0][j][kk] + temp_expr[(1, j)] = t2 + t3*2.0 + temp_expr[(2, j)] = t0 + t1*4.0 + temp_expr[(3, j)] = t2 + t3*8.0 + M[b][k][5][j][kk] + + now = tvm.const(0.0, "float32") + for ii in range(4): + for jj in range(6): + now = tvm.select(tvm.all(eps == ii, nu == jj), + temp_expr[(ii, jj)], + now) + return now - # # transform kernel - G = const_array(G_data, 'G') - r_kh = tvm.reduce_axis((0, KH), 'r_kh') - r_kw = tvm.reduce_axis((0, KW), 'r_kw') - U = tvm.compute((C // bnb, K // bna, alpha, alpha, bna, bnb), lambda c, k, eps, nu, cc, kk: - tvm.sum(kernel[k * bna + kk][c * bnb + cc][r_kh][r_kw] * G[eps][r_kh] * G[nu][r_kw], - axis=[r_kh, r_kw]), name='U') + temp = tvm.compute((P, K // bna, m, alpha, bna), compute_temp, name="temp_Y") + + def compute_output(b, k, eps, nu, kk): + output_expr = {} + for i in range(4): + t0 = temp[b][k][i][1][kk] + temp[b][k][i][2][kk] + t1 = temp[b][k][i][3][kk] + temp[b][k][i][4][kk] + t2 = temp[b][k][i][1][kk] - temp[b][k][i][2][kk] + t3 = temp[b][k][i][3][kk] - temp[b][k][i][4][kk] + output_expr[(i, 0)] = t0 + t1 + temp[b][k][i][0][kk] + output_expr[(i, 1)] = t2 + t3 * 2.0 + output_expr[(i, 2)] = t0 + t1 * 4.0 + output_expr[(i, 3)] = t2 + t3 * 8.0 + temp[b][k][i][5][kk] + + now = tvm.const(0.0, "float32") + for ii in range(4): + for jj in range(4): + now = tvm.select(tvm.all(eps == ii, nu == jj), + output_expr[(ii, jj)], + now) + return now - # transform image - B = const_array(B_data, 'B') - r_eps = tvm.reduce_axis((0, alpha), 'r_eps') - r_nu = tvm.reduce_axis((0, alpha), 'r_nu') - V = tvm.compute((P // bnb, C // bna, alpha, alpha, bnb, bna), lambda b, c, eps, nu, bb, cc: - tvm.sum(data_pad[(b*bnb+bb) // (nH*nW)][c][(b*bnb+bb) // nW % nH * m + r_eps][(b*bnb+bb) % nW * m + r_nu][cc] * B[r_eps][eps] * B[r_nu][nu], - axis=[r_eps, r_nu]), name='V') - - # batch gemm - c = tvm.reduce_axis((0, C), name='c') - M = tvm.compute((P //bnb, K // bna, alpha, alpha, bnb, bna), lambda b, k, eps, nu, bb, kk: - tvm.sum(V[b][c // bna][eps][nu][bb][c % bna] * - U[c // bna][k][eps][nu][c % bna][kk], axis=c), name='M') - - - # inverse transform - A = const_array(A_data, 'A') - r_eps = tvm.reduce_axis((0, alpha), 'r_eps') - r_nu = tvm.reduce_axis((0, alpha), 'r_nu') - output = tvm.compute((N, K // bna, H, W, bna), lambda n, k, h, w, kk: - tvm.sum(M[(n * nH * nW + (h//m) * nW + w//m)//bna][k][r_eps][r_nu][(n * nH * nW + (h//m) * nW + w//m)%bna][kk] * A[r_eps][h % m] * A[r_nu][w % m], - axis=[r_eps, r_nu]), name='output') + Y = tvm.compute((P, K // bna, m, m, bna), compute_output, name="Y") + output = tvm.compute((N, K // bna, H, W, bna), lambda n, k, h, w, kk: + Y[n * nH * nW + (h//m) * nW + w//m][k][h % m][w % m][kk], + name='output', tag='winograd_conv_output') return output - -def schedule_winograd(outs): - s = tvm.create_schedule([x.op for x in outs]) - op = outs[0].op - output = op.output(0) - M, A = s[output].op.input_tensors - V, U = s[M].op.input_tensors - kernel, G = s[U].op.input_tensors - data_pad, B = s[V].op.input_tensors - data = s[data_pad].op.input_tensors[0] - - s[data_pad].compute_inline() - - # transform kernel - s[G].compute_inline() - c, k, eps, nu, cc, kk = s[U].op.axis - r_kh, r_kw = s[U].op.reduce_axis - s[U].reorder(c, k, cc, kk, eps, nu, r_kh, r_kw) - _ = [s[U].unroll(x) for x in [eps, nu, r_kh, r_kw]] - kk = s[U].fuse(kk, cc) - s[U].vectorize(kk) - fused = s[U].fuse(k, c) - s[U].parallel(fused) - - # transform image - s[B].compute_inline() - b, c, eps, nu, bb, cc = s[V].op.axis - r_eps, r_nu = s[V].op.reduce_axis - s[V].reorder(b, c, eps, nu, r_nu, r_eps, bb, cc) - s[V].vectorize(cc) - _ = [s[V].unroll(x) for x in [eps, nu, r_eps, r_nu]] - fused = s[V].fuse(b, c) - s[V].parallel(fused) - - # batch gemm - b, k, eps, nu, bb, kk = s[M].op.axis - c = s[M].op.reduce_axis[0] - fused = s[M].fuse(b, k) - s[M].parallel(fused) - co, ci = s[M].split(c, factor=8) - s[M].reorder(co, bb, ci, kk) - s[M].unroll(ci) - s[M].vectorize(kk) - -# # inverse transform - s[A].compute_inline() - n, k, h, w, kk = s[output].op.axis - r_eps, r_nu = s[output].op.reduce_axis - ho, hi = s[output].split(h, factor=2) - wo, wi = s[output].split(w, factor=2) - s[output].reorder(n, k, ho, wo, hi, wi, r_eps, r_nu, kk) - s[output].vectorize(kk) - fused = s[output].fuse(n, k, ho, wo) - s[output].parallel(fused) - - return s - def decl_winograd_without_filter_transform(data, U, stride, padding, out_dtype): N, co, H, W, ci = [util.get_const_int(x) for x in data.shape] - co, ko, _, _, ci, ki = [util.get_const_int(x) for x in U.shape] + ko, _, _, C, ki = [util.get_const_int(x) for x in U.shape] C = co * ci K = ko * ki HPAD, WPAD = 1,1 @@ -401,48 +206,22 @@ def decl_winograd_without_filter_transform(data, U, stride, padding, out_dtype): assert HSTR == 1 and WSTR == 1 and HPAD == 1 and WPAD == 1 data_pad = pad(data, (0, 0, HPAD, WPAD, 0), name="data_pad") - B_data = np.array([ - [1, 0, 0, 0], - [0, 1, -1, 1], - [-1, 1, 1, 0], - [0, 0, 0, -1] - ], out_dtype) - - A_data = np.array([ - [1, 0], - [1, 1], - [1, -1], - [0, -1], - ], out_dtype) - - m = 2 + m = 4 r = 3 alpha = m + r - 1 nH, nW = (H + m-1) // m, (W + m-1) // m P = N * nH * nW - bna, bnb = 8, 8 - # transform image - B = const_array(B_data, 'B') - r_eps = tvm.reduce_axis((0, alpha), 'r_eps') - r_nu = tvm.reduce_axis((0, alpha), 'r_nu') - V = tvm.compute((P // bnb, C // bna, alpha, alpha, bnb, bna), lambda b, c, eps, nu, bb, cc: - tvm.sum(data_pad[(b*bnb+bb) // (nH*nW)][c][(b*bnb+bb) // nW % nH * m + r_eps][(b*bnb+bb) % nW * m + r_nu][cc] * B[r_eps][eps] * B[r_nu][nu], - axis=[r_eps, r_nu]), name='V') + V = decl_V_minimal(data_pad, P, C, alpha, bna, bnb, nH, nW, m) # batch gemm c = tvm.reduce_axis((0, C), name='c') - M = tvm.compute((P //bnb, K // bna, alpha, alpha, bnb, bna), lambda b, k, eps, nu, bb, kk: - tvm.sum(V[b][c // bna][eps][nu][bb][c % bna] * - U[c // bna][k][eps][nu][c % bna][kk], axis=c), name='M') + M = tvm.compute((P, K // bna, alpha, alpha, bna), lambda b, k, eps, nu, kk: + tvm.sum(V[b][c // bna][eps][nu][c % bna] * + U[k][eps][nu][c][kk], axis=c), name='M') # inverse transform - A = const_array(A_data, 'A') - r_eps = tvm.reduce_axis((0, alpha), 'r_eps') - r_nu = tvm.reduce_axis((0, alpha), 'r_nu') - output = tvm.compute((N, K // bna, H, W, bna), lambda n, k, h, w, kk: - tvm.sum(M[(n * nH * nW + (h//m) * nW + w//m)//bna][k][r_eps][r_nu][(n * nH * nW + (h//m) * nW + w//m)%bna][kk] * A[r_eps][h % m] * A[r_nu][w % m], - axis=[r_eps, r_nu]), name='output') + output = decl_output_minimal(M, N, K, H, W, P, alpha, bna, bnb, nH, nW, m) return output @@ -450,58 +229,84 @@ def schedule_winograd_without_filter_transform(outs): s = tvm.create_schedule([x.op for x in outs]) op = outs[0].op output = op.output(0) - M, A = s[output].op.input_tensors + Y = s[output].op.input_tensors[0] + temp_output_transform = s[Y].op.input_tensors[0] + M = s[temp_output_transform].op.input_tensors[0] V, U = s[M].op.input_tensors - data_pad, B = s[V].op.input_tensors + temp_input_transform = s[V].op.input_tensors[0] + data_pad = s[temp_input_transform].op.input_tensors[0] data = s[data_pad].op.input_tensors[0] + b_factor = 8 + P = V.shape[0].value + if P == 16: + b_factor = 2 + # transform image s[data_pad].compute_inline() - s[B].compute_inline() - b, c, eps, nu, bb, cc = s[V].op.axis - r_eps, r_nu = s[V].op.reduce_axis - s[V].reorder(b, c, bb, cc, eps, nu, r_nu, r_eps) + b, c, eps, nu, cc = s[V].op.axis + bo, bi = s[V].split(b, factor=b_factor) + s[V].reorder(bo, c, bi, eps, nu, cc) s[V].vectorize(cc) - _ = [s[V].unroll(x) for x in [eps, nu, r_eps, r_nu]] - fused = s[V].fuse(b, c, bb) - s[V].parallel(fused) + _ = [s[V].unroll(x) for x in [eps, nu]] + + b, c, eps, nu, cc = s[temp_input_transform].op.axis + s[temp_input_transform].vectorize(cc) + _ = [s[temp_input_transform].unroll(x) for x in [eps, nu]] + s[temp_input_transform].compute_at(s[V], bi) # batch gemm - b, k, eps, nu, bb, kk = s[M].op.axis + b, k, eps, nu, kk = s[M].op.axis c = s[M].op.reduce_axis[0] - fused = s[M].fuse(b, k) - s[M].parallel(fused) - co, ci = s[M].split(c, factor=8) - s[M].reorder(co, bb, ci, kk) - s[M].unroll(ci) + co, ci = s[M].split(c, factor=8) + bo, bi = s[M].split(b, factor=b_factor) + s[M].reorder(bo, k, bi, eps, co, nu, ci, kk) + s[V].compute_at(s[M], bo) s[M].vectorize(kk) -# # inverse transform - s[A].compute_inline() + # inverse transform + b, k, eps, nu, kk = s[Y].op.axis + bo, bi = s[Y].split(b, factor=b_factor) + s[Y].reorder(bo, k, bi, eps, nu, kk) + #s[Y].parallel(bo) + s[Y].vectorize(kk) + _ = [s[Y].unroll(x) for x in [eps, nu]] + #s[M].compute_at(s[Y], bo) + + b, k, eps, nu, kk = s[temp_output_transform].op.axis + s[temp_output_transform].unroll(eps) + s[temp_output_transform].unroll(nu) + s[temp_output_transform].vectorize(kk) + s[temp_output_transform].compute_at(s[Y], bi) + n, k, h, w, kk = s[output].op.axis - r_eps, r_nu = s[output].op.reduce_axis - ho, hi = s[output].split(h, factor=2) - wo, wi = s[output].split(w, factor=2) - s[output].reorder(n, k, ho, wo, hi, wi, r_eps, r_nu, kk) + ho, hi = s[output].split(h, factor=4) + wo, wi = s[output].split(w, factor=4) + s[output].reorder(n, ho, wo, k, hi, wi, kk) + woo, bi = s[output].split(wo, factor=b_factor) + bo = s[output].fuse(n, ho, woo) + s[output].reorder(bo, k, bi, hi, wi, kk) s[output].vectorize(kk) - fused = s[output].fuse(n, k, ho, wo) - s[output].parallel(fused) + s[output].parallel(bo) + s[M].compute_at(s[output], bo) + s[Y].compute_at(s[output], bo) return s def transform_filter(w_np): num_filter, in_channel, kernel, kernel = w_np.shape G = np.array([ - [1, 0, 0], - [1.0/2, 1.0/2, 1.0/2], - [1.0/2, -1.0/2, 1.0/2], - [0, 0, 1], - ], w_np.dtype) - bna = 8 - out = np.empty((in_channel // bna, num_filter // bna, 4, 4, bna, bna), w_np.dtype) + [1 / 4.0, 0, 0], + [-1 / 6.0, -1 / 6.0, -1 / 6.0], + [-1 / 6.0, 1 / 6.0, -1 / 6.0], + [1 / 24.0, 1 / 12.0, 1 / 6.0], + [1 / 24.0, -1 / 12.0, 1 / 6.0], + [0, 0, 1] + ], dtype=np.float32) + out = np.empty((num_filter // bna, 6, 6, in_channel, bna), w_np.dtype) for i in range(in_channel): for j in range(num_filter): - out[i // bna, j // bna, :, :, i % bna, j % bna] = np.dot(G, np.dot(w_np[j, i], G.transpose())) + out[j // bna, :, :, i, j % bna] = np.dot(G, np.dot(w_np[j, i], G.transpose())) return out def nchwc_to_nchw(arr): @@ -511,150 +316,24 @@ def nchwc_to_nchw(arr): for i in range(channels): ret[:, i] = arr[:, i//cc, :, :, i%cc] return ret - -def test_components(batch, in_channel, in_size, num_filter, kernel, stride, padding, device): - in_height = in_width = in_size - m = 2 - r = 3 - alpha = m + r - 1 - K = num_filter - H = W = in_size - N = batch - C = in_channel - nH, nW = (H + m-1) // m, (W + m-1) // m - P = N * nH * nW - bna, bnb = 8, 8 - - A = tvm.placeholder((batch, in_channel // bna, in_height, in_width, bna), name='A') - W = tvm.placeholder((num_filter, in_channel, kernel, kernel), name='W') - U = tvm.placeholder((C // bna, K // bna, alpha, alpha, bna, bna), name='U') - V = tvm.placeholder((P // bnb, C // bna, alpha, alpha, bnb, bna), name='V') - M = tvm.placeholder((P // bnb, K // bna, alpha, alpha, bnb, bna), name='M') - - output = tvm.placeholder((N, K // bna, in_size, in_size, bna), name='output') - - a_shape = util.get_const_tuple(A.shape) - w_shape = util.get_const_tuple(W.shape) - dtype = A.dtype - dilation = 1 - - @memoize("topi.tests.test_topi_conv2d_nchw.wino") - def get_ref_data(): - a_np = np.random.uniform(size=a_shape).astype(dtype) - w_np = np.random.uniform(size=w_shape).astype(dtype) - dw_np = topi.testing.dilate_python(w_np, (1, 1, dilation, dilation)) - b_np = topi.testing.conv2d_nchw_python(nchwc_to_nchw(a_np), dw_np, stride, padding) - c_np = np.maximum(b_np, 0) - return a_np, w_np, b_np, c_np - - a_np, w_np, b_np, c_np = get_ref_data() - u_np = np.random.uniform(size=util.get_const_tuple(U.shape)).astype(dtype) - v_np = np.random.uniform(size=util.get_const_tuple(V.shape)).astype(dtype) - m_np = np.zeros(util.get_const_tuple(M.shape), dtype=dtype) - output_np = np.zeros(util.get_const_tuple(output.shape), dtype=dtype) - - with tvm.target.create(device): - U_out, s_U = decl_U(A, W, stride, padding, dtype) - V_out, s_V = decl_V(A, W, stride, padding, dtype) - M_out, s_M = decl_M(A, W, U, V, stride, padding, dtype) - output_out, s_output = decl_output(A, W, M, stride, padding, dtype) - - ctx = tvm.context(device, 0) - a = tvm.nd.array(a_np, ctx) - w = tvm.nd.array(w_np, ctx) - u = tvm.nd.array(u_np, ctx) - v = tvm.nd.array(v_np, ctx) - m = tvm.nd.array(m_np, ctx) - output_tvm = tvm.nd.array(output_np, ctx) - num_runs = 100 - times = {} - - with tvm.build_config(auto_unroll_max_step=500, - unroll_explicit=True): - func_kernel_transform = tvm.build(s_U, [W, U_out], device) - func_kernel_transform(w, u) - timer = func_kernel_transform.time_evaluator(func_kernel_transform.entry_name, ctx, number=num_runs) - times["U"] = timer(w, u).mean * 1000 - - func_input_transform = tvm.build(s_V, [A, V_out], device) - func_input_transform(a, v) - timer = func_input_transform.time_evaluator(func_input_transform.entry_name, ctx, number=num_runs) - times["V"] = timer(a, v).mean * 1000 - #print(tvm.lower(s_V, [A, V_out], simple_mode=True)) - - func_batch_mm = tvm.build(s_M, [U, V, M_out], device) - #print(tvm.lower(s_M, [U, V, M_out], simple_mode=True)) - func_batch_mm(u, v, m) - - timer = func_batch_mm.time_evaluator(func_batch_mm.entry_name, ctx, number=num_runs) - times["M"] = timer(u, v, m).mean * 1000 - #print(tvm.lower(s_M, [A, W, U, V, M], simple_mode=True)) - - func_inverse_transform = tvm.build(s_output, [M, output_out], device) - func_inverse_transform(m, output_tvm) - timer = func_inverse_transform.time_evaluator(func_inverse_transform.entry_name, ctx, number=num_runs) - times["output"] = timer(m, output_tvm).mean * 1000 - #print(tvm.lower(s_output, [A, W, M, output], simple_mode=True)) - - np.testing.assert_allclose(nchwc_to_nchw(output_tvm.asnumpy()), b_np, rtol=1e-5) - - return times - - -def test_winograd(batch, in_channel, in_size, num_filter, kernel, stride, padding, device): - in_height = in_width = in_size - bna, bnb = 8, 8 - - A = tvm.placeholder((batch, in_channel // bna, in_height, in_width, bna), name='A') - W = tvm.placeholder((num_filter, in_channel, kernel, kernel), name='W') - - a_shape = util.get_const_tuple(A.shape) - w_shape = util.get_const_tuple(W.shape) - dtype = A.dtype - dilation = 1 - - output = tvm.placeholder((batch, num_filter//bna, in_size, in_size, bna), name='output') - - @memoize("topi.tests.test_topi_conv2d_nchw.wino") - def get_ref_data(): - a_np = np.random.uniform(size=a_shape).astype(dtype) - w_np = np.random.uniform(size=w_shape).astype(dtype) - dw_np = topi.testing.dilate_python(w_np, (1, 1, dilation, dilation)) - b_np = topi.testing.conv2d_nchw_python(nchwc_to_nchw(a_np), dw_np, stride, padding) - c_np = np.maximum(b_np, 0) - return a_np, w_np, b_np, c_np - - a_np, w_np, b_np, c_np = get_ref_data() - with tvm.target.create(device): - B = decl_winograd(A, W, stride, padding, dtype) - s = schedule_winograd([B]) +def nchwc_to_nchw_kernel(kernel): + n, c, h, w, ic, oc = kernel.shape + in_channels = c * ic + out_channels = n * oc + ret = np.zeros((out_channels, in_channels, h, w)) + for i in range(out_channels): + for j in range(in_channels): + ret[i, j] = kernel[i//oc, j//ic, :, :, j%ic, i%oc] + return ret - output_np = np.zeros(util.get_const_tuple(output.shape), dtype=dtype) - - ctx = tvm.context(device, 0) - a = tvm.nd.array(a_np, ctx) - w = tvm.nd.array(w_np, ctx) - b = tvm.nd.array(np.zeros(util.get_const_tuple(B.shape), dtype=B.dtype), ctx) - output_tvm = tvm.nd.array(output_np, ctx) - - with tvm.build_config(auto_unroll_max_step=500, - unroll_explicit=True): - func = tvm.build(s, [A, W, B], device) - func(a, w, b) - #print(tvm.lower(s, [A, W, B], simple_mode=True)) - num_runs = 100 - timer = func.time_evaluator(func.entry_name, ctx, number=num_runs) - np.testing.assert_allclose(nchwc_to_nchw(b.asnumpy()), b_np, rtol=1e-5) - return timer(a, w, b).mean def test_winograd_without_filter_transform(batch, in_channel, in_size, num_filter, kernel, stride, padding, device): in_height = in_width = in_size - bna, bnb = 8, 8 A = tvm.placeholder((batch, in_channel // bna, in_height, in_width, bna), name='A') W = tvm.placeholder((num_filter, in_channel, kernel, kernel), name='W') - U = tvm.placeholder((in_channel // bna, num_filter // bna, 4, 4, bna, bna), name='U') + U = tvm.placeholder((num_filter // bna, 6, 6, in_channel, bna), name='U') a_shape = util.get_const_tuple(A.shape) w_shape = util.get_const_tuple(W.shape) @@ -684,23 +363,25 @@ def get_ref_data(): u = tvm.nd.array(u_np, ctx) b = tvm.nd.array(np.zeros(util.get_const_tuple(B.shape), dtype=B.dtype), ctx) - with tvm.build_config(auto_unroll_max_step=500, + with tvm.build_config(auto_unroll_max_step=100, unroll_explicit=True): func = tvm.build(s, [A, U, B], device) func(a, u, b) - #print(tvm.lower(s, [A, W, B], simple_mode=True)) - num_runs = 100 + #print(tvm.lower(s, [A, U, B], simple_mode=True)) + # with open("wino.s", "w") as fo: + # fo.write(func.get_source("asm")) + num_runs = 1000 timer = func.time_evaluator(func.entry_name, ctx, number=num_runs) np.testing.assert_allclose(nchwc_to_nchw(b.asnumpy()), b_np, rtol=1e-5) return timer(a, u, b).mean - + # for copy paste as markdown -def generate_table(workloads, wino_times, direct_times): - print("| (batch,CI,size,CO) | TVM Winograd (This code) | TVM Direct |") +def generate_table(workloads, wino_times, direct_times, nchwc_times): + print("| (batch,CI,size,CO) | Winograd | NCHW | NCHWc") print("|------------- |:-------------:|:-------------:|") - for (workload, t_wino, t_direct) in zip(workloads, wino_times, direct_times): - print("|", workload, "| %.3f | %.3f |" % (t_wino, t_direct)) + for (workload, t_wino, t_direct, t_nchwc) in zip(workloads, wino_times, direct_times, nchwc_times): + print("|", workload, "| %.3f | %.3f | %.3f | " % (t_wino, t_direct, t_nchwc)) workloads1 = [(1, 32, 128, 16), (1, 16, 128, 8), @@ -710,7 +391,6 @@ def generate_table(workloads, wino_times, direct_times): (1, 32, 64, 64), (1, 64, 32, 64), (1, 64, 16, 64), - (1, 64, 8, 64), (1, 128, 16, 64), (1, 128, 32, 64), (1, 96, 64, 32), @@ -718,28 +398,37 @@ def generate_table(workloads, wino_times, direct_times): (1, 16, 128, 16) ] +vgg_workloads = [(1, 64, 224, 64), #relu, input and output transform slow + (1, 64, 112, 128),#relu2 + (1, 128, 112, 128), + (1, 128, 56, 256), + (1, 256, 56, 256), #relu4 + (1, 256, 28, 512), + (1, 512, 28, 512), # relu6 + (1, 512, 14, 512)] # relu7 + workloads2 = [(workload[0] * 10, *workload[1:]) for workload in workloads1] # 10 x 128 x 128 workloads3 = [(workload[0], workload[1], workload[2] * 2, workload[3]) for workload in workloads1] # 1 x 256 x 256 workloads4 = [(workload[0] * 10, workload[1], workload[2] * 2, workload[3]) for workload in workloads1] # 10 x 256 x 256 + wino_times = [] direct_times = [] -device = "llvm" -workloads = workloads1 +nchwc_times = [] +device = "llvm -mcpu=core-avx2" +workloads = workloads3 for workload in workloads: - times = test_components(*workload, 3, 1, 1, device) + t_nchwc = reference_direct_NCHWc(*workload, 3, 1, 1, device) + nchwc_times.append(t_nchwc * 1000) + t_wino = test_winograd_without_filter_transform(*workload, 3, 1, 1, device) - wino_times.append(t_wino * 1000) + wino_times.append(t_wino * 1000) + t_direct = reference_direct(*workload, 3, 1, 1, device) direct_times.append(t_direct * 1000) - print("Workload: ", workload) - for (k,v) in times.items(): - print("%s: %f" % (k, v)) - print("Total: %f" % np.sum(list(times.values()))) - print("Wino time: ", wino_times[-1]) - print("Direct: %f\n" % direct_times[-1]) - + print("Wino time: ", wino_times[-1]) + print("NCHWc time: ", nchwc_times[-1]) -generate_table(workloads, wino_times, direct_times) +generate_table(workloads, wino_times, direct_times, nchwc_times) diff --git a/wino_test_cpu_4x4.py b/wino_test_cpu_4x4.py deleted file mode 100644 index cc4ef3a..0000000 --- a/wino_test_cpu_4x4.py +++ /dev/null @@ -1,521 +0,0 @@ -import os -import numpy as np -import tvm -import topi -import topi.testing -from tvm.contrib.pickle_memoize import memoize -from topi import util -from topi.nn import pad - -def reference_direct(batch, in_channel, in_size, num_filter, kernel, stride, padding, device): - in_height = in_width = in_size - - A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A') - W = tvm.placeholder((num_filter, in_channel, kernel, kernel), name='W') - - a_shape = util.get_const_tuple(A.shape) - w_shape = util.get_const_tuple(W.shape) - dtype = A.dtype - - @memoize("topi.tests.test_topi_conv2d_nchw.reference_direct") - def get_ref_data(): - a_np = np.random.uniform(size=a_shape).astype(dtype) - w_np = np.random.uniform(size=w_shape).astype(dtype) - b_np = topi.testing.conv2d_nchw_python(a_np, w_np, stride, padding) - c_np = np.maximum(b_np, 0) - return a_np, w_np, b_np, c_np - - a_np, w_np, b_np, c_np = get_ref_data() - - ctx = tvm.context(device, 0) - if not ctx.exist: - print("Skip because %s is not enabled" % device) - return - with tvm.target.create(device): - B = topi.nn.conv2d(A, W, stride, padding, layout='NCHW') - s1 = topi.generic.schedule_conv2d_nchw([B]) - a = tvm.nd.array(a_np, ctx) - w = tvm.nd.array(w_np, ctx) - b = tvm.nd.array(np.zeros(util.get_const_tuple(B.shape), dtype=B.dtype), ctx) - with tvm.build_config(auto_unroll_max_step=500, - unroll_explicit=True): - func = tvm.build(s1, [A, W, B], device) - #print(tvm.lower(s1, [A, W, B], simple_mode=True)) - func(a, w, b) - np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) - num_runs = 100 - timer = func.time_evaluator(func.entry_name, ctx, number=num_runs) - return timer(a, w, b).mean - -def const_array(data, name): - """ convert an const array to tvm tensor""" - row, col = data.shape - dtype = str(data.dtype) - - def select_array(i, j): - now = tvm.const(0.0, dtype) - for ii in range(row): - for jj in range(col): - now = tvm.select(tvm.all(i % row == ii, j % col == jj), - tvm.const(data[ii][jj], dtype), - now) - return now - return tvm.compute(data.shape, select_array, name=name) - -def decl_V(data, kernel, stride, padding, out_dtype): - """declare winograd fast convolution F(2x2, 3x3) for conv2d""" - N, co, H, W, ci = [util.get_const_int(x) for x in data.shape] - K, C, KH, KW = [util.get_const_int(x) for x in kernel.shape] - HPAD, WPAD = 1,1 - if isinstance(stride, (tuple, list)): - HSTR, WSTR = stride - else: - HSTR, WSTR = stride, stride - - assert HSTR == 1 and WSTR == 1 and HPAD == 1 and WPAD == 1 - - B_data = np.array([ - [4, 0, 0, 0, 0, 0], - [0, -4, 4, -2, 2, 4], - [-5, -4, -4, -1, -1, 0], - [0, 1, -1, 2, -2, -5], - [1, 1, 1, 1, 1, 0], - [0, 0, 0, 0, 0, 1], - ], out_dtype) - m = 4 - r = 3 - alpha = m + r - 1 - nH, nW = (H + m-1) // m, (W + m-1) // m - P = N * nH * nW - bna, bnb = 8, 8 - - data_pad = pad(data, (0, 0, HPAD, WPAD, 0), name="data_pad") - - # transform image - B = const_array(B_data, 'B') - r_eps = tvm.reduce_axis((0, alpha), 'r_eps') - r_nu = tvm.reduce_axis((0, alpha), 'r_nu') - V = tvm.compute((P // bnb, C // bna, alpha, alpha, bnb, bna), lambda b, c, eps, nu, bb, cc: - tvm.sum(data_pad[(b*bnb+bb) // (nH*nW)][c][(b*bnb+bb) // nW % nH * m + r_eps][(b*bnb+bb) % nW * m + r_nu][cc] * B[r_eps][eps] * B[r_nu][nu], - axis=[r_eps, r_nu]), name='V') - - outs = [V] - s = tvm.create_schedule([x.op for x in outs]) - op = outs[0].op - V = op.output(0) - data_pad, B = s[V].op.input_tensors - s[data_pad].compute_inline() - s[B].compute_inline() - b, c, eps, nu, bb, cc = s[V].op.axis - r_eps, r_nu = s[V].op.reduce_axis - s[V].reorder(b, c, bb, cc, eps, nu, r_nu, r_eps) - s[V].vectorize(cc) - _ = [s[V].unroll(x) for x in [eps, nu, r_eps, r_nu]] - fused = s[V].fuse(b, c, bb) - s[V].parallel(fused) - - return V, s - -def decl_M(data, kernel, U, V, stride, padding, out_dtype): - """declare winograd fast convolution F(2x2, 3x3) for conv2d""" - N, co, H, W, ci = [util.get_const_int(x) for x in data.shape] - K, C, KH, KW = [util.get_const_int(x) for x in kernel.shape] - HPAD, WPAD = 1,1 - if isinstance(stride, (tuple, list)): - HSTR, WSTR = stride - else: - HSTR, WSTR = stride, stride - - assert HSTR == 1 and WSTR == 1 and HPAD == 1 and WPAD == 1 - - m = 4 - r = 3 - alpha = m + r - 1 - nH, nW = (H + m-1) // m, (W + m-1) // m - P = N * nH * nW - bna, bnb = 8, 8 - - # batch gemm - c = tvm.reduce_axis((0, C), name='c') - M = tvm.compute((P //bnb, K // bna, alpha, alpha, bnb, bna), lambda b, k, eps, nu, bb, kk: - tvm.sum(V[b][c // bna][eps][nu][bb][c % bna] * - U[c // bna][k][eps][nu][c % bna][kk], axis=c), name='M') - - outs = [M] - s = tvm.create_schedule([x.op for x in outs]) - op = outs[0].op - M = op.output(0) - b, k, eps, nu, bb, kk = s[M].op.axis - c = s[M].op.reduce_axis[0] - - fused = s[M].fuse(b, k) - s[M].parallel(fused) - co, ci = s[M].split(c, factor=8) - s[M].reorder(co, bb, ci, kk) - s[M].unroll(ci) - s[M].vectorize(kk) - - return M, s - -def decl_output(data, kernel, M, stride, padding, out_dtype): - """declare winograd fast convolution F(2x2, 3x3) for conv2d""" - N, co, H, W, ci = [util.get_const_int(x) for x in data.shape] - K, C, KH, KW = [util.get_const_int(x) for x in kernel.shape] - HPAD, WPAD = 1,1 - if isinstance(stride, (tuple, list)): - HSTR, WSTR = stride - else: - HSTR, WSTR = stride, stride - - assert HSTR == 1 and WSTR == 1 and HPAD == 1 and WPAD == 1 - - A_data = np.array([ - [1, 0, 0, 0], - [1, 1, 1, 1], - [1, -1, 1, -1], - [1, 2, 4, 8], - [1, -2, 4, -8], - [0, 0, 0, 1] - ], out_dtype) - - m = 4 - r = 3 - alpha = m + r - 1 - nH, nW = (H + m-1) // m, (W + m-1) // m - P = N * nH * nW - bna, bnb = 8, 8 - - # inverse transform - A = const_array(A_data, 'A') - r_eps = tvm.reduce_axis((0, alpha), 'r_eps') - r_nu = tvm.reduce_axis((0, alpha), 'r_nu') - output = tvm.compute((N, K // bna, H, W, bna), lambda n, k, h, w, kk: - tvm.sum(M[(n * nH * nW + (h//m) * nW + w//m)//bna][k][r_eps][r_nu][(n * nH * nW + (h//m) * nW + w//m)%bna][kk] * A[r_eps][h % m] * A[r_nu][w % m], - axis=[r_eps, r_nu]), name='output') - - outs = [output] - s = tvm.create_schedule([x.op for x in outs]) - op = outs[0].op - output = op.output(0) - _, A = s[output].op.input_tensors - s[A].compute_inline() - - n, k, h, w, kk = s[output].op.axis - r_eps, r_nu = s[output].op.reduce_axis - ho, hi = s[output].split(h, factor=8) - wo, wi = s[output].split(w, factor=8) - s[output].reorder(n, k, ho, wo, hi, r_eps, r_nu, wi, kk) - s[output].vectorize(kk) - fused = s[output].fuse(n, k, ho, wo) - s[output].parallel(fused) - _ = [s[output].unroll(x) for x in [r_eps, r_nu]] - - return output, s - -def decl_winograd_without_filter_transform(data, U, stride, padding, out_dtype): - N, co, H, W, ci = [util.get_const_int(x) for x in data.shape] - co, ko, _, _, ci, ki = [util.get_const_int(x) for x in U.shape] - C = co * ci - K = ko * ki - HPAD, WPAD = 1,1 - if isinstance(stride, (tuple, list)): - HSTR, WSTR = stride - else: - HSTR, WSTR = stride, stride - - assert HSTR == 1 and WSTR == 1 and HPAD == 1 and WPAD == 1 - data_pad = pad(data, (0, 0, HPAD, WPAD, 0), name="data_pad") - - B_data = np.array([ - [4, 0, 0, 0, 0, 0], - [0, -4, 4, -2, 2, 4], - [-5, -4, -4, -1, -1, 0], - [0, 1, -1, 2, -2, -5], - [1, 1, 1, 1, 1, 0], - [0, 0, 0, 0, 0, 1], - ], out_dtype) - - A_data = np.array([ - [1, 0, 0, 0], - [1, 1, 1, 1], - [1, -1, 1, -1], - [1, 2, 4, 8], - [1, -2, 4, -8], - [0, 0, 0, 1] - ], out_dtype) - - m = 4 - r = 3 - alpha = m + r - 1 - nH, nW = (H + m-1) // m, (W + m-1) // m - P = N * nH * nW - bna, bnb = 8, 8 - - # transform image - B = const_array(B_data, 'B') - r_eps = tvm.reduce_axis((0, alpha), 'r_eps') - r_nu = tvm.reduce_axis((0, alpha), 'r_nu') - V = tvm.compute((P // bnb, C // bna, alpha, alpha, bnb, bna), lambda b, c, eps, nu, bb, cc: - tvm.sum(data_pad[(b*bnb+bb) // (nH*nW)][c][(b*bnb+bb) // nW % nH * m + r_eps][(b*bnb+bb) % nW * m + r_nu][cc] * B[r_eps][eps] * B[r_nu][nu], - axis=[r_eps, r_nu]), name='V') - - # batch gemm - c = tvm.reduce_axis((0, C), name='c') - M = tvm.compute((P //bnb, K // bna, alpha, alpha, bnb, bna), lambda b, k, eps, nu, bb, kk: - tvm.sum(V[b][c // bna][eps][nu][bb][c % bna] * - U[c // bna][k][eps][nu][c % bna][kk], axis=c), name='M') - - # inverse transform - A = const_array(A_data, 'A') - r_eps = tvm.reduce_axis((0, alpha), 'r_eps') - r_nu = tvm.reduce_axis((0, alpha), 'r_nu') - output = tvm.compute((N, K // bna, H, W, bna), lambda n, k, h, w, kk: - tvm.sum(M[(n * nH * nW + (h//m) * nW + w//m)//bna][k][r_eps][r_nu][(n * nH * nW + (h//m) * nW + w//m)%bna][kk] * A[r_eps][h % m] * A[r_nu][w % m], - axis=[r_eps, r_nu]), name='output') - - return output - -def schedule_winograd_without_filter_transform(outs): - s = tvm.create_schedule([x.op for x in outs]) - op = outs[0].op - output = op.output(0) - M, A = s[output].op.input_tensors - V, U = s[M].op.input_tensors - data_pad, B = s[V].op.input_tensors - data = s[data_pad].op.input_tensors[0] - - # transform image - s[data_pad].compute_inline() - s[B].compute_inline() - b, c, eps, nu, bb, cc = s[V].op.axis - r_eps, r_nu = s[V].op.reduce_axis - s[V].reorder(b, c, bb, cc, eps, nu, r_nu, r_eps) - s[V].vectorize(cc) - _ = [s[V].unroll(x) for x in [eps, nu, r_eps, r_nu]] - fused = s[V].fuse(b, c) - s[V].parallel(fused) - - # batch gemm - b, k, eps, nu, bb, kk = s[M].op.axis - c = s[M].op.reduce_axis[0] - fused = s[M].fuse(b, k) - s[M].parallel(fused) - co, ci = s[M].split(c, factor=8) - s[M].reorder(co, bb, ci, kk) - s[M].unroll(ci) - s[M].vectorize(kk) - -# # inverse transform - s[A].compute_inline() - n, k, h, w, kk = s[output].op.axis - r_eps, r_nu = s[output].op.reduce_axis - ho, hi = s[output].split(h, factor=8) - wo, wi = s[output].split(w, factor=8) - s[output].reorder(n, k, ho, wo, hi, r_eps, r_nu, wi, kk) - s[output].vectorize(kk) - fused = s[output].fuse(n, k, ho, wo) - s[output].parallel(fused) - _ = [s[output].unroll(x) for x in [r_eps, r_nu]] - - return s - -def transform_filter(w_np): - num_filter, in_channel, kernel, kernel = w_np.shape - G = np.array([ - [1 / 4.0, 0, 0], - [-1 / 6.0, -1 / 6.0, -1 / 6.0], - [-1 / 6.0, 1 / 6.0, -1 / 6.0], - [1 / 24.0, 1 / 12.0, 1 / 6.0], - [1 / 24.0, -1 / 12.0, 1 / 6.0], - [0, 0, 1] - ], dtype=np.float32) - bna = 8 - out = np.empty((in_channel // bna, num_filter // bna, 6, 6, bna, bna), w_np.dtype) - for i in range(in_channel): - for j in range(num_filter): - out[i // bna, j // bna, :, :, i % bna, j % bna] = np.dot(G, np.dot(w_np[j, i], G.transpose())) - return out - -def nchwc_to_nchw(arr): - n, c, h, w, cc = arr.shape - channels = c * cc - ret = np.zeros((n, channels, h, w)) - for i in range(channels): - ret[:, i] = arr[:, i//cc, :, :, i%cc] - return ret - -def test_components(batch, in_channel, in_size, num_filter, kernel, stride, padding, device): - in_height = in_width = in_size - m = 4 - r = 3 - alpha = m + r - 1 - K = num_filter - H = W = in_size - N = batch - C = in_channel - nH, nW = (H + m-1) // m, (W + m-1) // m - P = N * nH * nW - bna, bnb = 8, 8 - - A = tvm.placeholder((batch, in_channel // bna, in_height, in_width, bna), name='A') - W = tvm.placeholder((num_filter, in_channel, kernel, kernel), name='W') - U = tvm.placeholder((C // bna, K // bna, alpha, alpha, bna, bna), name='U') - V = tvm.placeholder((P // bnb, C // bna, alpha, alpha, bnb, bna), name='V') - M = tvm.placeholder((P // bnb, K // bna, alpha, alpha, bnb, bna), name='M') - - output = tvm.placeholder((N, K // bna, in_size, in_size, bna), name='output') - - a_shape = util.get_const_tuple(A.shape) - w_shape = util.get_const_tuple(W.shape) - dtype = A.dtype - dilation = 1 - - @memoize("topi.tests.test_topi_conv2d_nchw.wino") - def get_ref_data(): - a_np = np.random.uniform(size=a_shape).astype(dtype) - w_np = np.random.uniform(size=w_shape).astype(dtype) - dw_np = topi.testing.dilate_python(w_np, (1, 1, dilation, dilation)) - b_np = topi.testing.conv2d_nchw_python(nchwc_to_nchw(a_np), dw_np, stride, padding) - c_np = np.maximum(b_np, 0) - return a_np, w_np, b_np, c_np - - a_np, w_np, b_np, c_np = get_ref_data() - u_np = transform_filter(w_np) - v_np = np.random.uniform(size=util.get_const_tuple(V.shape)).astype(dtype) - m_np = np.zeros(util.get_const_tuple(M.shape), dtype=dtype) - output_np = np.zeros(util.get_const_tuple(output.shape), dtype=dtype) - - with tvm.target.create(device): - V_out, s_V = decl_V(A, W, stride, padding, dtype) - M_out, s_M = decl_M(A, W, U, V, stride, padding, dtype) - output_out, s_output = decl_output(A, W, M, stride, padding, dtype) - - ctx = tvm.context(device, 0) - a = tvm.nd.array(a_np, ctx) - w = tvm.nd.array(w_np, ctx) - u = tvm.nd.array(u_np, ctx) - v = tvm.nd.array(v_np, ctx) - m = tvm.nd.array(m_np, ctx) - output_tvm = tvm.nd.array(output_np, ctx) - num_runs = 100 - times = {} - - with tvm.build_config(auto_unroll_max_step=500, - unroll_explicit=True): - func_input_transform = tvm.build(s_V, [A, V_out], device) - func_input_transform(a, v) - timer = func_input_transform.time_evaluator(func_input_transform.entry_name, ctx, number=num_runs) - times["V"] = timer(a, v).mean * 1000 - #print(tvm.lower(s_V, [A, V_out], simple_mode=True)) - - func_batch_mm = tvm.build(s_M, [U, V, M_out], device) - #print(tvm.lower(s_M, [U, V, M_out], simple_mode=True)) - func_batch_mm(u, v, m) - - timer = func_batch_mm.time_evaluator(func_batch_mm.entry_name, ctx, number=num_runs) - times["M"] = timer(u, v, m).mean * 1000 - #print(tvm.lower(s_M, [A, W, U, V, M], simple_mode=True)) - - func_inverse_transform = tvm.build(s_output, [M, output_out], device) - func_inverse_transform(m, output_tvm) - timer = func_inverse_transform.time_evaluator(func_inverse_transform.entry_name, ctx, number=num_runs) - times["output"] = timer(m, output_tvm).mean * 1000 - #print(tvm.lower(s_output, [A, W, M, output], simple_mode=True)) - - np.testing.assert_allclose(nchwc_to_nchw(output_tvm.asnumpy()), b_np, rtol=1e-5) - return times - -def test_winograd_without_filter_transform(batch, in_channel, in_size, num_filter, kernel, stride, padding, device): - in_height = in_width = in_size - bna, bnb = 8, 8 - - A = tvm.placeholder((batch, in_channel // bna, in_height, in_width, bna), name='A') - W = tvm.placeholder((num_filter, in_channel, kernel, kernel), name='W') - U = tvm.placeholder((in_channel // bna, num_filter // bna, 6, 6, bna, bna), name='U') - - a_shape = util.get_const_tuple(A.shape) - w_shape = util.get_const_tuple(W.shape) - dtype = A.dtype - dilation = 1 - - output = tvm.placeholder((batch, num_filter//bna, in_size, in_size, bna), name='output') - - @memoize("topi.tests.test_topi_conv2d_nchw.wino") - def get_ref_data(): - a_np = np.random.uniform(size=a_shape).astype(dtype) - w_np = np.random.uniform(size=w_shape).astype(dtype) - dw_np = topi.testing.dilate_python(w_np, (1, 1, dilation, dilation)) - b_np = topi.testing.conv2d_nchw_python(nchwc_to_nchw(a_np), dw_np, stride, padding) - c_np = np.maximum(b_np, 0) - return a_np, w_np, b_np, c_np - - a_np, w_np, b_np, c_np = get_ref_data() - - with tvm.target.create(device): - B = decl_winograd_without_filter_transform(A, U, stride, padding, dtype) - s = schedule_winograd_without_filter_transform([B]) - - u_np = transform_filter(w_np) - ctx = tvm.context(device, 0) - a = tvm.nd.array(a_np, ctx) - u = tvm.nd.array(u_np, ctx) - b = tvm.nd.array(np.zeros(util.get_const_tuple(B.shape), dtype=B.dtype), ctx) - - with tvm.build_config(auto_unroll_max_step=500, - unroll_explicit=True): - func = tvm.build(s, [A, U, B], device) - func(a, u, b) - #print(tvm.lower(s, [A, W, B], simple_mode=True)) - num_runs = 100 - timer = func.time_evaluator(func.entry_name, ctx, number=num_runs) - np.testing.assert_allclose(nchwc_to_nchw(b.asnumpy()), b_np, rtol=1e-5) - return timer(a, u, b).mean - - -# for copy paste as markdown -def generate_table(workloads, wino_times, direct_times): - print("| (batch,CI,size,CO) | TVM Winograd (This code) | TVM Direct |") - print("|------------- |:-------------:|:-------------:|") - for (workload, t_wino, t_direct) in zip(workloads, wino_times, direct_times): - print("|", workload, "| %.3f | %.3f |" % (t_wino, t_direct)) - -workloads1 = [(1, 32, 128, 16), - (1, 16, 128, 8), - (1, 8, 128, 16), - (1, 16, 128, 32), - (1, 32, 64, 32), - (1, 32, 64, 64), - (1, 64, 32, 64), - (1, 64, 16, 64), - (1, 128, 16, 64), - (1, 128, 32, 64), - (1, 96, 64, 32), - (1, 40, 128, 16), - (1, 16, 128, 16) - ] - -workloads2 = [(workload[0] * 10, *workload[1:]) for workload in workloads1] # 10 x 128 x 128 -workloads3 = [(workload[0], workload[1], workload[2] * 2, workload[3]) for workload in workloads1] # 1 x 256 x 256 -workloads4 = [(workload[0] * 10, workload[1], workload[2] * 2, workload[3]) for workload in workloads1] # 10 x 256 x 256 - -wino_times = [] -direct_times = [] -device = "llvm" -workloads = workloads1 - -for workload in workloads: - times = test_components(*workload, 3, 1, 1, device) - t_wino = test_winograd_without_filter_transform(*workload, 3, 1, 1, device) - wino_times.append(t_wino * 1000) - t_direct = reference_direct(*workload, 3, 1, 1, device) - direct_times.append(t_direct * 1000) - - print("Workload: ", workload) - for (k,v) in times.items(): - print("%s: %f" % (k, v)) - print("Total: %f" % np.sum(list(times.values()))) - print("Wino time: ", wino_times[-1]) - print("Direct: %f\n" % direct_times[-1]) - - -generate_table(workloads, wino_times, direct_times) diff --git a/wino_test_cpu_4x4_minimal.py b/wino_test_cpu_4x4_minimal.py deleted file mode 100644 index 17e6403..0000000 --- a/wino_test_cpu_4x4_minimal.py +++ /dev/null @@ -1,672 +0,0 @@ -import os -import numpy as np -import tvm -import topi -import topi.testing -from tvm.contrib.pickle_memoize import memoize -from topi import util -from topi.nn import pad - -def reference_direct(batch, in_channel, in_size, num_filter, kernel, stride, padding, device): - in_height = in_width = in_size - - A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A') - W = tvm.placeholder((num_filter, in_channel, kernel, kernel), name='W') - - a_shape = util.get_const_tuple(A.shape) - w_shape = util.get_const_tuple(W.shape) - dtype = A.dtype - - @memoize("topi.tests.test_topi_conv2d_nchw.reference_direct") - def get_ref_data(): - a_np = np.random.uniform(size=a_shape).astype(dtype) - w_np = np.random.uniform(size=w_shape).astype(dtype) - b_np = topi.testing.conv2d_nchw_python(a_np, w_np, stride, padding) - c_np = np.maximum(b_np, 0) - return a_np, w_np, b_np, c_np - - a_np, w_np, b_np, c_np = get_ref_data() - - ctx = tvm.context(device, 0) - if not ctx.exist: - print("Skip because %s is not enabled" % device) - return - with tvm.target.create(device): - B = topi.nn.conv2d(A, W, stride, padding, layout='NCHW') - s1 = topi.generic.schedule_conv2d_nchw([B]) - a = tvm.nd.array(a_np, ctx) - w = tvm.nd.array(w_np, ctx) - b = tvm.nd.array(np.zeros(util.get_const_tuple(B.shape), dtype=B.dtype), ctx) - with tvm.build_config(auto_unroll_max_step=500, - unroll_explicit=True): - func = tvm.build(s1, [A, W, B], device) - func(a, w, b) - np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) - num_runs = 100 - timer = func.time_evaluator(func.entry_name, ctx, number=num_runs) - return timer(a, w, b).mean - -def decl_V(data, kernel, stride, padding, out_dtype): - """declare winograd fast convolution F(2x2, 3x3) for conv2d""" - N, co, H, W, ci = [util.get_const_int(x) for x in data.shape] - K, C, KH, KW = [util.get_const_int(x) for x in kernel.shape] - HPAD, WPAD = 1,1 - if isinstance(stride, (tuple, list)): - HSTR, WSTR = stride - else: - HSTR, WSTR = stride, stride - - assert HSTR == 1 and WSTR == 1 and HPAD == 1 and WPAD == 1 - - m = 4 - r = 3 - alpha = m + r - 1 - nH, nW = (H + m-1) // m, (W + m-1) // m - P = N * nH * nW - bna, bnb = 8, 8 - - data_pad = pad(data, (0, 0, HPAD, WPAD, 0), name="data_pad") - # transform image - def compute_temp(b, c, eps, nu, bb, cc): - temp_expr = {} - batch_index = (b*bnb+bb) // (nH*nW) - h = (b*bnb+bb) // nW % nH * m - w = (b*bnb+bb) % nW * m - for j in range(6): - t0 = data_pad[batch_index][c][h+4][w+j][cc] - data_pad[batch_index][c][h+2][w+j][cc]*4.0 - t1 = data_pad[batch_index][c][h+3][w+j][cc] - data_pad[batch_index][c][h+1][w+j][cc]*4.0 - t2 = data_pad[batch_index][c][h+4][w+j][cc] - data_pad[batch_index][c][h+2][w+j][cc] - t3 = data_pad[batch_index][c][h+3][w+j][cc] - data_pad[batch_index][c][h+1][w+j][cc] - - temp_expr[(0, j)] = data_pad[batch_index][c][h+0][w+j][cc] * 4.0 - data_pad[batch_index][c][h+2][w+j][cc] * 5.0 + data_pad[batch_index][c][h+4][w+j][cc] - temp_expr[(1, j)] = t0 + t1 - temp_expr[(2, j)] = t0 - t1 - temp_expr[(3, j)] = t2 + t3*2.0 - temp_expr[(4, j)] = t2 - t3*2.0 - temp_expr[(5, j)] = data_pad[batch_index][c][h+1][w+j][cc] * 4.0 - data_pad[batch_index][c][h+3][w+j][cc] * 5.0 + data_pad[batch_index][c][h+5][w+j][cc] - - now = tvm.const(0.0, "float32") - for ii in range(alpha): - for jj in range(alpha): - now = tvm.select(tvm.all(eps == ii, nu == jj), - temp_expr[(ii, jj)], - now) - return now - - temp = tvm.compute((P // bnb, C // bna, alpha, alpha, bnb, bna), compute_temp, name="temp") - - def compute_V(b, c, eps, nu, bb, cc): - v_expr = {} - for i in range(6): - t0 = temp[b][c][i][4][bb][cc] - temp[b][c][i][2][bb][cc]*4.0 - t1 = temp[b][c][i][3][bb][cc] - temp[b][c][i][1][bb][cc]*4.0 - t2 = temp[b][c][i][4][bb][cc] - temp[b][c][i][2][bb][cc] - t3 = temp[b][c][i][3][bb][cc] - temp[b][c][i][1][bb][cc] - v_expr[(i, 0)] = temp[b][c][i][0][bb][cc] * 4.0 - temp[b][c][i][2][bb][cc] * 5.0 + temp[b][c][i][4][bb][cc] - v_expr[(i, 1)] = t0 + t1 - v_expr[(i, 2)] = t0 - t1 - v_expr[(i, 3)] = t2 + t3*2.0 - v_expr[(i, 4)] = t2 - t3*2.0 - v_expr[(i, 5)] = temp[b][c][i][1][bb][cc] * 4.0 - temp[b][c][i][3][bb][cc] * 5.0 + temp[b][c][i][5][bb][cc] - - now = tvm.const(0.0, "float32") - for ii in range(6): - for jj in range(6): - now = tvm.select(tvm.all(eps == ii, nu == jj), - v_expr[(ii, jj)], - now) - return now - - V = tvm.compute((P // bnb, C // bna, alpha, alpha, bnb, bna), compute_V) - - outs = [V] - s = tvm.create_schedule([x.op for x in outs]) - op = outs[0].op - V = op.output(0) - temp = s[V].op.input_tensors[0] - data_pad = s[temp].op.input_tensors[0] - s[data_pad].compute_inline() - - b, c, eps, nu, bb, cc = s[V].op.axis - s[V].reorder(b, c, bb, eps, nu, cc) - s[V].vectorize(cc) - _ = [s[V].unroll(x) for x in [eps, nu]] - - fused = s[V].fuse(b, c, bb) - s[V].parallel(fused) - - b, c, eps, nu, bb, cc = s[temp].op.axis - s[temp].reorder(b, c, bb, eps, nu, cc) - s[temp].vectorize(cc) - _ = [s[temp].unroll(x) for x in [eps, nu]] - s[temp].compute_at(s[V], fused) - - return V, s - -def decl_M(data, kernel, U, V, stride, padding, out_dtype): - """declare winograd fast convolution F(2x2, 3x3) for conv2d""" - N, co, H, W, ci = [util.get_const_int(x) for x in data.shape] - K, C, KH, KW = [util.get_const_int(x) for x in kernel.shape] - HPAD, WPAD = 1,1 - if isinstance(stride, (tuple, list)): - HSTR, WSTR = stride - else: - HSTR, WSTR = stride, stride - - assert HSTR == 1 and WSTR == 1 and HPAD == 1 and WPAD == 1 - - m = 4 - r = 3 - alpha = m + r - 1 - nH, nW = (H + m-1) // m, (W + m-1) // m - P = N * nH * nW - bna, bnb = 8, 8 - - # batch gemm - c = tvm.reduce_axis((0, C), name='c') - M = tvm.compute((P //bnb, K // bna, alpha, alpha, bnb, bna), lambda b, k, eps, nu, bb, kk: - tvm.sum(V[b][c // bna][eps][nu][bb][c % bna] * - U[c // bna][k][eps][nu][c % bna][kk], axis=c), name='M') - - outs = [M] - s = tvm.create_schedule([x.op for x in outs]) - op = outs[0].op - M = op.output(0) - b, k, eps, nu, bb, kk = s[M].op.axis - c = s[M].op.reduce_axis[0] - - fused = s[M].fuse(b, k) - s[M].parallel(fused) - co, ci = s[M].split(c, factor=8) - s[M].reorder(co, bb, ci, kk) - s[M].unroll(ci) - s[M].vectorize(kk) - - return M, s - -def decl_output(data, kernel, M, stride, padding, out_dtype): - """declare winograd fast convolution F(2x2, 3x3) for conv2d""" - N, co, H, W, ci = [util.get_const_int(x) for x in data.shape] - K, C, KH, KW = [util.get_const_int(x) for x in kernel.shape] - HPAD, WPAD = 1,1 - if isinstance(stride, (tuple, list)): - HSTR, WSTR = stride - else: - HSTR, WSTR = stride, stride - - assert HSTR == 1 and WSTR == 1 and HPAD == 1 and WPAD == 1 - - m = 4 - r = 3 - alpha = m + r - 1 - nH, nW = (H + m-1) // m, (W + m-1) // m - P = N * nH * nW - bna, bnb = 8, 8 - - def compute_temp(tile_index, k, eps, nu, kk): - b = tile_index // bnb - bb = tile_index % bnb - temp_expr = {} - for j in range(6): - t0 = M[b][k][1][j][bb][kk] + M[b][k][2][j][bb][kk] - t1 = M[b][k][3][j][bb][kk] + M[b][k][4][j][bb][kk] - t2 = M[b][k][1][j][bb][kk] - M[b][k][2][j][bb][kk] - t3 = M[b][k][3][j][bb][kk] - M[b][k][4][j][bb][kk] - temp_expr[(0, j)] = t0 + t1 + M[b][k][0][j][bb][kk] - temp_expr[(1, j)] = t2 + t3*2.0 - temp_expr[(2, j)] = t0 + t1*4.0 - temp_expr[(3, j)] = t2 + t3*8.0 + M[b][k][5][j][bb][kk] - - now = tvm.const(0.0, "float32") - for ii in range(4): - for jj in range(6): - now = tvm.select(tvm.all(eps == ii, nu == jj), - temp_expr[(ii, jj)], - now) - return now - - temp = tvm.compute((P, K // bna, m, alpha, bna), compute_temp, name="temp") - - def compute_output(n, k, h, w, kk): - b = (n * nH * nW + (h//m) * nW + w//m) - eps = h%m - nu = w%m - output_expr = {} - for i in range(4): - t0 = temp[b][k][i][1][kk] + temp[b][k][i][2][kk] - t1 = temp[b][k][i][3][kk] + temp[b][k][i][4][kk] - t2 = temp[b][k][i][1][kk] - temp[b][k][i][2][kk] - t3 = temp[b][k][i][3][kk] - temp[b][k][i][4][kk] - output_expr[(i, 0)] = t0 + t1 + temp[b][k][i][0][kk] - output_expr[(i, 1)] = t2 + t3 * 2.0 - output_expr[(i, 2)] = t0 + t1 * 4.0 - output_expr[(i, 3)] = t2 + t3 * 8.0 + temp[b][k][i][5][kk] - - now = tvm.const(0.0, "float32") - for ii in range(4): - for jj in range(4): - now = tvm.select(tvm.all(eps == ii, nu == jj), - output_expr[(ii, jj)], - now) - return now - - output = tvm.compute((N, K // bna, H, W, bna), compute_output) - - outs = [output] - s = tvm.create_schedule([x.op for x in outs]) - op = outs[0].op - output = op.output(0) - temp = s[output].op.input_tensors[0] - - n, k, h, w, kk = s[output].op.axis - ho, hi = s[output].split(h, factor=4) - wo, wi = s[output].split(w, factor=4) - s[output].reorder(n, k, ho, wo, hi, wi, kk) - s[output].unroll(hi) - s[output].unroll(wi) - s[output].vectorize(kk) - fused = s[output].fuse(n, k, ho, wo) - s[output].parallel(fused) - - b, k, eps, nu, kk = s[temp].op.axis - s[temp].unroll(eps) - s[temp].unroll(nu) - s[temp].vectorize(kk) - s[temp].compute_at(s[output], fused) - - return output, s - -def decl_V_minimal(data_pad, P, C, alpha, bna, bnb, nH, nW, m): - def compute_temp(b, c, eps, nu, bb, cc): - temp_expr = {} - batch_index = (b*bnb+bb) // (nH*nW) - h = (b*bnb+bb) // nW % nH * m - w = (b*bnb+bb) % nW * m - for j in range(6): - t0 = data_pad[batch_index][c][h+4][w+j][cc] - data_pad[batch_index][c][h+2][w+j][cc]*4.0 - t1 = data_pad[batch_index][c][h+3][w+j][cc] - data_pad[batch_index][c][h+1][w+j][cc]*4.0 - t2 = data_pad[batch_index][c][h+4][w+j][cc] - data_pad[batch_index][c][h+2][w+j][cc] - t3 = data_pad[batch_index][c][h+3][w+j][cc] - data_pad[batch_index][c][h+1][w+j][cc] - - temp_expr[(0, j)] = data_pad[batch_index][c][h+0][w+j][cc] * 4.0 - data_pad[batch_index][c][h+2][w+j][cc] * 5.0 + data_pad[batch_index][c][h+4][w+j][cc] - temp_expr[(1, j)] = t0 + t1 - temp_expr[(2, j)] = t0 - t1 - temp_expr[(3, j)] = t2 + t3*2.0 - temp_expr[(4, j)] = t2 - t3*2.0 - temp_expr[(5, j)] = data_pad[batch_index][c][h+1][w+j][cc] * 4.0 - data_pad[batch_index][c][h+3][w+j][cc] * 5.0 + data_pad[batch_index][c][h+5][w+j][cc] - now = tvm.const(0.0, "float32") - for ii in range(alpha): - for jj in range(alpha): - now = tvm.select(tvm.all(eps == ii, nu == jj), - temp_expr[(ii, jj)], - now) - return now - - temp = tvm.compute((P // bnb, C // bna, alpha, alpha, bnb, bna), compute_temp, name="temp") - - def compute_V(b, c, eps, nu, bb, cc): - v_expr = {} - for i in range(6): - t0 = temp[b][c][i][4][bb][cc] - temp[b][c][i][2][bb][cc]*4.0 - t1 = temp[b][c][i][3][bb][cc] - temp[b][c][i][1][bb][cc]*4.0 - t2 = temp[b][c][i][4][bb][cc] - temp[b][c][i][2][bb][cc] - t3 = temp[b][c][i][3][bb][cc] - temp[b][c][i][1][bb][cc] - v_expr[(i, 0)] = temp[b][c][i][0][bb][cc] * 4.0 - temp[b][c][i][2][bb][cc] * 5.0 + temp[b][c][i][4][bb][cc] - v_expr[(i, 1)] = t0 + t1 - v_expr[(i, 2)] = t0 - t1 - v_expr[(i, 3)] = t2 + t3*2.0 - v_expr[(i, 4)] = t2 - t3*2.0 - v_expr[(i, 5)] = temp[b][c][i][1][bb][cc] * 4.0 - temp[b][c][i][3][bb][cc] * 5.0 + temp[b][c][i][5][bb][cc] - - now = tvm.const(0.0, "float32") - for ii in range(6): - for jj in range(6): - now = tvm.select(tvm.all(eps == ii, nu == jj), - v_expr[(ii, jj)], - now) - return now - - V = tvm.compute((P // bnb, C // bna, alpha, alpha, bnb, bna), compute_V) - return V - -def decl_output_minimal(M, N, K, H, W, P, alpha, bna, bnb, nH, nW, m): - - def compute_temp(tile_index, k, eps, nu, kk): - b = tile_index // bnb - bb = tile_index % bnb - temp_expr = {} - for j in range(6): - t0 = M[b][k][1][j][bb][kk] + M[b][k][2][j][bb][kk] - t1 = M[b][k][3][j][bb][kk] + M[b][k][4][j][bb][kk] - t2 = M[b][k][1][j][bb][kk] - M[b][k][2][j][bb][kk] - t3 = M[b][k][3][j][bb][kk] - M[b][k][4][j][bb][kk] - temp_expr[(0, j)] = t0 + t1 + M[b][k][0][j][bb][kk] - temp_expr[(1, j)] = t2 + t3*2.0 - temp_expr[(2, j)] = t0 + t1*4.0 - temp_expr[(3, j)] = t2 + t3*8.0 + M[b][k][5][j][bb][kk] - - now = tvm.const(0.0, "float32") - for ii in range(4): - for jj in range(6): - now = tvm.select(tvm.all(eps == ii, nu == jj), - temp_expr[(ii, jj)], - now) - return now - - temp = tvm.compute((P, K // bna, m, alpha, bna), compute_temp, name="temp") - - def compute_output(n, k, h, w, kk): - b = (n * nH * nW + (h//m) * nW + w//m) - eps = h%m - nu = w%m - output_expr = {} - for i in range(4): - t0 = temp[b][k][i][1][kk] + temp[b][k][i][2][kk] - t1 = temp[b][k][i][3][kk] + temp[b][k][i][4][kk] - t2 = temp[b][k][i][1][kk] - temp[b][k][i][2][kk] - t3 = temp[b][k][i][3][kk] - temp[b][k][i][4][kk] - output_expr[(i, 0)] = t0 + t1 + temp[b][k][i][0][kk] - output_expr[(i, 1)] = t2 + t3 * 2.0 - output_expr[(i, 2)] = t0 + t1 * 4.0 - output_expr[(i, 3)] = t2 + t3 * 8.0 + temp[b][k][i][5][kk] - - now = tvm.const(0.0, "float32") - for ii in range(4): - for jj in range(4): - now = tvm.select(tvm.all(eps == ii, nu == jj), - output_expr[(ii, jj)], - now) - return now - - output = tvm.compute((N, K // bna, H, W, bna), compute_output) - - return output - -def decl_winograd_without_filter_transform(data, U, stride, padding, out_dtype): - N, co, H, W, ci = [util.get_const_int(x) for x in data.shape] - co, ko, _, _, ci, ki = [util.get_const_int(x) for x in U.shape] - C = co * ci - K = ko * ki - HPAD, WPAD = 1,1 - if isinstance(stride, (tuple, list)): - HSTR, WSTR = stride - else: - HSTR, WSTR = stride, stride - - assert HSTR == 1 and WSTR == 1 and HPAD == 1 and WPAD == 1 - data_pad = pad(data, (0, 0, HPAD, WPAD, 0), name="data_pad") - - m = 4 - r = 3 - alpha = m + r - 1 - nH, nW = (H + m-1) // m, (W + m-1) // m - P = N * nH * nW - bna, bnb = 8, 8 - - V = decl_V_minimal(data_pad, P, C, alpha, bna, bnb, nH, nW, m) - - # batch gemm - c = tvm.reduce_axis((0, C), name='c') - M = tvm.compute((P //bnb, K // bna, alpha, alpha, bnb, bna), lambda b, k, eps, nu, bb, kk: - tvm.sum(V[b][c // bna][eps][nu][bb][c % bna] * - U[c // bna][k][eps][nu][c % bna][kk], axis=c), name='M') - - # inverse transform - output = decl_output_minimal(M, N, K, H, W, P, alpha, bna, bnb, nH, nW, m) - - return output - -def schedule_winograd_without_filter_transform(outs): - s = tvm.create_schedule([x.op for x in outs]) - op = outs[0].op - output = op.output(0) - temp_output_transform = s[output].op.input_tensors[0] - M = s[temp_output_transform].op.input_tensors[0] - V, U = s[M].op.input_tensors - temp_input_transform = s[V].op.input_tensors[0] - data_pad = s[temp_input_transform].op.input_tensors[0] - data = s[data_pad].op.input_tensors[0] - - # transform image - s[data_pad].compute_inline() - b, c, eps, nu, bb, cc = s[V].op.axis - s[V].reorder(b, c, bb, eps, nu, cc) - s[V].vectorize(cc) - _ = [s[V].unroll(x) for x in [eps, nu]] - - fused = s[V].fuse(b, c, bb) - s[V].parallel(fused) - - b, c, eps, nu, bb, cc = s[temp_input_transform].op.axis - s[temp_input_transform].reorder(b, c, bb, eps, nu, cc) - s[temp_input_transform].vectorize(cc) - _ = [s[temp_input_transform].unroll(x) for x in [eps, nu]] - s[temp_input_transform].compute_at(s[V], fused) - - # batch gemm - b, k, eps, nu, bb, kk = s[M].op.axis - c = s[M].op.reduce_axis[0] - fused = s[M].fuse(b, k) - s[M].parallel(fused) - co, ci = s[M].split(c, factor=8) - s[M].reorder(co, bb, ci, kk) - s[M].unroll(ci) - s[M].vectorize(kk) - - # inverse transform - n, k, h, w, kk = s[output].op.axis - ho, hi = s[output].split(h, factor=4) - wo, wi = s[output].split(w, factor=4) - s[output].reorder(n, k, ho, wo, hi, wi, kk) - s[output].unroll(hi) - s[output].unroll(wi) - s[output].vectorize(kk) - fused = s[output].fuse(n, k, ho, wo) - s[output].parallel(fused) - - b, k, eps, nu, kk = s[temp_output_transform].op.axis - s[temp_output_transform].unroll(eps) - s[temp_output_transform].unroll(nu) - s[temp_output_transform].vectorize(kk) - s[temp_output_transform].compute_at(s[output], fused) - - return s - -def transform_filter(w_np): - num_filter, in_channel, kernel, kernel = w_np.shape - G = np.array([ - [1 / 4.0, 0, 0], - [-1 / 6.0, -1 / 6.0, -1 / 6.0], - [-1 / 6.0, 1 / 6.0, -1 / 6.0], - [1 / 24.0, 1 / 12.0, 1 / 6.0], - [1 / 24.0, -1 / 12.0, 1 / 6.0], - [0, 0, 1] - ], dtype=np.float32) - bna = 8 - out = np.empty((in_channel // bna, num_filter // bna, 6, 6, bna, bna), w_np.dtype) - for i in range(in_channel): - for j in range(num_filter): - out[i // bna, j // bna, :, :, i % bna, j % bna] = np.dot(G, np.dot(w_np[j, i], G.transpose())) - return out - -def nchwc_to_nchw(arr): - n, c, h, w, cc = arr.shape - channels = c * cc - ret = np.zeros((n, channels, h, w)) - for i in range(channels): - ret[:, i] = arr[:, i//cc, :, :, i%cc] - return ret - -def test_components(batch, in_channel, in_size, num_filter, kernel, stride, padding, device): - in_height = in_width = in_size - m = 4 - r = 3 - alpha = m + r - 1 - K = num_filter - H = W = in_size - N = batch - C = in_channel - nH, nW = (H + m-1) // m, (W + m-1) // m - P = N * nH * nW - bna, bnb = 8, 8 - - A = tvm.placeholder((batch, in_channel // bna, in_height, in_width, bna), name='A') - W = tvm.placeholder((num_filter, in_channel, kernel, kernel), name='W') - U = tvm.placeholder((C // bna, K // bna, alpha, alpha, bna, bna), name='U') - V = tvm.placeholder((P // bnb, C // bna, alpha, alpha, bnb, bna), name='V') - M = tvm.placeholder((P // bnb, K // bna, alpha, alpha, bnb, bna), name='M') - - output = tvm.placeholder((N, K // bna, in_size, in_size, bna), name='output') - - a_shape = util.get_const_tuple(A.shape) - w_shape = util.get_const_tuple(W.shape) - dtype = A.dtype - dilation = 1 - - @memoize("topi.tests.test_topi_conv2d_nchw.wino") - def get_ref_data(): - a_np = np.random.uniform(size=a_shape).astype(dtype) - w_np = np.random.uniform(size=w_shape).astype(dtype) - dw_np = topi.testing.dilate_python(w_np, (1, 1, dilation, dilation)) - b_np = topi.testing.conv2d_nchw_python(nchwc_to_nchw(a_np), dw_np, stride, padding) - c_np = np.maximum(b_np, 0) - return a_np, w_np, b_np, c_np - - a_np, w_np, b_np, c_np = get_ref_data() - u_np = transform_filter(w_np) - v_np = np.random.uniform(size=util.get_const_tuple(V.shape)).astype(dtype) - m_np = np.zeros(util.get_const_tuple(M.shape), dtype=dtype) - output_np = np.zeros(util.get_const_tuple(output.shape), dtype=dtype) - - with tvm.target.create(device): - V_out, s_V = decl_V(A, W, stride, padding, dtype) - M_out, s_M = decl_M(A, W, U, V, stride, padding, dtype) - output_out, s_output = decl_output(A, W, M, stride, padding, dtype) - - ctx = tvm.context(device, 0) - a = tvm.nd.array(a_np, ctx) - w = tvm.nd.array(w_np, ctx) - u = tvm.nd.array(u_np, ctx) - v = tvm.nd.array(v_np, ctx) - m = tvm.nd.array(m_np, ctx) - output_tvm = tvm.nd.array(output_np, ctx) - num_runs = 1000 - times = {} - - with tvm.build_config(auto_unroll_max_step=500, - unroll_explicit=True): - func_input_transform = tvm.build(s_V, [A, V_out], device) - func_input_transform(a, v) - timer = func_input_transform.time_evaluator(func_input_transform.entry_name, ctx, number=num_runs) - times["V"] = timer(a, v).mean * 1000 - #print(tvm.lower(s_V, [A, V_out], simple_mode=True)) - - func_batch_mm = tvm.build(s_M, [U, V, M_out], device) - #print(tvm.lower(s_M, [U, V, M_out], simple_mode=True)) - func_batch_mm(u, v, m) - - timer = func_batch_mm.time_evaluator(func_batch_mm.entry_name, ctx, number=num_runs) - times["M"] = timer(u, v, m).mean * 1000 - #print(tvm.lower(s_M, [A, W, U, V, M], simple_mode=True)) - func_inverse_transform = tvm.build(s_output, [M, output_out], device) - func_inverse_transform(m, output_tvm) - timer = func_inverse_transform.time_evaluator(func_inverse_transform.entry_name, ctx, number=num_runs) - times["output"] = timer(m, output_tvm).mean * 1000 - #print(tvm.lower(s_output, [A, W, M, output], simple_mode=True)) - - np.testing.assert_allclose(nchwc_to_nchw(output_tvm.asnumpy()), b_np, rtol=1e-5) - return times - -def test_winograd_without_filter_transform(batch, in_channel, in_size, num_filter, kernel, stride, padding, device): - in_height = in_width = in_size - bna, bnb = 8, 8 - - A = tvm.placeholder((batch, in_channel // bna, in_height, in_width, bna), name='A') - W = tvm.placeholder((num_filter, in_channel, kernel, kernel), name='W') - U = tvm.placeholder((in_channel // bna, num_filter // bna, 6, 6, bna, bna), name='U') - - a_shape = util.get_const_tuple(A.shape) - w_shape = util.get_const_tuple(W.shape) - dtype = A.dtype - dilation = 1 - - output = tvm.placeholder((batch, num_filter//bna, in_size, in_size, bna), name='output') - - @memoize("topi.tests.test_topi_conv2d_nchw.wino") - def get_ref_data(): - a_np = np.random.uniform(size=a_shape).astype(dtype) - w_np = np.random.uniform(size=w_shape).astype(dtype) - dw_np = topi.testing.dilate_python(w_np, (1, 1, dilation, dilation)) - b_np = topi.testing.conv2d_nchw_python(nchwc_to_nchw(a_np), dw_np, stride, padding) - c_np = np.maximum(b_np, 0) - return a_np, w_np, b_np, c_np - - a_np, w_np, b_np, c_np = get_ref_data() - - with tvm.target.create(device): - B = decl_winograd_without_filter_transform(A, U, stride, padding, dtype) - s = schedule_winograd_without_filter_transform([B]) - - u_np = transform_filter(w_np) - ctx = tvm.context(device, 0) - a = tvm.nd.array(a_np, ctx) - u = tvm.nd.array(u_np, ctx) - b = tvm.nd.array(np.zeros(util.get_const_tuple(B.shape), dtype=B.dtype), ctx) - - with tvm.build_config(auto_unroll_max_step=500, - unroll_explicit=True): - func = tvm.build(s, [A, U, B], device) - func(a, u, b) - #print(tvm.lower(s, [A, U, B], simple_mode=True)) - num_runs = 1000 - timer = func.time_evaluator(func.entry_name, ctx, number=num_runs) - np.testing.assert_allclose(nchwc_to_nchw(b.asnumpy()), b_np, rtol=1e-5) - return timer(a, u, b).mean - - -# for copy paste as markdown -def generate_table(workloads, wino_times, direct_times): - print("| (batch,CI,size,CO) | TVM Winograd (This code) | TVM Direct |") - print("|------------- |:-------------:|:-------------:|") - for (workload, t_wino, t_direct) in zip(workloads, wino_times, direct_times): - print("|", workload, "| %.3f | %.3f |" % (t_wino, t_direct)) - -workloads1 = [(1, 32, 128, 16), - (1, 16, 128, 8), - (1, 8, 128, 16), - (1, 16, 128, 32), - (1, 32, 64, 32), - (1, 32, 64, 64), - (1, 64, 32, 64), - (1, 64, 16, 64), - (1, 128, 16, 64), - (1, 128, 32, 64), - (1, 96, 64, 32), - (1, 40, 128, 16), - (1, 16, 128, 16) - ] - -workloads2 = [(workload[0] * 10, *workload[1:]) for workload in workloads1] # 10 x 128 x 128 -workloads3 = [(workload[0], workload[1], workload[2] * 2, workload[3]) for workload in workloads1] # 1 x 256 x 256 -workloads4 = [(workload[0] * 10, workload[1], workload[2] * 2, workload[3]) for workload in workloads1] # 10 x 256 x 256 - -wino_times = [] -direct_times = [] -device = "llvm" -workloads = workloads1 - -for workload in workloads: - times = test_components(*workload, 3, 1, 1, device) - t_wino = test_winograd_without_filter_transform(*workload, 3, 1, 1, device) - wino_times.append(t_wino * 1000) - t_direct = reference_direct(*workload, 3, 1, 1, device) - direct_times.append(t_direct * 1000) - - print("Workload: ", workload) - for (k,v) in times.items(): - print("%s: %f" % (k, v)) - print("Total: %f" % np.sum(list(times.values()))) - print("Wino time: ", wino_times[-1]) - print("Direct: %f\n" % direct_times[-1]) - -generate_table(workloads, wino_times, direct_times) diff --git a/wino_test_cpu_minimal.py b/wino_test_cpu_minimal.py deleted file mode 100644 index 8fc1138..0000000 --- a/wino_test_cpu_minimal.py +++ /dev/null @@ -1,847 +0,0 @@ -import os -import numpy as np -import tvm -import topi -import topi.testing -from tvm.contrib.pickle_memoize import memoize -from topi import util -from topi.nn import pad - -def reference_direct(batch, in_channel, in_size, num_filter, kernel, stride, padding, device): - in_height = in_width = in_size - - A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A') - W = tvm.placeholder((num_filter, in_channel, kernel, kernel), name='W') - - a_shape = util.get_const_tuple(A.shape) - w_shape = util.get_const_tuple(W.shape) - dtype = A.dtype - - @memoize("topi.tests.test_topi_conv2d_nchw.reference_direct") - def get_ref_data(): - a_np = np.random.uniform(size=a_shape).astype(dtype) - w_np = np.random.uniform(size=w_shape).astype(dtype) - b_np = topi.testing.conv2d_nchw_python(a_np, w_np, stride, padding) - c_np = np.maximum(b_np, 0) - return a_np, w_np, b_np, c_np - - a_np, w_np, b_np, c_np = get_ref_data() - - ctx = tvm.context(device, 0) - if not ctx.exist: - print("Skip because %s is not enabled" % device) - return - with tvm.target.create(device): - B = topi.nn.conv2d(A, W, stride, padding, layout='NCHW') - s1 = topi.generic.schedule_conv2d_nchw([B]) - a = tvm.nd.array(a_np, ctx) - w = tvm.nd.array(w_np, ctx) - b = tvm.nd.array(np.zeros(util.get_const_tuple(B.shape), dtype=B.dtype), ctx) - with tvm.build_config(auto_unroll_max_step=500, - unroll_explicit=True): - func = tvm.build(s1, [A, W, B], device) - #print(tvm.lower(s1, [A, W, B], simple_mode=True)) - func(a, w, b) - np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) - num_runs = 100 - timer = func.time_evaluator(func.entry_name, ctx, number=num_runs) - return timer(a, w, b).mean - -def const_array(data, name): - """ convert an const array to tvm tensor""" - row, col = data.shape - dtype = str(data.dtype) - - def select_array(i, j): - now = tvm.const(0.0, dtype) - for ii in range(row): - for jj in range(col): - now = tvm.select(tvm.all(i % row == ii, j % col == jj), - tvm.const(data[ii][jj], dtype), - now) - return now - return tvm.compute(data.shape, select_array, name=name) - -def decl_U(data, kernel, stride, padding, out_dtype): - N, co, H, W, ci = [util.get_const_int(x) for x in data.shape] - K, C, KH, KW = [util.get_const_int(x) for x in kernel.shape] - HPAD, WPAD = 1,1 - if isinstance(stride, (tuple, list)): - HSTR, WSTR = stride - else: - HSTR, WSTR = stride, stride - - assert HSTR == 1 and WSTR == 1 and HPAD == 1 and WPAD == 1 - - G_data = np.array([ - [1, 0, 0], - [1.0/2, 1.0/2, 1.0/2], - [1.0/2, -1.0/2, 1.0/2], - [0, 0, 1], - ], out_dtype) - - m = 2 - r = 3 - alpha = m + r - 1 - nH, nW = (H + m-1) // m, (W + m-1) // m - P = N * nH * nW - bna, bnb = 8, 8 - # transform kernel - G = const_array(G_data, 'G') - r_kh = tvm.reduce_axis((0, KH), 'r_kh') - r_kw = tvm.reduce_axis((0, KW), 'r_kw') - U = tvm.compute((C // bnb, K // bna, alpha, alpha, bna, bnb), lambda c, k, eps, nu, cc, kk: - tvm.sum(kernel[k * bna + kk][c * bnb + cc][r_kh][r_kw] * G[eps][r_kh] * G[nu][r_kw], - axis=[r_kh, r_kw]), name='U') - outs = [U] - s = tvm.create_schedule([x.op for x in outs]) - op = outs[0].op - U = op.output(0) - kernel, G = s[U].op.input_tensors - s[G].compute_inline() - c, k, eps, nu, cc, kk = s[U].op.axis - r_kh, r_kw = s[U].op.reduce_axis - s[U].reorder(c, k, cc, kk, eps, nu, r_kh, r_kw) - _ = [s[U].unroll(x) for x in [eps, nu, r_kh, r_kw]] - kk = s[U].fuse(kk, cc) - s[U].vectorize(kk) - fused = s[U].fuse(k, c) - s[U].parallel(fused) - - return U, s - -def decl_V(data, kernel, stride, padding, out_dtype): - """declare winograd fast convolution F(2x2, 3x3) for conv2d""" - N, co, H, W, ci = [util.get_const_int(x) for x in data.shape] - K, C, KH, KW = [util.get_const_int(x) for x in kernel.shape] - HPAD, WPAD = 1,1 - if isinstance(stride, (tuple, list)): - HSTR, WSTR = stride - else: - HSTR, WSTR = stride, stride - - assert HSTR == 1 and WSTR == 1 and HPAD == 1 and WPAD == 1 - - m = 2 - r = 3 - alpha = m + r - 1 - nH, nW = (H + m-1) // m, (W + m-1) // m - P = N * nH * nW - bna, bnb = 8, 8 - - data_pad = pad(data, (0, 0, HPAD, WPAD, 0), name="data_pad") - - def compute_temp(b, c, eps, nu, bb, cc): - now = tvm.const(0.0, "float32") - batch_index = (b*bnb+bb) // (nH*nW) - h = (b*bnb+bb) // nW % nH * m - w = (b*bnb+bb) % nW * m - temp_expr = {} - temp_expr[(0,0)] = data_pad[batch_index][c][h+0][0+w][cc] - data_pad[batch_index][c][h+2][0+w][cc] - temp_expr[(0,1)] = data_pad[batch_index][c][h+0][1+w][cc] - data_pad[batch_index][c][h+2][1+w][cc] - temp_expr[(0,2)] = data_pad[batch_index][c][h+0][2+w][cc] - data_pad[batch_index][c][h+2][2+w][cc] - temp_expr[(0,3)] = data_pad[batch_index][c][h+0][3+w][cc] - data_pad[batch_index][c][h+2][3+w][cc] - temp_expr[(1,0)] = data_pad[batch_index][c][h+1][0+w][cc] + data_pad[batch_index][c][h+2][0+w][cc] - temp_expr[(1,1)] = data_pad[batch_index][c][h+1][1+w][cc] + data_pad[batch_index][c][h+2][1+w][cc] - temp_expr[(1,2)] = data_pad[batch_index][c][h+1][2+w][cc] + data_pad[batch_index][c][h+2][2+w][cc] - temp_expr[(1,3)] = data_pad[batch_index][c][h+1][3+w][cc] + data_pad[batch_index][c][h+2][3+w][cc] - temp_expr[(2,0)] = data_pad[batch_index][c][h+2][0+w][cc] - data_pad[batch_index][c][h+1][0+w][cc] - temp_expr[(2,1)] = data_pad[batch_index][c][h+2][1+w][cc] - data_pad[batch_index][c][h+1][1+w][cc] - temp_expr[(2,2)] = data_pad[batch_index][c][h+2][2+w][cc] - data_pad[batch_index][c][h+1][2+w][cc] - temp_expr[(2,3)] = data_pad[batch_index][c][h+2][3+w][cc] - data_pad[batch_index][c][h+1][3+w][cc] - temp_expr[(3,0)] = data_pad[batch_index][c][h+1][0+w][cc] - data_pad[batch_index][c][h+3][0+w][cc] - temp_expr[(3,1)] = data_pad[batch_index][c][h+1][1+w][cc] - data_pad[batch_index][c][h+3][1+w][cc] - temp_expr[(3,2)] = data_pad[batch_index][c][h+1][2+w][cc] - data_pad[batch_index][c][h+3][2+w][cc] - temp_expr[(3,3)] = data_pad[batch_index][c][h+1][3+w][cc] - data_pad[batch_index][c][h+3][3+w][cc] - for ii in range(alpha): - for jj in range(alpha): - now = tvm.select(tvm.all(eps == ii, nu == jj), - temp_expr[(ii, jj)], - now) - return now - - temp = tvm.compute((P // bnb, C // bna, alpha, alpha, bnb, bna), compute_temp, name="temp") - - def compute_V(b, c, eps, nu, bb, cc): - v_expr = {} - v_expr[(0, 0)] = temp[b][c][0][0][bb][cc] - temp[b][c][0][2][bb][cc] - v_expr[(0, 1)] = temp[b][c][0][1][bb][cc] + temp[b][c][0][2][bb][cc] - v_expr[(0, 2)] = temp[b][c][0][2][bb][cc] - temp[b][c][0][1][bb][cc] - v_expr[(0, 3)] = temp[b][c][0][1][bb][cc] - temp[b][c][0][3][bb][cc] - v_expr[(1, 0)] = temp[b][c][1][0][bb][cc] - temp[b][c][1][2][bb][cc] - v_expr[(1, 1)] = temp[b][c][1][1][bb][cc] + temp[b][c][1][2][bb][cc] - v_expr[(1, 2)] = temp[b][c][1][2][bb][cc] - temp[b][c][1][1][bb][cc] - v_expr[(1, 3)] = temp[b][c][1][1][bb][cc] - temp[b][c][1][3][bb][cc] - v_expr[(2, 0)] = temp[b][c][2][0][bb][cc] - temp[b][c][2][2][bb][cc] - v_expr[(2, 1)] = temp[b][c][2][1][bb][cc] + temp[b][c][2][2][bb][cc] - v_expr[(2, 2)] = temp[b][c][2][2][bb][cc] - temp[b][c][2][1][bb][cc] - v_expr[(2, 3)] = temp[b][c][2][1][bb][cc] - temp[b][c][2][3][bb][cc] - v_expr[(3, 0)] = temp[b][c][3][0][bb][cc] - temp[b][c][3][2][bb][cc] - v_expr[(3, 1)] = temp[b][c][3][1][bb][cc] + temp[b][c][3][2][bb][cc] - v_expr[(3, 2)] = temp[b][c][3][2][bb][cc] - temp[b][c][3][1][bb][cc] - v_expr[(3, 3)] = temp[b][c][3][1][bb][cc] - temp[b][c][3][3][bb][cc] - now = tvm.const(0.0, "float32") - for ii in range(4): - for jj in range(4): - now = tvm.select(tvm.all(eps == ii, nu == jj), - v_expr[(ii, jj)], - now) - return now - - V = tvm.compute((P // bnb, C // bna, alpha, alpha, bnb, bna), compute_V) - - outs = [V] - s = tvm.create_schedule([x.op for x in outs]) - op = outs[0].op - V = op.output(0) - temp = s[V].op.input_tensors[0] - data_pad = s[temp].op.input_tensors[0] - s[data_pad].compute_inline() - - b, c, eps, nu, bb, cc = s[V].op.axis - s[V].reorder(b, c, bb, cc, eps, nu) - s[V].vectorize(cc) - _ = [s[V].unroll(x) for x in [eps, nu]] - - fused = s[V].fuse(b, c, bb) - s[V].parallel(fused) - - b, c, eps, nu, bb, cc = s[temp].op.axis - s[temp].reorder(b, c, bb, cc, eps, nu) - s[temp].vectorize(cc) - _ = [s[temp].unroll(x) for x in [eps, nu]] - s[temp].compute_at(s[V], fused) -# s[data_pad].compute_at(s[V], fused) - - return V, s - -def decl_M(data, kernel, U, V, stride, padding, out_dtype): - """declare winograd fast convolution F(2x2, 3x3) for conv2d""" - N, co, H, W, ci = [util.get_const_int(x) for x in data.shape] - K, C, KH, KW = [util.get_const_int(x) for x in kernel.shape] - HPAD, WPAD = 1,1 - if isinstance(stride, (tuple, list)): - HSTR, WSTR = stride - else: - HSTR, WSTR = stride, stride - - assert HSTR == 1 and WSTR == 1 and HPAD == 1 and WPAD == 1 - - m = 2 - r = 3 - alpha = m + r - 1 - nH, nW = (H + m-1) // m, (W + m-1) // m - P = N * nH * nW - bna, bnb = 8, 8 - - # batch gemm - c = tvm.reduce_axis((0, C), name='c') - M = tvm.compute((P //bnb, K // bna, alpha, alpha, bnb, bna), lambda b, k, eps, nu, bb, kk: - tvm.sum(V[b][c // bna][eps][nu][bb][c % bna] * - U[c // bna][k][eps][nu][c % bna][kk], axis=c), name='M') - - outs = [M] - s = tvm.create_schedule([x.op for x in outs]) - op = outs[0].op - M = op.output(0) - b, k, eps, nu, bb, kk = s[M].op.axis - c = s[M].op.reduce_axis[0] - - fused = s[M].fuse(b, k) - s[M].parallel(fused) - co, ci = s[M].split(c, factor=8) - s[M].reorder(co, bb, ci, kk) - s[M].unroll(ci) - s[M].vectorize(kk) - - return M, s - -def decl_output(data, kernel, M, stride, padding, out_dtype): - """declare winograd fast convolution F(2x2, 3x3) for conv2d""" - N, co, H, W, ci = [util.get_const_int(x) for x in data.shape] - K, C, KH, KW = [util.get_const_int(x) for x in kernel.shape] - HPAD, WPAD = 1,1 - if isinstance(stride, (tuple, list)): - HSTR, WSTR = stride - else: - HSTR, WSTR = stride, stride - - assert HSTR == 1 and WSTR == 1 and HPAD == 1 and WPAD == 1 - - A_data = np.array([ - [1, 0], - [1, 1], - [1, -1], - [0, -1], - ], out_dtype) - - m = 2 - r = 3 - alpha = m + r - 1 - nH, nW = (H + m-1) // m, (W + m-1) // m - P = N * nH * nW - bna, bnb = 8, 8 - - # inverse transform - A = const_array(A_data, 'A') - r_eps = tvm.reduce_axis((0, alpha), 'r_eps') - r_nu = tvm.reduce_axis((0, alpha), 'r_nu') - output = tvm.compute((N, K // bna, H, W, bna), lambda n, k, h, w, kk: - tvm.sum(M[(n * nH * nW + (h//m) * nW + w//m)//bna][k][r_eps][r_nu][(n * nH * nW + (h//m) * nW + w//m)%bna][kk] * A[r_eps][h % m] * A[r_nu][w % m], - axis=[r_eps, r_nu]), name='output') - - outs = [output] - s = tvm.create_schedule([x.op for x in outs]) - op = outs[0].op - output = op.output(0) - _, A = s[output].op.input_tensors - s[A].compute_inline() - - n, k, h, w, kk = s[output].op.axis - r_eps, r_nu = s[output].op.reduce_axis - ho, hi = s[output].split(h, factor=2) - wo, wi = s[output].split(w, factor=2) - s[output].reorder(n, k, ho, wo, hi, wi, r_eps, r_nu, kk) - s[output].vectorize(kk) - fused = s[output].fuse(n, k, ho, wo) - s[output].parallel(fused) - - return output, s - -def decl_winograd(data, kernel, stride, padding, out_dtype): - """declare winograd fast convolution F(2x2, 3x3) for conv2d""" - N, co, H, W, ci = [util.get_const_int(x) for x in data.shape] - K, C, KH, KW = [util.get_const_int(x) for x in kernel.shape] - HPAD, WPAD = 1,1 - if isinstance(stride, (tuple, list)): - HSTR, WSTR = stride - else: - HSTR, WSTR = stride, stride - - assert HSTR == 1 and WSTR == 1 and HPAD == 1 and WPAD == 1 - data_pad = pad(data, (0, 0, HPAD, WPAD, 0), name="data_pad") - - B_data = np.array([ - [1, 0, 0, 0], - [0, 1, -1, 1], - [-1, 1, 1, 0], - [0, 0, 0, -1] - ], out_dtype) - - A_data = np.array([ - [1, 0], - [1, 1], - [1, -1], - [0, -1], - ], out_dtype) - - G_data = np.array([ - [1, 0, 0], - [1.0/2, 1.0/2, 1.0/2], - [1.0/2, -1.0/2, 1.0/2], - [0, 0, 1], - ], out_dtype) - - m = 2 - r = 3 - alpha = m + r - 1 - nH, nW = (H + m-1) // m, (W + m-1) // m - P = N * nH * nW - bna, bnb = 8, 8 - - # # transform kernel - G = const_array(G_data, 'G') - r_kh = tvm.reduce_axis((0, KH), 'r_kh') - r_kw = tvm.reduce_axis((0, KW), 'r_kw') - U = tvm.compute((C // bnb, K // bna, alpha, alpha, bna, bnb), lambda c, k, eps, nu, cc, kk: - tvm.sum(kernel[k * bna + kk][c * bnb + cc][r_kh][r_kw] * G[eps][r_kh] * G[nu][r_kw], - axis=[r_kh, r_kw]), name='U') - - # transform image - B = const_array(B_data, 'B') - r_eps = tvm.reduce_axis((0, alpha), 'r_eps') - r_nu = tvm.reduce_axis((0, alpha), 'r_nu') - V = tvm.compute((P // bnb, C // bna, alpha, alpha, bnb, bna), lambda b, c, eps, nu, bb, cc: - tvm.sum(data_pad[(b*bnb+bb) // (nH*nW)][c][(b*bnb+bb) // nW % nH * m + r_eps][(b*bnb+bb) % nW * m + r_nu][cc] * B[r_eps][eps] * B[r_nu][nu], - axis=[r_eps, r_nu]), name='V') - - # batch gemm - c = tvm.reduce_axis((0, C), name='c') - M = tvm.compute((P //bnb, K // bna, alpha, alpha, bnb, bna), lambda b, k, eps, nu, bb, kk: - tvm.sum(V[b][c // bna][eps][nu][bb][c % bna] * - U[c // bna][k][eps][nu][c % bna][kk], axis=c), name='M') - - - # inverse transform - A = const_array(A_data, 'A') - r_eps = tvm.reduce_axis((0, alpha), 'r_eps') - r_nu = tvm.reduce_axis((0, alpha), 'r_nu') - output = tvm.compute((N, K // bna, H, W, bna), lambda n, k, h, w, kk: - tvm.sum(M[(n * nH * nW + (h//m) * nW + w//m)//bna][k][r_eps][r_nu][(n * nH * nW + (h//m) * nW + w//m)%bna][kk] * A[r_eps][h % m] * A[r_nu][w % m], - axis=[r_eps, r_nu]), name='output') - - return output - - -def schedule_winograd(outs): - s = tvm.create_schedule([x.op for x in outs]) - op = outs[0].op - output = op.output(0) - M, A = s[output].op.input_tensors - V, U = s[M].op.input_tensors - kernel, G = s[U].op.input_tensors - data_pad, B = s[V].op.input_tensors - data = s[data_pad].op.input_tensors[0] - - s[data_pad].compute_inline() - - # transform kernel - s[G].compute_inline() - c, k, eps, nu, cc, kk = s[U].op.axis - r_kh, r_kw = s[U].op.reduce_axis - s[U].reorder(c, k, cc, kk, eps, nu, r_kh, r_kw) - _ = [s[U].unroll(x) for x in [eps, nu, r_kh, r_kw]] - kk = s[U].fuse(kk, cc) - s[U].vectorize(kk) - fused = s[U].fuse(k, c) - s[U].parallel(fused) - - # transform image - s[B].compute_inline() - b, c, eps, nu, bb, cc = s[V].op.axis - r_eps, r_nu = s[V].op.reduce_axis - s[V].reorder(b, c, eps, nu, r_nu, r_eps, bb, cc) - s[V].vectorize(cc) - _ = [s[V].unroll(x) for x in [eps, nu, r_eps, r_nu]] - fused = s[V].fuse(b, c) - s[V].parallel(fused) - - # batch gemm - b, k, eps, nu, bb, kk = s[M].op.axis - c = s[M].op.reduce_axis[0] - fused = s[M].fuse(b, k) - s[M].parallel(fused) - co, ci = s[M].split(c, factor=8) - s[M].reorder(co, bb, ci, kk) - s[M].unroll(ci) - s[M].vectorize(kk) - -# # inverse transform - s[A].compute_inline() - n, k, h, w, kk = s[output].op.axis - r_eps, r_nu = s[output].op.reduce_axis - ho, hi = s[output].split(h, factor=2) - wo, wi = s[output].split(w, factor=2) - s[output].reorder(n, k, ho, wo, hi, wi, r_eps, r_nu, kk) - s[output].vectorize(kk) - fused = s[output].fuse(n, k, ho, wo) - s[output].parallel(fused) - - return s - -def decl_winograd_without_filter_transform(data, U, stride, padding, out_dtype): - N, co, H, W, ci = [util.get_const_int(x) for x in data.shape] - co, ko, _, _, ci, ki = [util.get_const_int(x) for x in U.shape] - C = co * ci - K = ko * ki - HPAD, WPAD = 1,1 - if isinstance(stride, (tuple, list)): - HSTR, WSTR = stride - else: - HSTR, WSTR = stride, stride - - assert HSTR == 1 and WSTR == 1 and HPAD == 1 and WPAD == 1 - data_pad = pad(data, (0, 0, HPAD, WPAD, 0), name="data_pad") - - A_data = np.array([ - [1, 0], - [1, 1], - [1, -1], - [0, -1], - ], out_dtype) - - m = 2 - r = 3 - alpha = m + r - 1 - nH, nW = (H + m-1) // m, (W + m-1) // m - P = N * nH * nW - bna, bnb = 8, 8 - - # transform image - def compute_temp(b, c, eps, nu, bb, cc): - now = tvm.const(0.0, "float32") - batch_index = (b*bnb+bb) // (nH*nW) - h = (b*bnb+bb) // nW % nH * m - w = (b*bnb+bb) % nW * m - temp_expr = {} - temp_expr[(0,0)] = data_pad[batch_index][c][h+0][0+w][cc] - data_pad[batch_index][c][h+2][0+w][cc] - temp_expr[(0,1)] = data_pad[batch_index][c][h+0][1+w][cc] - data_pad[batch_index][c][h+2][1+w][cc] - temp_expr[(0,2)] = data_pad[batch_index][c][h+0][2+w][cc] - data_pad[batch_index][c][h+2][2+w][cc] - temp_expr[(0,3)] = data_pad[batch_index][c][h+0][3+w][cc] - data_pad[batch_index][c][h+2][3+w][cc] - temp_expr[(1,0)] = data_pad[batch_index][c][h+1][0+w][cc] + data_pad[batch_index][c][h+2][0+w][cc] - temp_expr[(1,1)] = data_pad[batch_index][c][h+1][1+w][cc] + data_pad[batch_index][c][h+2][1+w][cc] - temp_expr[(1,2)] = data_pad[batch_index][c][h+1][2+w][cc] + data_pad[batch_index][c][h+2][2+w][cc] - temp_expr[(1,3)] = data_pad[batch_index][c][h+1][3+w][cc] + data_pad[batch_index][c][h+2][3+w][cc] - temp_expr[(2,0)] = data_pad[batch_index][c][h+2][0+w][cc] - data_pad[batch_index][c][h+1][0+w][cc] - temp_expr[(2,1)] = data_pad[batch_index][c][h+2][1+w][cc] - data_pad[batch_index][c][h+1][1+w][cc] - temp_expr[(2,2)] = data_pad[batch_index][c][h+2][2+w][cc] - data_pad[batch_index][c][h+1][2+w][cc] - temp_expr[(2,3)] = data_pad[batch_index][c][h+2][3+w][cc] - data_pad[batch_index][c][h+1][3+w][cc] - temp_expr[(3,0)] = data_pad[batch_index][c][h+1][0+w][cc] - data_pad[batch_index][c][h+3][0+w][cc] - temp_expr[(3,1)] = data_pad[batch_index][c][h+1][1+w][cc] - data_pad[batch_index][c][h+3][1+w][cc] - temp_expr[(3,2)] = data_pad[batch_index][c][h+1][2+w][cc] - data_pad[batch_index][c][h+3][2+w][cc] - temp_expr[(3,3)] = data_pad[batch_index][c][h+1][3+w][cc] - data_pad[batch_index][c][h+3][3+w][cc] - for ii in range(alpha): - for jj in range(alpha): - now = tvm.select(tvm.all(eps == ii, nu == jj), - temp_expr[(ii, jj)], - now) - return now - - temp = tvm.compute((P // bnb, C // bna, alpha, alpha, bnb, bna), compute_temp, name="temp") - - def compute_V(b, c, eps, nu, bb, cc): - v_expr = {} - v_expr[(0, 0)] = temp[b][c][0][0][bb][cc] - temp[b][c][0][2][bb][cc] - v_expr[(0, 1)] = temp[b][c][0][1][bb][cc] + temp[b][c][0][2][bb][cc] - v_expr[(0, 2)] = temp[b][c][0][2][bb][cc] - temp[b][c][0][1][bb][cc] - v_expr[(0, 3)] = temp[b][c][0][1][bb][cc] - temp[b][c][0][3][bb][cc] - v_expr[(1, 0)] = temp[b][c][1][0][bb][cc] - temp[b][c][1][2][bb][cc] - v_expr[(1, 1)] = temp[b][c][1][1][bb][cc] + temp[b][c][1][2][bb][cc] - v_expr[(1, 2)] = temp[b][c][1][2][bb][cc] - temp[b][c][1][1][bb][cc] - v_expr[(1, 3)] = temp[b][c][1][1][bb][cc] - temp[b][c][1][3][bb][cc] - v_expr[(2, 0)] = temp[b][c][2][0][bb][cc] - temp[b][c][2][2][bb][cc] - v_expr[(2, 1)] = temp[b][c][2][1][bb][cc] + temp[b][c][2][2][bb][cc] - v_expr[(2, 2)] = temp[b][c][2][2][bb][cc] - temp[b][c][2][1][bb][cc] - v_expr[(2, 3)] = temp[b][c][2][1][bb][cc] - temp[b][c][2][3][bb][cc] - v_expr[(3, 0)] = temp[b][c][3][0][bb][cc] - temp[b][c][3][2][bb][cc] - v_expr[(3, 1)] = temp[b][c][3][1][bb][cc] + temp[b][c][3][2][bb][cc] - v_expr[(3, 2)] = temp[b][c][3][2][bb][cc] - temp[b][c][3][1][bb][cc] - v_expr[(3, 3)] = temp[b][c][3][1][bb][cc] - temp[b][c][3][3][bb][cc] - now = tvm.const(0.0, "float32") - for ii in range(4): - for jj in range(4): - now = tvm.select(tvm.all(eps == ii, nu == jj), - v_expr[(ii, jj)], - now) - return now - - V = tvm.compute((P // bnb, C // bna, alpha, alpha, bnb, bna), compute_V) - # batch gemm - c = tvm.reduce_axis((0, C), name='c') - M = tvm.compute((P //bnb, K // bna, alpha, alpha, bnb, bna), lambda b, k, eps, nu, bb, kk: - tvm.sum(V[b][c // bna][eps][nu][bb][c % bna] * - U[c // bna][k][eps][nu][c % bna][kk], axis=c), name='M') - - # inverse transform - A = const_array(A_data, 'A') - r_eps = tvm.reduce_axis((0, alpha), 'r_eps') - r_nu = tvm.reduce_axis((0, alpha), 'r_nu') - output = tvm.compute((N, K // bna, H, W, bna), lambda n, k, h, w, kk: - tvm.sum(M[(n * nH * nW + (h//m) * nW + w//m)//bna][k][r_eps][r_nu][(n * nH * nW + (h//m) * nW + w//m)%bna][kk] * A[r_eps][h % m] * A[r_nu][w % m], - axis=[r_eps, r_nu]), name='output') - - return output - -def schedule_winograd_without_filter_transform(outs): - s = tvm.create_schedule([x.op for x in outs]) - op = outs[0].op - output = op.output(0) - M, A = s[output].op.input_tensors - V, U = s[M].op.input_tensors - temp = s[V].op.input_tensors[0] - data_pad = s[temp].op.input_tensors[0] - data = s[data_pad].op.input_tensors[0] - - # transform image - s[data_pad].compute_inline() - b, c, eps, nu, bb, cc = s[V].op.axis - s[V].reorder(b, c, bb, cc, eps, nu) - s[V].vectorize(cc) - _ = [s[V].unroll(x) for x in [eps, nu]] - fused = s[V].fuse(b, c, bb) - s[V].parallel(fused) - - b, c, eps, nu, bb, cc = s[temp].op.axis - s[temp].reorder(b, c, bb, cc, eps, nu) - s[temp].vectorize(cc) - _ = [s[temp].unroll(x) for x in [eps, nu]] - s[temp].compute_at(s[V], fused) -# s[data_pad].compute_at(s[V], fused) - - # batch gemm - b, k, eps, nu, bb, kk = s[M].op.axis - c = s[M].op.reduce_axis[0] - fused = s[M].fuse(b, k) - s[M].parallel(fused) - co, ci = s[M].split(c, factor=8) - s[M].reorder(co, bb, ci, kk) - s[M].unroll(ci) - s[M].vectorize(kk) - -# # inverse transform - s[A].compute_inline() - n, k, h, w, kk = s[output].op.axis - r_eps, r_nu = s[output].op.reduce_axis - ho, hi = s[output].split(h, factor=2) - wo, wi = s[output].split(w, factor=2) - s[output].reorder(n, k, ho, wo, hi, wi, r_eps, r_nu, kk) - s[output].vectorize(kk) - fused = s[output].fuse(n, k, ho, wo) - s[output].parallel(fused) - - return s - -def transform_filter(w_np): - num_filter, in_channel, kernel, kernel = w_np.shape - G = np.array([ - [1, 0, 0], - [1.0/2, 1.0/2, 1.0/2], - [1.0/2, -1.0/2, 1.0/2], - [0, 0, 1], - ], w_np.dtype) - bna = 8 - out = np.empty((in_channel // bna, num_filter // bna, 4, 4, bna, bna), w_np.dtype) - for i in range(in_channel): - for j in range(num_filter): - out[i // bna, j // bna, :, :, i % bna, j % bna] = np.dot(G, np.dot(w_np[j, i], G.transpose())) - return out - -def nchwc_to_nchw(arr): - n, c, h, w, cc = arr.shape - channels = c * cc - ret = np.zeros((n, channels, h, w)) - for i in range(channels): - ret[:, i] = arr[:, i//cc, :, :, i%cc] - return ret - -def test_components(batch, in_channel, in_size, num_filter, kernel, stride, padding, device): - in_height = in_width = in_size - m = 2 - r = 3 - alpha = m + r - 1 - K = num_filter - H = W = in_size - N = batch - C = in_channel - nH, nW = (H + m-1) // m, (W + m-1) // m - P = N * nH * nW - bna, bnb = 8, 8 - - A = tvm.placeholder((batch, in_channel // bna, in_height, in_width, bna), name='A') - W = tvm.placeholder((num_filter, in_channel, kernel, kernel), name='W') - U = tvm.placeholder((C // bna, K // bna, alpha, alpha, bna, bna), name='U') - V = tvm.placeholder((P // bnb, C // bna, alpha, alpha, bnb, bna), name='V') - M = tvm.placeholder((P // bnb, K // bna, alpha, alpha, bnb, bna), name='M') - - output = tvm.placeholder((N, K // bna, in_size, in_size, bna), name='output') - - a_shape = util.get_const_tuple(A.shape) - w_shape = util.get_const_tuple(W.shape) - dtype = A.dtype - dilation = 1 - - @memoize("topi.tests.test_topi_conv2d_nchw.wino") - def get_ref_data(): - a_np = np.random.uniform(size=a_shape).astype(dtype) - w_np = np.random.uniform(size=w_shape).astype(dtype) - dw_np = topi.testing.dilate_python(w_np, (1, 1, dilation, dilation)) - b_np = topi.testing.conv2d_nchw_python(nchwc_to_nchw(a_np), dw_np, stride, padding) - c_np = np.maximum(b_np, 0) - return a_np, w_np, b_np, c_np - - a_np, w_np, b_np, c_np = get_ref_data() - u_np = np.random.uniform(size=util.get_const_tuple(U.shape)).astype(dtype) - v_np = np.random.uniform(size=util.get_const_tuple(V.shape)).astype(dtype) - m_np = np.zeros(util.get_const_tuple(M.shape), dtype=dtype) - output_np = np.zeros(util.get_const_tuple(output.shape), dtype=dtype) - - with tvm.target.create(device): - U_out, s_U = decl_U(A, W, stride, padding, dtype) - V_out, s_V = decl_V(A, W, stride, padding, dtype) - M_out, s_M = decl_M(A, W, U, V, stride, padding, dtype) - output_out, s_output = decl_output(A, W, M, stride, padding, dtype) - - ctx = tvm.context(device, 0) - a = tvm.nd.array(a_np, ctx) - w = tvm.nd.array(w_np, ctx) - u = tvm.nd.array(u_np, ctx) - v = tvm.nd.array(v_np, ctx) - m = tvm.nd.array(m_np, ctx) - output_tvm = tvm.nd.array(output_np, ctx) - num_runs = 100 - times = {} - - with tvm.build_config(auto_unroll_max_step=500, - unroll_explicit=True): - func_kernel_transform = tvm.build(s_U, [W, U_out], device) - func_kernel_transform(w, u) - timer = func_kernel_transform.time_evaluator(func_kernel_transform.entry_name, ctx, number=num_runs) - times["U"] = timer(w, u).mean * 1000 - - func_input_transform = tvm.build(s_V, [A, V_out], device) - func_input_transform(a, v) - timer = func_input_transform.time_evaluator(func_input_transform.entry_name, ctx, number=num_runs) - times["V"] = timer(a, v).mean * 1000 - #print(tvm.lower(s_V, [A, V_out], simple_mode=True)) - - func_batch_mm = tvm.build(s_M, [U, V, M_out], device) - #print(tvm.lower(s_M, [U, V, M_out], simple_mode=True)) - func_batch_mm(u, v, m) - - timer = func_batch_mm.time_evaluator(func_batch_mm.entry_name, ctx, number=num_runs) - times["M"] = timer(u, v, m).mean * 1000 - #print(tvm.lower(s_M, [A, W, U, V, M], simple_mode=True)) - - func_inverse_transform = tvm.build(s_output, [M, output_out], device) - func_inverse_transform(m, output_tvm) - timer = func_inverse_transform.time_evaluator(func_inverse_transform.entry_name, ctx, number=num_runs) - times["output"] = timer(m, output_tvm).mean * 1000 - #print(tvm.lower(s_output, [A, W, M, output], simple_mode=True)) - - np.testing.assert_allclose(nchwc_to_nchw(output_tvm.asnumpy()), b_np, rtol=1e-5) - - return times - - -def test_winograd(batch, in_channel, in_size, num_filter, kernel, stride, padding, device): - in_height = in_width = in_size - bna, bnb = 8, 8 - - A = tvm.placeholder((batch, in_channel // bna, in_height, in_width, bna), name='A') - W = tvm.placeholder((num_filter, in_channel, kernel, kernel), name='W') - - a_shape = util.get_const_tuple(A.shape) - w_shape = util.get_const_tuple(W.shape) - dtype = A.dtype - dilation = 1 - - output = tvm.placeholder((batch, num_filter//bna, in_size, in_size, bna), name='output') - - @memoize("topi.tests.test_topi_conv2d_nchw.wino") - def get_ref_data(): - a_np = np.random.uniform(size=a_shape).astype(dtype) - w_np = np.random.uniform(size=w_shape).astype(dtype) - dw_np = topi.testing.dilate_python(w_np, (1, 1, dilation, dilation)) - b_np = topi.testing.conv2d_nchw_python(nchwc_to_nchw(a_np), dw_np, stride, padding) - c_np = np.maximum(b_np, 0) - return a_np, w_np, b_np, c_np - - a_np, w_np, b_np, c_np = get_ref_data() - - with tvm.target.create(device): - B = decl_winograd(A, W, stride, padding, dtype) - s = schedule_winograd([B]) - - output_np = np.zeros(util.get_const_tuple(output.shape), dtype=dtype) - - ctx = tvm.context(device, 0) - a = tvm.nd.array(a_np, ctx) - w = tvm.nd.array(w_np, ctx) - b = tvm.nd.array(np.zeros(util.get_const_tuple(B.shape), dtype=B.dtype), ctx) - output_tvm = tvm.nd.array(output_np, ctx) - - with tvm.build_config(auto_unroll_max_step=500, - unroll_explicit=True): - func = tvm.build(s, [A, W, B], device) - func(a, w, b) - #print(tvm.lower(s, [A, W, B], simple_mode=True)) - num_runs = 100 - timer = func.time_evaluator(func.entry_name, ctx, number=num_runs) - np.testing.assert_allclose(nchwc_to_nchw(b.asnumpy()), b_np, rtol=1e-5) - return timer(a, w, b).mean - -def test_winograd_without_filter_transform(batch, in_channel, in_size, num_filter, kernel, stride, padding, device): - in_height = in_width = in_size - bna, bnb = 8, 8 - - A = tvm.placeholder((batch, in_channel // bna, in_height, in_width, bna), name='A') - W = tvm.placeholder((num_filter, in_channel, kernel, kernel), name='W') - U = tvm.placeholder((in_channel // bna, num_filter // bna, 4, 4, bna, bna), name='U') - - a_shape = util.get_const_tuple(A.shape) - w_shape = util.get_const_tuple(W.shape) - dtype = A.dtype - dilation = 1 - - output = tvm.placeholder((batch, num_filter//bna, in_size, in_size, bna), name='output') - - @memoize("topi.tests.test_topi_conv2d_nchw.wino") - def get_ref_data(): - a_np = np.random.uniform(size=a_shape).astype(dtype) - w_np = np.random.uniform(size=w_shape).astype(dtype) - dw_np = topi.testing.dilate_python(w_np, (1, 1, dilation, dilation)) - b_np = topi.testing.conv2d_nchw_python(nchwc_to_nchw(a_np), dw_np, stride, padding) - c_np = np.maximum(b_np, 0) - return a_np, w_np, b_np, c_np - - a_np, w_np, b_np, c_np = get_ref_data() - - with tvm.target.create(device): - B = decl_winograd_without_filter_transform(A, U, stride, padding, dtype) - s = schedule_winograd_without_filter_transform([B]) - - u_np = transform_filter(w_np) - ctx = tvm.context(device, 0) - a = tvm.nd.array(a_np, ctx) - u = tvm.nd.array(u_np, ctx) - b = tvm.nd.array(np.zeros(util.get_const_tuple(B.shape), dtype=B.dtype), ctx) - - with tvm.build_config(auto_unroll_max_step=500, - unroll_explicit=True): - func = tvm.build(s, [A, U, B], device) - func(a, u, b) - #print(tvm.lower(s, [A, W, B], simple_mode=True)) - num_runs = 100 - timer = func.time_evaluator(func.entry_name, ctx, number=num_runs) - np.testing.assert_allclose(nchwc_to_nchw(b.asnumpy()), b_np, rtol=1e-5) - return timer(a, u, b).mean - - -# for copy paste as markdown -def generate_table(workloads, wino_times, direct_times): - print("| (batch,CI,size,CO) | TVM Winograd (This code) | TVM Direct |") - print("|------------- |:-------------:|:-------------:|") - for (workload, t_wino, t_direct) in zip(workloads, wino_times, direct_times): - print("|", workload, "| %.3f | %.3f |" % (t_wino, t_direct)) - -workloads1 = [(1, 32, 128, 16), - (1, 16, 128, 8), - (1, 8, 128, 16), - (1, 16, 128, 32), - (1, 32, 64, 32), - (1, 32, 64, 64), - (1, 64, 32, 64), - (1, 64, 16, 64), - (1, 64, 8, 64), - (1, 128, 16, 64), - (1, 128, 32, 64), - (1, 96, 64, 32), - (1, 40, 128, 16), - (1, 16, 128, 16) - ] - -workloads2 = [(workload[0] * 10, *workload[1:]) for workload in workloads1] # 10 x 128 x 128 -workloads3 = [(workload[0], workload[1], workload[2] * 2, workload[3]) for workload in workloads1] # 1 x 256 x 256 -workloads4 = [(workload[0] * 10, workload[1], workload[2] * 2, workload[3]) for workload in workloads1] # 10 x 256 x 256 - -wino_times = [] -direct_times = [] -device = "llvm" -workloads = workloads1 - -for workload in workloads: - times = test_components(*workload, 3, 1, 1, device) - t_wino = test_winograd_without_filter_transform(*workload, 3, 1, 1, device) - wino_times.append(t_wino * 1000) - t_direct = reference_direct(*workload, 3, 1, 1, device) - direct_times.append(t_direct * 1000) - - print("Workload: ", workload) - for (k,v) in times.items(): - print("%s: %f" % (k, v)) - print("Total: %f" % np.sum(list(times.values()))) - print("Wino time: ", wino_times[-1]) - print("Direct: %f\n" % direct_times[-1]) - - -generate_table(workloads, wino_times, direct_times) diff --git a/wino_test_cuda.py b/wino_test_cuda.py index 114a50e..0a92e5a 100644 --- a/wino_test_cuda.py +++ b/wino_test_cuda.py @@ -50,20 +50,108 @@ def get_ref_data(): timer = func.time_evaluator(func.entry_name, ctx, number=num_runs) return timer(a, w, b).mean -def const_array(data, name): - """ convert an const array to tvm tensor""" - row, col = data.shape - dtype = str(data.dtype) - - def select_array(i, j): - now = tvm.const(0.0, dtype) - for ii in range(row): - for jj in range(col): - now = tvm.select(tvm.all(i % row == ii, j % col == jj), - tvm.const(data[ii][jj], dtype), +def decl_V_minimal(input_tile, alpha, C, P): + # transform image + def compute_temp(c, p, i, j): + now = tvm.const(0.0, "float32") + temp_expr = {} + temp_expr[(0,0)] = input_tile[c][p][0][0] - input_tile[c][p][2][0] + temp_expr[(0,1)] = input_tile[c][p][0][1] - input_tile[c][p][2][1] + temp_expr[(0,2)] = input_tile[c][p][0][2] - input_tile[c][p][2][2] + temp_expr[(0,3)] = input_tile[c][p][0][3] - input_tile[c][p][2][3] + temp_expr[(1,0)] = input_tile[c][p][1][0] + input_tile[c][p][2][0] + temp_expr[(1,1)] = input_tile[c][p][1][1] + input_tile[c][p][2][1] + temp_expr[(1,2)] = input_tile[c][p][1][2] + input_tile[c][p][2][2] + temp_expr[(1,3)] = input_tile[c][p][1][3] + input_tile[c][p][2][3] + temp_expr[(2,0)] = input_tile[c][p][2][0] - input_tile[c][p][1][0] + temp_expr[(2,1)] = input_tile[c][p][2][1] - input_tile[c][p][1][1] + temp_expr[(2,2)] = input_tile[c][p][2][2] - input_tile[c][p][1][2] + temp_expr[(2,3)] = input_tile[c][p][2][3] - input_tile[c][p][1][3] + temp_expr[(3,0)] = input_tile[c][p][1][0] - input_tile[c][p][3][0] + temp_expr[(3,1)] = input_tile[c][p][1][1] - input_tile[c][p][3][1] + temp_expr[(3,2)] = input_tile[c][p][1][2] - input_tile[c][p][3][2] + temp_expr[(3,3)] = input_tile[c][p][1][3] - input_tile[c][p][3][3] + for ii in range(alpha): + for jj in range(alpha): + now = tvm.select(tvm.all(i == ii, j == jj), + temp_expr[(ii, jj)], + now) + return now + + temp = tvm.compute((C, P, alpha, alpha), compute_temp, name="temp") + + def compute_V(i, j, c, p): + v_expr = {} + v_expr[(0, 0)] = temp[c][p][0][0] - temp[c][p][0][2] + v_expr[(0, 1)] = temp[c][p][0][1] + temp[c][p][0][2] + v_expr[(0, 2)] = temp[c][p][0][2] - temp[c][p][0][1] + v_expr[(0, 3)] = temp[c][p][0][1] - temp[c][p][0][3] + v_expr[(1, 0)] = temp[c][p][1][0] - temp[c][p][1][2] + v_expr[(1, 1)] = temp[c][p][1][1] + temp[c][p][1][2] + v_expr[(1, 2)] = temp[c][p][1][2] - temp[c][p][1][1] + v_expr[(1, 3)] = temp[c][p][1][1] - temp[c][p][1][3] + v_expr[(2, 0)] = temp[c][p][2][0] - temp[c][p][2][2] + v_expr[(2, 1)] = temp[c][p][2][1] + temp[c][p][2][2] + v_expr[(2, 2)] = temp[c][p][2][2] - temp[c][p][2][1] + v_expr[(2, 3)] = temp[c][p][2][1] - temp[c][p][2][3] + v_expr[(3, 0)] = temp[c][p][3][0] - temp[c][p][3][2] + v_expr[(3, 1)] = temp[c][p][3][1] + temp[c][p][3][2] + v_expr[(3, 2)] = temp[c][p][3][2] - temp[c][p][3][1] + v_expr[(3, 3)] = temp[c][p][3][1] - temp[c][p][3][3] + now = tvm.const(0.0, "float32") + for ii in range(4): + for jj in range(4): + now = tvm.select(tvm.all(i == ii, j == jj), + v_expr[(ii, jj)], now) return now - return tvm.compute(data.shape, select_array, name=name) + + V = tvm.compute((alpha, alpha, C, P), compute_V) + + return V + +def decl_output_minimal(M, N, K, H, W, P, m, nH, nW): + + def compute_temp(k, p, eps, nu): + temp_expr = {} + for j in range(4): + t0 = M[0][j][k][p] + M[1][j][k][p] + t1 = M[1][j][k][p] - M[2][j][k][p] + temp_expr[(0,j)] = t0 + M[2][j][k][p] + temp_expr[(1,j)] = t1 - M[3][j][k][p] + + now = tvm.const(0.0, "float32") + for ii in range(2): + for jj in range(4): + now = tvm.select(tvm.all(eps == ii, nu == jj), + temp_expr[(ii, jj)], + now) + return now + + temp = tvm.compute((K, P, 2,4), compute_temp, name="temp") + + def compute_output(n, k, h, w): + b = n * nH * nW + (h//m) * nW + w//m + eps = h%m + nu = w%m + output_expr = {} + for i in range(2): + t0 = temp[k][b][i][0] + temp[k][b][i][1] + t1 = temp[k][b][i][1] - temp[k][b][i][2] + output_expr[(i,0)] = t0 + temp[k][b][i][2] + output_expr[(i,1)] = t1 - temp[k][b][i][3] + + now = tvm.const(0.0, "float32") + for ii in range(2): + for jj in range(2): + now = tvm.select(tvm.all(eps == ii, nu == jj), + output_expr[(ii, jj)], + now) + return now + + output = tvm.compute((N, K, H, W), compute_output) + + return output def decl_winograd(data, U, stride, padding, out_dtype): """declare winograd fast convolution F(2x2, 3x3) for conv2d""" @@ -78,25 +166,10 @@ def decl_winograd(data, U, stride, padding, out_dtype): assert HSTR == 1 and WSTR == 1 and HPAD == 1 and WPAD == 1 data_pad = pad(data, (0, 0, HPAD, WPAD), name="data_pad") - B_data = np.array([ - [1, 0, 0, 0], - [0, 1, -1, 1], - [-1, 1, 1, 0], - [0, 0, 0, -1] - ], out_dtype) - - A_data = np.array([ - [1, 0], - [1, 1], - [1, -1], - [0, -1], - ], out_dtype) - m = 2 r = 3 alpha = m + r - 1 K = K - nH, nW = (H + m-1) // m, (W + m-1) // m P = N * nH * nW @@ -105,13 +178,7 @@ def decl_winograd(data, U, stride, padding, out_dtype): lambda c, b, eps, nu: tvm.select(b < P, data_pad[b // (nH*nW)][c][b// nW % nH * m + eps][b % nW * m + nu], tvm.const(0, data_pad.dtype)), name='d') - # transform image - B = const_array(B_data, 'B') - r_eps = tvm.reduce_axis((0, alpha), 'r_eps') - r_nu = tvm.reduce_axis((0, alpha), 'r_nu') - V = tvm.compute((alpha, alpha, C, P), lambda eps, nu, c, b: - tvm.sum(input_tile[c][b][r_eps][r_nu] * B[r_eps][eps] * B[r_nu][nu], - axis=[r_eps, r_nu]), name='V') + V = decl_V_minimal(input_tile, alpha, C, P) # batch gemm c = tvm.reduce_axis((0, C), name='c') @@ -120,12 +187,7 @@ def decl_winograd(data, U, stride, padding, out_dtype): V[eps][nu][c][b], axis=c), name='M') # inverse transform and unpack - A = const_array(A_data, 'A') - r_eps = tvm.reduce_axis((0, alpha), 'r_eps') - r_nu = tvm.reduce_axis((0, alpha), 'r_nu') - output = tvm.compute((N, K, H, W), lambda n, k, h, w: - tvm.sum(M[r_eps][r_nu][k][n * nH * nW + (h//m) * nW + w//m] * A[r_eps][h % m] * A[r_nu][w % m], - axis=[r_eps, r_nu]), name='output') + output = decl_output_minimal(M, N, K, H, W, P, m, nH, nW) return output @@ -195,20 +257,17 @@ def schedule_winograd(outs): s = tvm.create_schedule([x.op for x in outs]) op = outs[0].op output = op.output(0) - - M, A = s[output].op.input_tensors + output_temp = s[output].op.input_tensors[0] + M = s[output_temp].op.input_tensors[0] U, V = s[M].op.input_tensors - d, B = s[V].op.input_tensors + V_temp = s[V].op.input_tensors[0] + d = s[V_temp].op.input_tensors[0] data_pad = s[d].op.input_tensors[0] - data = s[data_pad].op.input_tensors[0] s[data_pad].compute_inline() # transform image - s[B].compute_inline() - VL = s.cache_write(V, "local") eps, nu, c, p = s[V].op.axis - r_eps, r_nu = s[VL].op.reduce_axis s[V].reorder(c, p, eps, nu) co, ci = s[V].split(c, factor=16) @@ -217,17 +276,13 @@ def schedule_winograd(outs): s[V].bind(pi, tvm.thread_axis("threadIdx.x")) s[V].bind(co, tvm.thread_axis("blockIdx.y")) s[V].bind(po, tvm.thread_axis("blockIdx.x")) - - s[VL].compute_at(s[V], pi) + s[V_temp].compute_at(s[V], pi) s[d].compute_at(s[V], pi) - + schedule_batched_sgemm(s, U, V, M) # inverse transform - s[A].compute_inline() n, k, h, w = s[output].op.axis - ML = s.cache_read(M, "local", [output]) - output_L = s.cache_write(output, "local") ho, hi = s[output].split(h, factor=2) wo, wi = s[output].split(w, factor=2) s[output].reorder(k, n, ho, wo, hi, wi) @@ -235,17 +290,16 @@ def schedule_winograd(outs): hoo, hoi = s[output].split(ho, factor=16) woo, woi = s[output].split(wo, factor=16) + s[output].reorder(hoo, woo, hoi, woi, hi, wi) s[output].bind(hoi, tvm.thread_axis("threadIdx.y")) s[output].bind(woi, tvm.thread_axis("threadIdx.x")) s[output].bind(hoo, tvm.thread_axis("blockIdx.y")) s[output].bind(woo, tvm.thread_axis("blockIdx.x")) s[output].bind(k, tvm.thread_axis("blockIdx.z")) - s[output_L].compute_at(s[output], woi) - s[ML].compute_at(s[output], woi) + s[output_temp].compute_at(s[output], woi) return s - def transform_filter(w_np): num_filter, in_channel, kernel, kernel = w_np.shape G = np.array([ @@ -299,10 +353,12 @@ def get_ref_data(): unroll_explicit=(device != "cuda"), partition_const_loop=False): func = tvm.build(s, [A, U, B], device) + #print(tvm.lower(s, [A, U, B], simple_mode=True)) func(a, u, b) num_runs = 100 timer = func.time_evaluator(func.entry_name, ctx, number=num_runs) np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) + #print(func.imported_modules[0].get_source()) return timer(a, u, b).mean # for copy paste as markdown @@ -342,7 +398,8 @@ def generate_table(workloads, wino_times, direct_times, wino_nvptx_times, direct (1, 256, 56, 256), #relu4 (1, 256, 28, 512), (1, 512, 28, 512), # relu6 - (1, 512, 14, 512)] # relu7 + (1, 512, 14, 512) # relu7 + ] wino_times = [] direct_times = [] @@ -365,8 +422,8 @@ def generate_table(workloads, wino_times, direct_times, wino_nvptx_times, direct else: t_direct_nvptx = reference_direct(*workload, 3, 1, 1, "nvptx") - t_lib = reference_direct(*workload, 3, 1, 1, "cuda -libs=cudnn") - + #t_lib = reference_direct(*workload, 3, 1, 1, "cuda -libs=cudnn") + t_lib = 0 wino_times.append(t_wino * 1000) wino_nvptx_times.append(t_wino_nvptx * 1000) lib_times.append(t_lib * 1000) diff --git a/wino_test_cuda_minimal.py b/wino_test_cuda_minimal.py deleted file mode 100644 index 0a92e5a..0000000 --- a/wino_test_cuda_minimal.py +++ /dev/null @@ -1,439 +0,0 @@ -import os -import numpy as np -import tvm -import topi -import topi.testing -from tvm.contrib.pickle_memoize import memoize -from topi import util -from topi.nn import pad - -def reference_direct(batch, in_channel, in_size, num_filter, kernel, stride, padding, device): - in_height = in_width = in_size - - A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A') - W = tvm.placeholder((num_filter, in_channel, kernel, kernel), name='W') - - a_shape = util.get_const_tuple(A.shape) - w_shape = util.get_const_tuple(W.shape) - dtype = A.dtype - dilation = 1 - - @memoize("topi.tests.test_topi_conv2d_nchw.reference_direct") - def get_ref_data(): - a_np = np.random.uniform(size=a_shape).astype(dtype) - w_np = np.random.uniform(size=w_shape).astype(dtype) - dw_np = topi.testing.dilate_python(w_np, (1, 1, dilation, dilation)) - b_np = topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding) - c_np = np.maximum(b_np, 0) - return a_np, w_np, b_np, c_np - - a_np, w_np, b_np, c_np = get_ref_data() - - ctx = tvm.context(device, 0) - if not ctx.exist: - print("Skip because %s is not enabled" % device) - return - with tvm.target.create(device): - dW = topi.nn.dilate(W, (1, 1, dilation, dilation)) - B = topi.nn.conv2d(A, dW, stride, padding, layout='NCHW') - s1 = topi.generic.schedule_conv2d_nchw([B]) - a = tvm.nd.array(a_np, ctx) - w = tvm.nd.array(w_np, ctx) - b = tvm.nd.array(np.zeros(util.get_const_tuple(B.shape), dtype=B.dtype), ctx) - with tvm.build_config(auto_unroll_max_step=1400, - unroll_explicit=(device != "cuda")): - func = tvm.build(s1, [A, W, B], device, name="conv2d_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation)) - #print(tvm.lower(s1, [A, W, B], simple_mode=True)) - func(a, w, b) - np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) - num_runs = 100 - timer = func.time_evaluator(func.entry_name, ctx, number=num_runs) - return timer(a, w, b).mean - -def decl_V_minimal(input_tile, alpha, C, P): - # transform image - def compute_temp(c, p, i, j): - now = tvm.const(0.0, "float32") - temp_expr = {} - temp_expr[(0,0)] = input_tile[c][p][0][0] - input_tile[c][p][2][0] - temp_expr[(0,1)] = input_tile[c][p][0][1] - input_tile[c][p][2][1] - temp_expr[(0,2)] = input_tile[c][p][0][2] - input_tile[c][p][2][2] - temp_expr[(0,3)] = input_tile[c][p][0][3] - input_tile[c][p][2][3] - temp_expr[(1,0)] = input_tile[c][p][1][0] + input_tile[c][p][2][0] - temp_expr[(1,1)] = input_tile[c][p][1][1] + input_tile[c][p][2][1] - temp_expr[(1,2)] = input_tile[c][p][1][2] + input_tile[c][p][2][2] - temp_expr[(1,3)] = input_tile[c][p][1][3] + input_tile[c][p][2][3] - temp_expr[(2,0)] = input_tile[c][p][2][0] - input_tile[c][p][1][0] - temp_expr[(2,1)] = input_tile[c][p][2][1] - input_tile[c][p][1][1] - temp_expr[(2,2)] = input_tile[c][p][2][2] - input_tile[c][p][1][2] - temp_expr[(2,3)] = input_tile[c][p][2][3] - input_tile[c][p][1][3] - temp_expr[(3,0)] = input_tile[c][p][1][0] - input_tile[c][p][3][0] - temp_expr[(3,1)] = input_tile[c][p][1][1] - input_tile[c][p][3][1] - temp_expr[(3,2)] = input_tile[c][p][1][2] - input_tile[c][p][3][2] - temp_expr[(3,3)] = input_tile[c][p][1][3] - input_tile[c][p][3][3] - for ii in range(alpha): - for jj in range(alpha): - now = tvm.select(tvm.all(i == ii, j == jj), - temp_expr[(ii, jj)], - now) - return now - - temp = tvm.compute((C, P, alpha, alpha), compute_temp, name="temp") - - def compute_V(i, j, c, p): - v_expr = {} - v_expr[(0, 0)] = temp[c][p][0][0] - temp[c][p][0][2] - v_expr[(0, 1)] = temp[c][p][0][1] + temp[c][p][0][2] - v_expr[(0, 2)] = temp[c][p][0][2] - temp[c][p][0][1] - v_expr[(0, 3)] = temp[c][p][0][1] - temp[c][p][0][3] - v_expr[(1, 0)] = temp[c][p][1][0] - temp[c][p][1][2] - v_expr[(1, 1)] = temp[c][p][1][1] + temp[c][p][1][2] - v_expr[(1, 2)] = temp[c][p][1][2] - temp[c][p][1][1] - v_expr[(1, 3)] = temp[c][p][1][1] - temp[c][p][1][3] - v_expr[(2, 0)] = temp[c][p][2][0] - temp[c][p][2][2] - v_expr[(2, 1)] = temp[c][p][2][1] + temp[c][p][2][2] - v_expr[(2, 2)] = temp[c][p][2][2] - temp[c][p][2][1] - v_expr[(2, 3)] = temp[c][p][2][1] - temp[c][p][2][3] - v_expr[(3, 0)] = temp[c][p][3][0] - temp[c][p][3][2] - v_expr[(3, 1)] = temp[c][p][3][1] + temp[c][p][3][2] - v_expr[(3, 2)] = temp[c][p][3][2] - temp[c][p][3][1] - v_expr[(3, 3)] = temp[c][p][3][1] - temp[c][p][3][3] - now = tvm.const(0.0, "float32") - for ii in range(4): - for jj in range(4): - now = tvm.select(tvm.all(i == ii, j == jj), - v_expr[(ii, jj)], - now) - return now - - V = tvm.compute((alpha, alpha, C, P), compute_V) - - return V - -def decl_output_minimal(M, N, K, H, W, P, m, nH, nW): - - def compute_temp(k, p, eps, nu): - temp_expr = {} - for j in range(4): - t0 = M[0][j][k][p] + M[1][j][k][p] - t1 = M[1][j][k][p] - M[2][j][k][p] - temp_expr[(0,j)] = t0 + M[2][j][k][p] - temp_expr[(1,j)] = t1 - M[3][j][k][p] - - now = tvm.const(0.0, "float32") - for ii in range(2): - for jj in range(4): - now = tvm.select(tvm.all(eps == ii, nu == jj), - temp_expr[(ii, jj)], - now) - return now - - temp = tvm.compute((K, P, 2,4), compute_temp, name="temp") - - def compute_output(n, k, h, w): - b = n * nH * nW + (h//m) * nW + w//m - eps = h%m - nu = w%m - output_expr = {} - for i in range(2): - t0 = temp[k][b][i][0] + temp[k][b][i][1] - t1 = temp[k][b][i][1] - temp[k][b][i][2] - output_expr[(i,0)] = t0 + temp[k][b][i][2] - output_expr[(i,1)] = t1 - temp[k][b][i][3] - - now = tvm.const(0.0, "float32") - for ii in range(2): - for jj in range(2): - now = tvm.select(tvm.all(eps == ii, nu == jj), - output_expr[(ii, jj)], - now) - return now - - output = tvm.compute((N, K, H, W), compute_output) - - return output - -def decl_winograd(data, U, stride, padding, out_dtype): - """declare winograd fast convolution F(2x2, 3x3) for conv2d""" - N, C, H, W = [util.get_const_int(x) for x in data.shape] - _, _, C, K = [util.get_const_int(x) for x in U.shape] - HPAD, WPAD = 1,1 - if isinstance(stride, (tuple, list)): - HSTR, WSTR = stride - else: - HSTR, WSTR = stride, stride - - assert HSTR == 1 and WSTR == 1 and HPAD == 1 and WPAD == 1 - data_pad = pad(data, (0, 0, HPAD, WPAD), name="data_pad") - - m = 2 - r = 3 - alpha = m + r - 1 - K = K - nH, nW = (H + m-1) // m, (W + m-1) // m - P = N * nH * nW - - # pack input tile - input_tile = tvm.compute((C, P, alpha, alpha), - lambda c, b, eps, nu: - tvm.select(b < P, data_pad[b // (nH*nW)][c][b// nW % nH * m + eps][b % nW * m + nu], tvm.const(0, data_pad.dtype)), name='d') - - V = decl_V_minimal(input_tile, alpha, C, P) - - # batch gemm - c = tvm.reduce_axis((0, C), name='c') - M = tvm.compute((alpha, alpha, K, P), lambda eps, nu, k, b: - tvm.sum(U[eps][nu][c][k] * - V[eps][nu][c][b], axis=c), name='M') - - # inverse transform and unpack - output = decl_output_minimal(M, N, K, H, W, P, m, nH, nW) - - return output - -def schedule_smem_load(s, smem, num_thread): - yi, xi, ci, ni = s[smem].op.axis - ty, ci = s[smem].split(ci, nparts=num_thread) - tx, ni = s[smem].split(ni, nparts=num_thread) - _, ni = s[smem].split(ni, factor=4) - s[smem].reorder(ty, tx, yi, xi, ci, ni) - s[smem].vectorize(ni) # vectorize memory load - s[smem].bind(ty, tvm.thread_axis("threadIdx.y")) - s[smem].bind(tx, tvm.thread_axis("threadIdx.x")) - -def schedule_batched_sgemm(s, U, V, M): - UU = s.cache_read(U, 'shared', [M]) - VV = s.cache_read(V, "shared", [M]) - UL = s.cache_read(UU, "local", [M]) - VL = s.cache_read(VV, "local", [M]) - ML = s.cache_write(M, "local") - - tile = 8 - num_thread = 8 - block_factor = tile * num_thread - step = 8 - vthread = 2 - - thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x") - thread_y = tvm.thread_axis((0, num_thread), "threadIdx.y") - thread_xz = tvm.thread_axis((0, vthread), "vthread", name="vx") - thread_yz = tvm.thread_axis((0, vthread), "vthread", name="vy") - - eps, nu, k, p = s[M].op.axis - ko, ki = s[M].split(k, factor=block_factor) - po, pi = s[M].split(p, factor=block_factor) - z = s[M].fuse(eps, nu) - - s[M].bind(z, tvm.thread_axis("blockIdx.z")) - s[M].bind(ko, tvm.thread_axis("blockIdx.y")) - s[M].bind(po, tvm.thread_axis("blockIdx.x")) - - tyz, kii = s[M].split(ki, nparts=vthread) # virtual thread split - txz, pii = s[M].split(pi, nparts=vthread) # virtual thread split - ty, kii = s[M].split(kii, nparts=num_thread) - tx, pii = s[M].split(pii, nparts=num_thread) - s[M].reorder(z, ko, po, tyz, txz, ty, tx, kii, pii) - - s[M].bind(tyz, thread_yz) - s[M].bind(txz, thread_xz) - s[M].bind(ty, thread_y) - s[M].bind(tx, thread_x) - - s[ML].compute_at(s[M], tx) - eps, nu, k, p = s[ML].op.axis - c = s[ML].op.reduce_axis[0] - co, ci = s[ML].split(c, factor=step) - s[ML].reorder(co, ci, k, p) - - s[UU].compute_at(s[ML], co) - s[VV].compute_at(s[ML], co) - s[UL].compute_at(s[ML], ci) - s[VL].compute_at(s[ML], ci) - - schedule_smem_load(s, UU, num_thread) - schedule_smem_load(s, VV, num_thread) - -def schedule_winograd(outs): - s = tvm.create_schedule([x.op for x in outs]) - op = outs[0].op - output = op.output(0) - output_temp = s[output].op.input_tensors[0] - M = s[output_temp].op.input_tensors[0] - U, V = s[M].op.input_tensors - V_temp = s[V].op.input_tensors[0] - d = s[V_temp].op.input_tensors[0] - data_pad = s[d].op.input_tensors[0] - - s[data_pad].compute_inline() - - # transform image - eps, nu, c, p = s[V].op.axis - s[V].reorder(c, p, eps, nu) - - co, ci = s[V].split(c, factor=16) - po, pi = s[V].split(p, factor=16) - s[V].bind(ci, tvm.thread_axis("threadIdx.y")) - s[V].bind(pi, tvm.thread_axis("threadIdx.x")) - s[V].bind(co, tvm.thread_axis("blockIdx.y")) - s[V].bind(po, tvm.thread_axis("blockIdx.x")) - s[V_temp].compute_at(s[V], pi) - s[d].compute_at(s[V], pi) - - schedule_batched_sgemm(s, U, V, M) - - # inverse transform - n, k, h, w = s[output].op.axis - ho, hi = s[output].split(h, factor=2) - wo, wi = s[output].split(w, factor=2) - s[output].reorder(k, n, ho, wo, hi, wi) - k = s[output].fuse(k, n) - - hoo, hoi = s[output].split(ho, factor=16) - woo, woi = s[output].split(wo, factor=16) - s[output].reorder(hoo, woo, hoi, woi, hi, wi) - s[output].bind(hoi, tvm.thread_axis("threadIdx.y")) - s[output].bind(woi, tvm.thread_axis("threadIdx.x")) - s[output].bind(hoo, tvm.thread_axis("blockIdx.y")) - s[output].bind(woo, tvm.thread_axis("blockIdx.x")) - s[output].bind(k, tvm.thread_axis("blockIdx.z")) - s[output_temp].compute_at(s[output], woi) - - return s - -def transform_filter(w_np): - num_filter, in_channel, kernel, kernel = w_np.shape - G = np.array([ - [1, 0, 0], - [1.0/2, 1.0/2, 1.0/2], - [1.0/2, -1.0/2, 1.0/2], - [0, 0, 1], - ], w_np.dtype) - - out = np.empty((4, 4, in_channel, num_filter), w_np.dtype) - for i in range(in_channel): - for j in range(num_filter): - out[:, :, i, j] = np.dot(G, np.dot(w_np[j, i], G.transpose())) - return out - - -def test_winograd(batch, in_channel, in_size, num_filter, kernel, stride, padding, device): - in_height = in_width = in_size - - A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A') - W = tvm.placeholder((num_filter, in_channel, kernel, kernel), name='W') - U = tvm.placeholder((4, 4, in_channel, num_filter), name='W') - - a_shape = util.get_const_tuple(A.shape) - w_shape = util.get_const_tuple(W.shape) - dtype = A.dtype - dilation = 1 - - @memoize("topi.tests.test_topi_conv2d_nchw.wino") - def get_ref_data(): - a_np = np.random.uniform(size=a_shape).astype(dtype) - w_np = np.random.uniform(size=w_shape).astype(dtype) - dw_np = topi.testing.dilate_python(w_np, (1, 1, dilation, dilation)) - b_np = topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding) - c_np = np.maximum(b_np, 0) - return a_np, w_np, b_np, c_np - - a_np, w_np, b_np, c_np = get_ref_data() - - with tvm.target.create(device): - B = decl_winograd(A, U, stride, padding, dtype) - s = schedule_winograd([B]) - - u_np = transform_filter(w_np) - - ctx = tvm.context(device, 0) - a = tvm.nd.array(a_np, ctx) - u = tvm.nd.array(u_np, ctx) - b = tvm.nd.array(np.zeros(util.get_const_tuple(B.shape), dtype=B.dtype), ctx) - with tvm.build_config(auto_unroll_max_step=1400, - unroll_explicit=(device != "cuda"), - partition_const_loop=False): - func = tvm.build(s, [A, U, B], device) - #print(tvm.lower(s, [A, U, B], simple_mode=True)) - func(a, u, b) - num_runs = 100 - timer = func.time_evaluator(func.entry_name, ctx, number=num_runs) - np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) - #print(func.imported_modules[0].get_source()) - return timer(a, u, b).mean - -# for copy paste as markdown -def generate_table(workloads, wino_times, direct_times, wino_nvptx_times, direct_nvptx_times, lib_times, lib_name): - print("| (batch,CI,size,CO) | TVM Winograd (This code) | TVM Direct | TVM Winograd NVPTX (This code) | TVM Direct NVPTX | %s |" % lib_name) - print("|------------- |:-------------:|:-------------:|:-------------:|:-------------:|:-------------:|") - for (workload, t_wino, t_direct, t_wino_nvptx, t_direct_nvptx, t_lib) in zip(workloads, wino_times, direct_times, wino_nvptx_times, direct_nvptx_times, lib_times): - if t_direct and t_direct_nvptx: - print("|", workload, "| %.3f | %.3f | %.3f | %.3f | %.3f" % (t_wino, t_direct, t_wino_nvptx, t_direct_nvptx, t_lib)) - elif t_direct: - print("|", workload, "| %.3f | %.3f | %.3f | N/A | %.3f" % (t_wino, t_direct, t_wino_nvptx, t_lib)) - elif t_direct_nvptx: - print("|", workload, "| %.3f | N/A | %.3f | %.3f | %.3f" % (t_wino, t_wino_nvptx, t_direct_nvptx, t_lib)) - else: - print("|", workload, "| %.3f | N/A | %.3f | N/A | %.3f" % (t_wino, t_wino_nvptx, t_lib)) - - -workloads = [(1, 128, 122, 128), - (1, 128, 128, 128), - (1, 64, 56, 64), - (1, 64, 64, 32), - (1, 64, 224, 64), - (1, 64, 112, 128), - (1, 512, 28, 512), - (1, 128, 28, 128), - (1, 256, 14, 256), - (8, 128, 122, 128), - (16, 64, 56, 64), - (32, 64, 64, 32), - (64, 128, 32, 128) - ] - -vgg_workloads = [(1, 64, 224, 64), #relu, input and output transform slow - (1, 64, 112, 128),#relu2 - (1, 128, 112, 128), - (1, 128, 56, 256), - (1, 256, 56, 256), #relu4 - (1, 256, 28, 512), - (1, 512, 28, 512), # relu6 - (1, 512, 14, 512) # relu7 - ] - -wino_times = [] -direct_times = [] -wino_nvptx_times = [] -direct_nvptx_times = [] -lib_times = [] -device = "cuda" - -for workload in workloads: - t_wino = test_winograd(*workload, 3, 1, 1, device) - t_wino_nvptx = test_winograd(*workload, 3, 1, 1, "nvptx") - - if workload[1] == 512 or workload[0] > 1: - t_direct = None # tvm direct conv2d cannot handle this workload - t_direct_nvptx = None - else: - t_direct = reference_direct(*workload, 3, 1, 1, device) - if workload[2] == 122: - t_direct_nvptx = None - else: - t_direct_nvptx = reference_direct(*workload, 3, 1, 1, "nvptx") - - #t_lib = reference_direct(*workload, 3, 1, 1, "cuda -libs=cudnn") - t_lib = 0 - wino_times.append(t_wino * 1000) - wino_nvptx_times.append(t_wino_nvptx * 1000) - lib_times.append(t_lib * 1000) - - if t_direct: - t_direct *= 1000 - if t_direct_nvptx: - t_direct_nvptx *= 1000 - - direct_times.append(t_direct) - direct_nvptx_times.append(t_direct_nvptx) - -generate_table(workloads, wino_times, direct_times, wino_nvptx_times, direct_nvptx_times, lib_times, "cuDNN Winograd")