Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add code implementation of cross entropy #491

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions python/jittor/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,8 @@ def execute(self, x):
else:
return jt.maximum(0, x) + self.weight * jt.minimum(0, x)

import jittor.other.code_cross_entropy as code_cross_entropy

#TODO dims is 4 will cause slowly execution
def cross_entropy_loss(output, target, weight=None, ignore_index=None,reduction='mean'):
target_shape = target.shape
Expand All @@ -379,7 +381,7 @@ def cross_entropy_loss(output, target, weight=None, ignore_index=None,reduction=
output = output.reshape((-1, c_dim))

target = target.reshape((-1, ))
target_weight = ((target >= 0) & (target < output.shape[1])).float32()
target_weight = ((target >= 0) & (target < output.shape[1])).astype(output.dtype)
if weight is not None:
target_weight = weight[target]
if ignore_index is not None:
Expand All @@ -394,7 +396,9 @@ def cross_entropy_loss(output, target, weight=None, ignore_index=None,reduction=

output = output - output.max([1], keepdims=True)
logsum = output.exp().sum(1).log()
loss = (logsum - (output*target).sum(1)) * target_weight
cross_entropy = (logsum - (output*target).sum(1))

loss = cross_entropy * target_weight
if reduction == 'sum':
return loss.sum()
elif reduction == 'mean':
Expand Down
96 changes: 96 additions & 0 deletions python/jittor/other/code_cross_entropy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import jittor as jt
from jittor import nn
import numpy as np


def cross_entropy(output, target):
tnum = min(512, output.shape[-1])

class CodeCrossEntropy(jt.Function):
def execute(self, x, target):
self.save_vars = [x, target]
cross_entropy = jt.code(target.shape, x.dtype, [x, target], cuda_header=f'''
#include <{jt.compile_extern.cub_home}cub/cub.cuh>
#include <type/fp16_compute.h>
#include <helper_cuda.h>
''', cuda_src=f'''
__global__ void kernel(in0_type* x, in1_type* target, out0_type* y, size_t len) {{
typedef cub::BlockReduce<float, {tnum}> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;

size_t id = blockIdx.x * len;

float v1 = -1e30;
for (size_t i = threadIdx.x; i < len; i += blockDim.x)
v1 = ::max(v1, float(x[id + i]));

__shared__ float vmax;
auto tmp = BlockReduce(temp_storage).Reduce(v1, cub::Max());
if (threadIdx.x == 0)
vmax = tmp;
__syncthreads();

v1 = 0;
for (size_t i = threadIdx.x; i < len; i += blockDim.x)
v1 += expf(float(float(x[id + i]) - vmax));

auto vsum = BlockReduce(temp_storage).Sum(v1);
if (threadIdx.x == 0)
y[blockIdx.x] = -float(x[id+target[blockIdx.x]]) + vmax + float(@expand_op(log,@in0_type,vsum));
}}
size_t len = in0->shape[in0->shape.size()-1];
size_t bnum = in0->numel() / len;
cudaGetLastError();
kernel<<<bnum, {tnum}>>>(in0_p, in1_p, out0_p, len);
getLastCudaError("Failed to run CodeCrossEntropy forward");
''')
return cross_entropy

def grad(self, grad):
x, target = self.save_vars
return jt.code(x.shape, x.dtype, [x, target, grad], cuda_header=f'''
#include <{jt.compile_extern.cub_home}cub/cub.cuh>
#include <type/fp16_compute.h>
#include <helper_cuda.h>
''', cuda_src=f'''
__global__ void kernel(in0_type* x, in1_type* target, in2_type* grad, out0_type* y, size_t len) {{
typedef cub::BlockReduce<float, {tnum}> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;

size_t id = blockIdx.x * len;

float v1 = -1e30;
for (size_t i = threadIdx.x; i < len; i += blockDim.x)
v1 = ::max(v1, float(x[id + i]));
__shared__ float vmax;
auto tmp = BlockReduce(temp_storage).Reduce(v1, cub::Max());
if (threadIdx.x == 0)
vmax = tmp;
__syncthreads();

v1 = 0;
for (size_t i = threadIdx.x; i < len; i += blockDim.x) {{
y[id + i] = expf(float(x[id + i]) - vmax);
v1 += float(y[id + i]);
}}

tmp = BlockReduce(temp_storage).Sum(v1);
__shared__ float vsum;
if (threadIdx.x == 0)
vsum = tmp;
__syncthreads();

for (size_t i = threadIdx.x; i < len; i += blockDim.x)
y[id + i] = float(y[id + i]) / vsum * float(grad[blockIdx.x]);
__syncthreads();

if (threadIdx.x == 0)
y[id + target[blockIdx.x]] -= grad[blockIdx.x];
}}
size_t len = in0->shape[in0->shape.size()-1];
size_t bnum = in0->numel() / len;
cudaGetLastError();
kernel<<<bnum, {tnum}>>>(in0_p, in1_p, in2_p, out0_p, len);
getLastCudaError("Failed to run CodeCrossEntropy backward");
''')
return CodeCrossEntropy()(output, target)
21 changes: 12 additions & 9 deletions python/jittor/other/code_softmax.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,25 @@
import jittor as jt
from jittor import nn
import numpy as np

def can_softmax_v1(a, dim):
if not jt.flags.use_cuda:
return False
if dim != -1 and dim != len(a.shape)-1:
return False
if a.shape[len(a.shape)-1] > 10000:
if a.shape[-1] > 10000 and np.prod(a.shape[:-1]) < 64:
return False
return True

def softmax_v1(a, log=False):
assert can_softmax_v1(a, -1)
length = a.shape[-1]
# tnum = 1024
tnum = 500 if length % 500 == 0 else 512
tnum = 125 if length % 125 == 0 else 128
# tnum = 125
# tnum = 1000 if length % 1000 == 0 else 1024
# tnum = 250

if length < 65536:
tnum = 250 if length % 250 == 0 else 256
else:
tnum = 125 if length % 125 == 0 else 128

per_thread = (length-1) // tnum + 1
ILP = 1
for ilp in [8,4,2]:
Expand All @@ -38,6 +39,7 @@ def execute(self, x):
self.save_vars = jt.code(x.shape, x.dtype, [x], cuda_header=f'''
#include <{jt.compile_extern.cub_home}cub/cub.cuh>
#include <type/fp16_compute.h>
#include <helper_cuda.h>
''', cuda_src=f'''
__global__ void kernel(in0_type* x, out0_type* y, int len) {{
typedef cub::BlockReduce<float, {tnum}> BlockReduce;
Expand Down Expand Up @@ -95,7 +97,7 @@ def execute(self, x):
int bnum = in0->numel() / len;
cudaGetLastError();
kernel<<<bnum, {tnum}>>>(in0_p, out0_p, len);
CHECK(0 == cudaGetLastError());
getLastCudaError("Failed to run CodeSoftmax forward");
''')
return self.save_vars

Expand All @@ -104,6 +106,7 @@ def grad(self, grad_x):
return jt.code(x.shape, x.dtype, [x, grad_x], cuda_header=f'''
#include <{jt.compile_extern.cub_home}cub/cub.cuh>
#include <type/fp16_compute.h>
#include <helper_cuda.h>
''',
cuda_src=f"""
__global__ void kernel(in0_type* x, in1_type* y, out0_type* z, int len) {{
Expand Down Expand Up @@ -144,6 +147,6 @@ def grad(self, grad_x):
int bnum = in0->numel() / len;
cudaGetLastError();
kernel<<<bnum, {tnum}>>>(in0_p, in1_p, out0_p, len);
CHECK(0 == cudaGetLastError());
getLastCudaError("Failed to run CodeSoftmax backward");
""")
return CodeSoftmax()(a)
53 changes: 52 additions & 1 deletion python/jittor/test/test_misc_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def softmax(x, dim = None, log=False):
x = (x - x.max()).exp()
ret = x / x.sum()
else:
x = (x-x.max(dim, keepdims=True)).exp()
x = (x-x.max(dim, keepdims=True)[0]).exp()
ret = x / x.sum(dim, keepdims=True)
if log: return ret.log()
return ret
Expand Down Expand Up @@ -323,6 +323,57 @@ def softmax(x, dim = None, log=False):
else:
assert err.item() < 1e-5, (err.item())

def test_code_cross_entropy_loss(self):
if not jt.has_cuda: return

def naive_cross_entropy_loss(output, target, weight=None, ignore_index=None, reduction='mean'):
target_shape = target.shape
if len(output.shape) == 4:
c_dim = output.shape[1]
output = output.transpose((0, 2, 3, 1))
output = output.reshape((-1, c_dim))

target = target.reshape((-1, ))
target_weight = ((target >= 0) & (target < output.shape[1])).float32()
if weight is not None:
target_weight = weight[target]
if ignore_index is not None:
target_weight = jt.ternary(
target==ignore_index,
jt.array(0).broadcast(target_weight),
target_weight
)

import jittor.other.code_cross_entropy as code_cross_entropy
cross_entropy = code_cross_entropy.cross_entropy(output, target)

loss = cross_entropy * target_weight
if reduction == 'sum':
return loss.sum()
elif reduction == 'mean':
return loss.mean() / target_weight.mean()
else:
return loss.reshape(target_shape)


jt.set_global_seed(42)

with jt.flag_scope(use_cuda = 1):
for dtype in ["float16", "bfloat16", "float32"]:
for shape in [(3, 3), (200, 2000), (200, 2049), (16380, 65000)]:
print(shape)
x = jt.rand(shape, dtype=dtype)
target = jt.randint(0, x.shape[1], (x.shape[0],))
b = naive_cross_entropy_loss(x, target)
d1 = jt.grad(b, x)
bb = jt.nn.cross_entropy_loss(x, target)
d2 = jt.grad(bb, x)
jt.sync_all(True)

np.testing.assert_allclose(bb.astype(jt.float32).data, b.astype(jt.float32).data, rtol=1e-3, atol=1e-2)
np.testing.assert_allclose(bb.astype(jt.float32).data, b.astype(jt.float32).data, rtol=1e-3, atol=1e-2)


def test_nan(self):
a = np.array([1.0,0.0,1.0,-1.0], "float32") / np.array([1.0,0.0,0.0,0.0], "float32")
np.testing.assert_allclose(jt.isnan(jt.array(a)).data, [0,1,0,0])
Expand Down