From 2db7b94b7fabae08e21dd467862ba5cac5e890ff Mon Sep 17 00:00:00 2001 From: Exusial <2247838039@qq.com> Date: Mon, 8 Feb 2021 18:50:05 +0800 Subject: [PATCH 01/36] fix bicubic,add fold. --- python/jittor/linalg.py | 55 +- python/jittor/nn.py | 1001 +++++++++++++++++------------ python/jittor/test/test_linalg.py | 1 - 3 files changed, 652 insertions(+), 405 deletions(-) diff --git a/python/jittor/linalg.py b/python/jittor/linalg.py index 99789806..7c869218 100644 --- a/python/jittor/linalg.py +++ b/python/jittor/linalg.py @@ -11,6 +11,7 @@ import jittor as jt from functools import partial + #TODO:full_matrices=1 def svd(x): @@ -81,6 +82,7 @@ def T(x): ) return u, s, v + def eigh(x): def forward_code(np, data): @@ -122,6 +124,7 @@ def T(x): ) return w, v + def inv(x): def forward_code(np, data): @@ -151,6 +154,7 @@ def T(x): mx = lmx[0] return mx + def pinv(x): def forward_code(np, data): @@ -185,6 +189,7 @@ def T(x): mx = lmx[0] return mx + def det(x): def forward_code(np, data): @@ -220,6 +225,7 @@ def T(x): det = l_det[0] return det + def slogdet(x): def forward_code(np, data): a = data["inputs"][0] @@ -256,6 +262,7 @@ def T(x): ) return sign, mx + def cholesky(x): def forward_code(np, data): @@ -291,6 +298,7 @@ def conjugate_solve(L, X): L = lL[0] return L + def solve(a,b): def forward_code(np, data): @@ -323,4 +331,49 @@ def backward_code2(np, data): [backward_code1, backward_code2], ) ans = l_ans[0] - return ans \ No newline at end of file + return ans + + +def qr(x): + def forward_code(np, data): + a = data["inputs"][0] + q, r = data["outputs"] + Q, R = np.linalg.qr(a) + np.copyto(q,Q) + np.copyto(r,R) + + def backward_code(np, data): + def T(x): + return np.swapaxes(x, -1, -2) + _dot = partial(np.einsum, '...ij,...jk->...ik') + _harmard = partial(np.einsum, '...ij,...ij->...ij') + dout = data["dout"] + out = data["outputs"][0] + q, r = data["f_outputs"] + out_index = data["out_index"] + #pl = np.tril(np.ones((inp.shape[-1],inp.shape[-1])))-diags + if out_index == 0: # Q_TERM + q_t = _dot(T(q),dout) + rhs_solve = q_t - T(q_t) + rhs_solve = T(np.tril(rhs_solve,-1)) + qsolve = np.linalg.solve(r,rhs_solve) + qsolve = T(qsolve) + tq = _dot(q,qsolve) + np.copyto(out,tq) + else: #R_TERM + r_t = _dot(r ,T(dout)) + rhs_solve = r_t - T(r_t) + rhs_solve = np.tril(rhs_solve,-1) + rhs_solve = T(rhs_solve) + r_solve = np.linalg.solve(r,rhs_solve) + tr = _dot(q,(T(r_solve) + dout)) + np.copyto(out,tr) + + q, r = jt.numpy_code( + [x.shape,x.shape], + [x.dtype,x.dtype], + [x], + forward_code, + [backward_code], + ) + return q, r diff --git a/python/jittor/nn.py b/python/jittor/nn.py index a5a2b71b..20576dff 100644 --- a/python/jittor/nn.py +++ b/python/jittor/nn.py @@ -28,15 +28,15 @@ def matmul_transpose(a, b): ''' assert len(a.shape) >= 2 and len(b.shape) == 2 assert a.shape[-1] == b.shape[-1] - if len(a.shape)>2: + if len(a.shape) > 2: aa = a.reshape((-1, a.shape[-1])) cc = matmul_transpose(aa, b) - return cc.reshape(a.shape[:-1]+(-1,)) + return cc.reshape(a.shape[:-1] + (-1,)) shape = list(a.shape)[:-1] + list(b.shape) - a = a.broadcast(shape, [len(shape)-2]) + a = a.broadcast(shape, [len(shape) - 2]) b = b.broadcast(shape) - return (a*b).sum(len(shape)-1) + return (a * b).sum(len(shape) - 1) def bmm_transpose(a, b): @@ -70,6 +70,7 @@ def bmm(a, b): assert len(a.shape) > 2 and len(b.shape) > 2 return matmul(a, b) + def matmul(a, b): ''' matrix multiply, @@ -109,11 +110,11 @@ def matmul(a, b): len_b = len(b.shape) if len_b == 1: # a: [n, m], b:[m], c:[n] - return (a*b).sum(-1) + return (a * b).sum(-1) if len_a == 1: # a: [n], b:[n,k], c:[k] return (a.broadcast(b, [-1]) * b).sum(0) - if len_a>=3 and len_a==len_b: + if len_a >= 3 and len_a == len_b: # bmm # a: [..., n, m], b: [..., m, k], c:[..., n, k] if jt.flags.use_cuda: @@ -127,52 +128,67 @@ def matmul(a, b): # cc:[..., n, m, k] # --> # 012 - if len_b == 2 and len_a>2: + if len_b == 2 and len_a > 2: # TODO:ugly implementation for tuner aa = a.reshape((-1, m)) cc = matmul(aa, b) # print(a.shape, b.shape, cc.shape) return cc.reshape(a.shape[:-1] + [k]) - for i in range(len_c-2): - ai = len_a-(len_c-i) - bi = len_b-(len_c-i) - an = a.shape[ai] if ai>=0 else 1 - bn = b.shape[bi] if bi>=0 else 1 - if an!=1 and bn!=1: + for i in range(len_c - 2): + ai = len_a - (len_c - i) + bi = len_b - (len_c - i) + an = a.shape[ai] if ai >= 0 else 1 + bn = b.shape[bi] if bi >= 0 else 1 + if an != 1 and bn != 1: assert an == bn, f"dimension not match, a.shape:{a.shape}, b.shape:{a.shape}" cn = max(an, bn) shape.append(cn) shape.extend([n, m, k]) a = a.broadcast(shape, [-1]) b = b.broadcast(shape, [-3]) - return (a*b).sum(-2) + return (a * b).sum(-2) + + jt.Var.matmul = jt.Var.__matmul__ = matmul -jt.Var.__imatmul__ = lambda a,b: a.assign(matmul(a,b)) +jt.Var.__imatmul__ = lambda a, b: a.assign(matmul(a, b)) + def get_init_var_rand(shape, dtype): return jt.array(np.random.normal(0.0, 1.0, shape).astype(np.float32)) -def relu(x): return jt.ternary((x>0.0), x, jt.broadcast_var(0.0, x)) -def leaky_relu(x, scale=0.01): return jt.ternary(x>0, x, x*scale) + +def relu(x): return jt.ternary((x > 0.0), x, jt.broadcast_var(0.0, x)) + + +def leaky_relu(x, scale=0.01): return jt.ternary(x > 0, x, x * scale) + + def relu6(x): return jt.minimum(jt.maximum(x, 0.0), 6.0) -def elu(x,alpha=1.0):return jt.ternary(x>0,x,alpha*(x.exp()-1)) + + +def elu(x, alpha=1.0): return jt.ternary(x > 0, x, alpha * (x.exp() - 1)) + + def sign(x): one = jt.ones(x.shape) - x = jt.ternary(x>0, one, x) - return jt.ternary(x<0, -one, x) + x = jt.ternary(x > 0, one, x) + return jt.ternary(x < 0, -one, x) + def gelu(x): _sqrt2 = 1.4142135623730951 - erf = jt.erf(x/_sqrt2)+1 - r = erf*x*.5 + erf = jt.erf(x / _sqrt2) + 1 + r = erf * x * .5 return r + class ELU(Module): - def __init__(self,alpha=1.0): - self.alpha=alpha - - def execute(self,x): - return elu(x,self.alpha) + def __init__(self, alpha=1.0): + self.alpha = alpha + + def execute(self, x): + return elu(x, self.alpha) + class PReLU(Module): def __init__(self, num_parameters=1, init_=0.25): @@ -182,51 +198,55 @@ def __init__(self, num_parameters=1, init_=0.25): def execute(self, x): if self.num_parameters != 1: assert self.num_parameters == x.size(1), f"num_parameters does not match input channels in PReLU" - return jt.maximum(0, x) + self.a.broadcast(x, [0,2,3]) * jt.minimum(0, x) + return jt.maximum(0, x) + self.a.broadcast(x, [0, 2, 3]) * jt.minimum(0, x) else: return jt.maximum(0, x) + self.a * jt.minimum(0, x) -#TODO dims is 4 will cause slowly execution + +# TODO dims is 4 will cause slowly execution def cross_entropy_loss(output, target, ignore_index=None): if len(output.shape) == 4: c_dim = output.shape[1] output = output.transpose((0, 2, 3, 1)) output = output.reshape((-1, c_dim)) if ignore_index is not None: - target = jt.ternary(target==ignore_index, - jt.array(-1).broadcast(target), target) + target = jt.ternary(target == ignore_index, + jt.array(-1).broadcast(target), target) mask = jt.logical_and(target >= 0, target < output.shape[1]) - target = target.reshape((-1, )) + target = target.reshape((-1,)) target = target.broadcast(output, [1]) target = target.index(1) == target - + output = output - output.max([1], keepdims=True) loss = output.exp().sum(1).log() - loss = loss - (output*target).sum(1) + loss = loss - (output * target).sum(1) if ignore_index is None: return loss.mean() else: return loss.sum() / jt.maximum(mask.int().sum(), 1) + def mse_loss(output, target): - return (output-target).sqr().mean() + return (output - target).sqr().mean() + def bce_loss(output, target, weight=None, size_average=True): loss = - (target * jt.log(jt.maximum(output, 1e-20)) + (1 - target) * jt.log(jt.maximum(1 - output, 1e-20))) if weight is not None: loss *= weight - + if size_average: return loss.mean() else: return loss.sum() + def l1_loss(output, target): - return (output-target).abs().mean() + return (output - target).abs().mean() -def smooth_l1_loss(y_true, y_pred,reduction="mean"): +def smooth_l1_loss(y_true, y_pred, reduction="mean"): """Implements Smooth-L1 loss. y_true and y_pred are typically: [N, 4], but could be any shape. @@ -236,82 +256,92 @@ def smooth_l1_loss(y_true, y_pred,reduction="mean"): reduction - the mode of cal loss which must be in ['mean','sum','none'] """ diff = jt.abs(y_true - y_pred) - less_than_one = (diff<1.0).float32() + less_than_one = (diff < 1.0).float32() loss = (less_than_one * 0.5 * diff.sqr()) + (1 - less_than_one) * (diff - 0.5) - if reduction=="mean": + if reduction == "mean": return loss.mean() - elif reduction=="sum": + elif reduction == "sum": return loss.sum() - elif reduction=="none": + elif reduction == "none": return loss else: raise ValueError(f'not support {reduction}') -def nll_loss(output,target,weight=None,ignore_index=-100,reduction='mean'): - assert output.ndim<=2 and output.ndim>0 and target.ndim==1 + +def nll_loss(output, target, weight=None, ignore_index=-100, reduction='mean'): + assert output.ndim <= 2 and output.ndim > 0 and target.ndim == 1 n_classes = output.shape[-1] - assert weight is None or weight.numel()==n_classes - assert ignore_index<0 or ignore_index0: - weight[ignore_index]=0 - if output.ndim==2: - index = jt.index((output.shape[0],),dim=0) - loss = -output[index,target]*weight[target] + if ignore_index > 0: + weight[ignore_index] = 0 + if output.ndim == 2: + index = jt.index((output.shape[0],), dim=0) + loss = -output[index, target] * weight[target] else: - loss = -output[target[0]]*weight[target[0]] - if reduction=="mean": - total_weight = weight[target].sum() if output.ndim==2 else weight[target[0]].sum() - return loss.sum()/total_weight - elif reduction=="sum": + loss = -output[target[0]] * weight[target[0]] + if reduction == "mean": + total_weight = weight[target].sum() if output.ndim == 2 else weight[target[0]].sum() + return loss.sum() / total_weight + elif reduction == "sum": return loss.sum() - elif reduction=="none": + elif reduction == "none": return loss else: raise ValueError(f'not support {reduction}') - + + class CrossEntropyLoss(Module): - def __init__(self,ignore_index=None): + def __init__(self, ignore_index=None): self.ignore_index = ignore_index - + def execute(self, output, target): - return cross_entropy_loss(output, target,self.ignore_index) + return cross_entropy_loss(output, target, self.ignore_index) + class MSELoss(Module): def __init__(self): pass + def execute(self, output, target): return mse_loss(output, target) + class BCELoss(Module): def __init__(self, weight=None, size_average=True): self.weight = weight self.size_average = size_average + def execute(self, output, target): return bce_loss(output, target, self.weight, self.size_average) + class L1Loss(Module): def __init__(self): pass + def execute(self, output, target): return l1_loss(output, target) + def binary_cross_entropy_with_logits(output, target, weight=None, pos_weight=None, size_average=True): - max_val = jt.clamp(-output,min_v=0) + max_val = jt.clamp(-output, min_v=0) if pos_weight is not None: - log_weight = (pos_weight-1)*target + 1 - loss = (1-target)*output+(log_weight*(((-max_val).exp()+(-output - max_val).exp()).log()+max_val)) + log_weight = (pos_weight - 1) * target + 1 + loss = (1 - target) * output + (log_weight * (((-max_val).exp() + (-output - max_val).exp()).log() + max_val)) else: - loss = (1-target)*output+max_val+((-max_val).exp()+(-output -max_val).exp()).log() + loss = (1 - target) * output + max_val + ((-max_val).exp() + (-output - max_val).exp()).log() if weight is not None: - loss *=weight + loss *= weight if size_average: return loss.mean() else: return loss.sum() + class BCEWithLogitsLoss(Module): def __init__(self, weight=None, pos_weight=None, size_average=True): self.pos_weight = pos_weight @@ -319,21 +349,24 @@ def __init__(self, weight=None, pos_weight=None, size_average=True): self.size_average = size_average def execute(self, output, target): - return binary_cross_entropy_with_logits(output,target,self.weight,self.pos_weight,self.size_average) + return binary_cross_entropy_with_logits(output, target, self.weight, self.pos_weight, self.size_average) -def softmax(x, dim = None): + +def softmax(x, dim=None): if dim is None: x = (x - x.max()).exp() ret = x / x.sum() else: - x = (x-x.max(dim, keepdims=True)).exp() + x = (x - x.max(dim, keepdims=True)).exp() ret = x / x.sum(dim, keepdims=True) return ret -def log_softmax(x,dim=None): - x = softmax(x,dim=dim) + +def log_softmax(x, dim=None): + x = softmax(x, dim=dim) return jt.log(x) + def log_sigmoid(x): return jt.log(jt.sigmoid(x)) @@ -345,12 +378,14 @@ def __init__(self, *args, **kwargs): def execute(self, input): return input + class Dropout(Module): def __init__(self, p=0.5, is_train=False): assert p >= 0 and p <= 1, "dropout probability has to be between 0 and 1, but got {}".format(p) self.p = p self.is_train = is_train - #TODO: test model.train() to change self.is_train + # TODO: test model.train() to change self.is_train + def execute(self, input): output = input if self.p > 0 and self.is_train: @@ -360,19 +395,21 @@ def execute(self, input): else: noise = jt.random(input.shape) noise = (noise > self.p).int() - output = output * noise / (1.0 - self.p) # div keep prob + output = output * noise / (1.0 - self.p) # div keep prob return output -def dropout(x,p=0.5,is_train=False): - return Dropout(p=p,is_train=is_train)(x) + +def dropout(x, p=0.5, is_train=False): + return Dropout(p=p, is_train=is_train)(x) + class Linear(Module): def __init__(self, in_features, out_features, bias=True): self.in_features = in_features self.out_features = out_features self.weight = init.invariant_uniform((out_features, in_features), "float32") - bound = 1.0/math.sqrt(in_features) - self.bias = init.uniform((out_features,), "float32",-bound,bound) if bias else None + bound = 1.0 / math.sqrt(in_features) + self.bias = init.uniform((out_features,), "float32", -bound, bound) if bias else None def execute(self, x): x = matmul_transpose(x, self.weight) @@ -380,6 +417,7 @@ def execute(self, x): return x + self.bias return x + class BatchNorm(Module): def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, is_train=True, sync=True): self.sync = sync @@ -394,32 +432,34 @@ def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, is_train=T self.running_var = init.constant((num_features,), "float32", 1.0).stop_grad() def execute(self, x): - dims = [0]+list(range(2,x.ndim)) + dims = [0] + list(range(2, x.ndim)) if self.is_train: xmean = jt.mean(x, dims=dims) - x2mean = jt.mean(x*x, dims=dims) + x2mean = jt.mean(x * x, dims=dims) if self.sync and jt.in_mpi: xmean = xmean.mpi_all_reduce("mean") x2mean = x2mean.mpi_all_reduce("mean") - xvar = (x2mean-xmean*xmean).maximum(0.0) - w = self.weight / jt.sqrt(xvar+self.eps) + xvar = (x2mean - xmean * xmean).maximum(0.0) + w = self.weight / jt.sqrt(xvar + self.eps) b = self.bias - xmean * w norm_x = x * w.broadcast(x, dims) + b.broadcast(x, dims) self.running_mean.update(self.running_mean + - (xmean.reshape((-1,)) - self.running_mean) * self.momentum) + (xmean.reshape((-1,)) - self.running_mean) * self.momentum) self.running_var.update(self.running_var + - (xvar.reshape((-1,))-self.running_var)*self.momentum) + (xvar.reshape((-1,)) - self.running_var) * self.momentum) return norm_x else: - w = self.weight / jt.sqrt(self.running_var+self.eps) + w = self.weight / jt.sqrt(self.running_var + self.eps) b = self.bias - self.running_mean * w norm_x = x * w.broadcast(x, dims) + b.broadcast(x, dims) return norm_x + BatchNorm2d = BatchNorm1d = BatchNorm + class InstanceNorm(Module): def __init__(self, num_features, eps=1e-05, momentum=0.1, affine=True, is_train=True, sync=True): self.sync = sync @@ -433,17 +473,19 @@ def __init__(self, num_features, eps=1e-05, momentum=0.1, affine=True, is_train= self.bias = init.constant((num_features,), "float32", 0.0) if affine else 0.0 def execute(self, x): - dims = list(range(2,x.ndim)) + dims = list(range(2, x.ndim)) xmean = jt.mean(x, dims=dims) - x2mean = jt.mean(x*x, dims=dims) + x2mean = jt.mean(x * x, dims=dims) - xvar = (x2mean-xmean*xmean).maximum(0.0) - w = self.weight / jt.sqrt(xvar+self.eps) + xvar = (x2mean - xmean * xmean).maximum(0.0) + w = self.weight / jt.sqrt(xvar + self.eps) b = self.bias - xmean * w return x * w.broadcast(x, dims) + b.broadcast(x, dims) + InstanceNorm2d = InstanceNorm1d = InstanceNorm + class LayerNorm(Module): def __init__(self, normalized_shape, eps: float = 1e-5, elementwise_affine: bool = True) -> None: if isinstance(normalized_shape, int): @@ -457,16 +499,17 @@ def __init__(self, normalized_shape, eps: float = 1e-5, elementwise_affine: bool def execute(self, x): dims = [-i for i in range(len(self.normalized_shape), 0, -1)] xmean = jt.mean(x, dims=dims, keepdims=1) - x2mean = jt.mean(x*x, dims=dims, keepdims=1) + x2mean = jt.mean(x * x, dims=dims, keepdims=1) - xvar = (x2mean-xmean*xmean).maximum(0.0) - w = self.weight / jt.sqrt(xvar+self.eps) + xvar = (x2mean - xmean * xmean).maximum(0.0) + w = self.weight / jt.sqrt(xvar + self.eps) b = self.bias - xmean * w return x * w + b LayerNorm2d = LayerNorm1d = LayerNorm + class GroupNorm(Module): def __init__(self, num_groups, num_channels, eps=1e-05, affine=True, is_train=True): self.num_groups = num_groups @@ -480,15 +523,15 @@ def __init__(self, num_groups, num_channels, eps=1e-05, affine=True, is_train=Tr def execute(self, x): N = x.shape[0] C = self.num_channels - output_shape = (N,-1) - # TODO: 3d group norm - if x.ndim==4: + output_shape = (N, -1) + # TODO: 3d group norm + if x.ndim == 4: output_shape = x.shape assert C % self.num_groups == 0 - x = x.reshape((N, self.num_groups, C//self.num_groups, -1)) - xmean = jt.mean(x, dims=[2,3]).reshape((N, self.num_groups, 1)) - x2mean = jt.mean(x*x, dims=[2,3]).reshape((N, self.num_groups, 1)) - xvar = (x2mean-xmean*xmean).maximum(0.0) + x = x.reshape((N, self.num_groups, C // self.num_groups, -1)) + xmean = jt.mean(x, dims=[2, 3]).reshape((N, self.num_groups, 1)) + x2mean = jt.mean(x * x, dims=[2, 3]).reshape((N, self.num_groups, 1)) + xvar = (x2mean - xmean * xmean).maximum(0.0) if self.affine: w = self.weight.reshape((1, self.num_groups, -1)) @@ -496,11 +539,12 @@ def execute(self, x): else: w = 1 b = 0 - w = w / jt.sqrt(xvar+self.eps) + w = w / jt.sqrt(xvar + self.eps) b = b - xmean * w x = x * w.broadcast(x, [3]) + b.broadcast(x, [3]) return x.reshape(output_shape) + Relu = jt.make_module(relu) ReLU = Relu Leaky_relu = jt.make_module(leaky_relu, 2) @@ -511,6 +555,7 @@ def execute(self, x): from jittor.depthwise_conv import DepthwiseConv + class Conv(Module): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True): self.in_channels = in_channels @@ -531,9 +576,9 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, assert out_channels % groups == 0, 'out_channels must be divisible by groups' # self.weight = init.relu_invariant_gauss([out_channels, in_channels//groups, Kh, Kw], dtype="float", mode="fan_out") - self.weight = init.invariant_uniform([out_channels, in_channels//groups, Kh, Kw], dtype="float") + self.weight = init.invariant_uniform([out_channels, in_channels // groups, Kh, Kw], dtype="float") if bias: - fan=1 + fan = 1 for i in self.weight.shape[1:]: fan *= i bound = 1 / math.sqrt(fan) @@ -545,65 +590,67 @@ def execute(self, x): if self.is_depthwise_conv and jt.flags.use_cuda: y = self.depthwise_conv(x, self.weight) if self.bias is not None: - b = self.bias.broadcast(y.shape, [0,2,3]) + b = self.bias.broadcast(y.shape, [0, 2, 3]) y = y + b return y elif self.groups == 1: - N,C,H,W = x.shape + N, C, H, W = x.shape Kh, Kw = self.kernel_size - assert C==self.in_channels - oh = (H+self.padding[0]*2-Kh*self.dilation[0]+self.dilation[0]-1)//self.stride[0]+1 - ow = (W+self.padding[1]*2-Kw*self.dilation[1]+self.dilation[1]-1)//self.stride[1]+1 - xx = x.reindex([N,self.out_channels,C,oh,ow,Kh,Kw], [ - 'i0', # Nid - 'i2', # Cid - f'i3*{self.stride[0]}-{self.padding[0]}+i5*{self.dilation[0]}', # Hid+Khid - f'i4*{self.stride[1]}-{self.padding[1]}+i6*{self.dilation[1]}', # Wid+KWid + assert C == self.in_channels + oh = (H + self.padding[0] * 2 - Kh * self.dilation[0] + self.dilation[0] - 1) // self.stride[0] + 1 + ow = (W + self.padding[1] * 2 - Kw * self.dilation[1] + self.dilation[1] - 1) // self.stride[1] + 1 + xx = x.reindex([N, self.out_channels, C, oh, ow, Kh, Kw], [ + 'i0', # Nid + 'i2', # Cid + f'i3*{self.stride[0]}-{self.padding[0]}+i5*{self.dilation[0]}', # Hid+Khid + f'i4*{self.stride[1]}-{self.padding[1]}+i6*{self.dilation[1]}', # Wid+KWid ]) - ww = self.weight.broadcast(xx.shape, [0,3,4]) - yy = xx*ww - y = yy.sum([2,5,6]) # Kc, Kh, Kw + ww = self.weight.broadcast(xx.shape, [0, 3, 4]) + yy = xx * ww + y = yy.sum([2, 5, 6]) # Kc, Kh, Kw if self.bias is not None: - b = self.bias.broadcast(y.shape, [0,2,3]) + b = self.bias.broadcast(y.shape, [0, 2, 3]) y = y + b return y else: - N,C,H,W = x.shape + N, C, H, W = x.shape Kh, Kw = self.kernel_size G = self.groups - CpG = C // G # channels per group - assert C==self.in_channels + CpG = C // G # channels per group + assert C == self.in_channels oc = self.out_channels - oh = (H+self.padding[0]*2-Kh*self.dilation[0]+self.dilation[0]-1)//self.stride[0]+1 - ow = (W+self.padding[1]*2-Kw*self.dilation[1]+self.dilation[1]-1)//self.stride[1]+1 - xx = x.reindex([N,G,oc//G,CpG,oh,ow,Kh,Kw], [ - 'i0', # Nid - f'i1*{CpG}+i3', # Gid - f'i4*{self.stride[0]}-{self.padding[0]}+i6*{self.dilation[0]}', # Hid+Khid - f'i5*{self.stride[1]}-{self.padding[1]}+i7*{self.dilation[1]}', # Wid+KWid + oh = (H + self.padding[0] * 2 - Kh * self.dilation[0] + self.dilation[0] - 1) // self.stride[0] + 1 + ow = (W + self.padding[1] * 2 - Kw * self.dilation[1] + self.dilation[1] - 1) // self.stride[1] + 1 + xx = x.reindex([N, G, oc // G, CpG, oh, ow, Kh, Kw], [ + 'i0', # Nid + f'i1*{CpG}+i3', # Gid + f'i4*{self.stride[0]}-{self.padding[0]}+i6*{self.dilation[0]}', # Hid+Khid + f'i5*{self.stride[1]}-{self.padding[1]}+i7*{self.dilation[1]}', # Wid+KWid ]) # w: [oc, CpG, Kh, Kw] - ww = self.weight.reindex([N, G, oc//G, CpG, oh, ow, Kh, Kw], [ - f'i1*{oc//G}+i2', + ww = self.weight.reindex([N, G, oc // G, CpG, oh, ow, Kh, Kw], [ + f'i1*{oc // G}+i2', 'i3', 'i6', 'i7' ]) - ww.compile_options = xx.compile_options = {"G":G,"C":C} - yy = xx*ww + ww.compile_options = xx.compile_options = {"G": G, "C": C} + yy = xx * ww y = yy.reindex_reduce('add', [N, oc, oh, ow], [ 'i0', - f'i1*{oc//G}+i2', + f'i1*{oc // G}+i2', 'i4', 'i5' ]) if self.bias is not None: - b = self.bias.broadcast(y.shape, [0,2,3]) + b = self.bias.broadcast(y.shape, [0, 2, 3]) y = y + b - return y + return y + Conv2d = Conv + class Conv1d(Module): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True): self.in_channels = in_channels @@ -616,11 +663,12 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, self.bias = bias assert in_channels % groups == 0, 'in_channels must be divisible by groups' assert out_channels % groups == 0, 'out_channels must be divisible by groups' - self.conv = Conv(self.in_channels, self.out_channels, self.kernel_size, self.stride, self.padding, self.dilation, self.groups, self.bias) + self.conv = Conv(self.in_channels, self.out_channels, self.kernel_size, self.stride, self.padding, + self.dilation, self.groups, self.bias) def execute(self, x): - N,C,D = x.shape - assert C==self.in_channels + N, C, D = x.shape + assert C == self.in_channels x = x.unsqueeze(-1) x = self.conv(x) y = x.squeeze(-1) @@ -634,56 +682,57 @@ def conv2d(x, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): out_channels = weight.shape[0] if groups == 1: - N,C,H,W = x.shape + N, C, H, W = x.shape Kh, Kw = weight.shape[-2:] - oh = (H+padding[0]*2-Kh*dilation[0]+dilation[0]-1)//stride[0]+1 - ow = (W+padding[1]*2-Kw*dilation[1]+dilation[1]-1)//stride[1]+1 - xx = x.reindex([N,out_channels,C,oh,ow,Kh,Kw], [ - 'i0', # Nid - 'i2', # Cid - f'i3*{stride[0]}-{padding[0]}+i5*{dilation[0]}', # Hid+Khid - f'i4*{stride[1]}-{padding[1]}+i6*{dilation[1]}', # Wid+KWid - ]) - ww = weight.broadcast(xx.shape, [0,3,4]) - yy = xx*ww - y = yy.sum([2,5,6]) # Kc, Kh, Kw + oh = (H + padding[0] * 2 - Kh * dilation[0] + dilation[0] - 1) // stride[0] + 1 + ow = (W + padding[1] * 2 - Kw * dilation[1] + dilation[1] - 1) // stride[1] + 1 + xx = x.reindex([N, out_channels, C, oh, ow, Kh, Kw], [ + 'i0', # Nid + 'i2', # Cid + f'i3*{stride[0]}-{padding[0]}+i5*{dilation[0]}', # Hid+Khid + f'i4*{stride[1]}-{padding[1]}+i6*{dilation[1]}', # Wid+KWid + ]) + ww = weight.broadcast(xx.shape, [0, 3, 4]) + yy = xx * ww + y = yy.sum([2, 5, 6]) # Kc, Kh, Kw if bias is not None: - b = bias.broadcast(y.shape, [0,2,3]) + b = bias.broadcast(y.shape, [0, 2, 3]) y = y + b return y else: - N,C,H,W = x.shape + N, C, H, W = x.shape Kh, Kw = weight.shape[-2:] G = groups - CpG = C // G # channels per group + CpG = C // G # channels per group oc = out_channels - oh = (H+padding[0]*2-Kh*dilation[0]+dilation[0]-1)//stride[0]+1 - ow = (W+padding[1]*2-Kw*dilation[1]+dilation[1]-1)//stride[1]+1 - xx = x.reindex([N,G,oc//G,CpG,oh,ow,Kh,Kw], [ - 'i0', # Nid - f'i1*{CpG}+i3', # Gid - f'i4*{stride[0]}-{padding[0]}+i6*{dilation[0]}', # Hid+Khid - f'i5*{stride[1]}-{padding[1]}+i7*{dilation[1]}', # Wid+KWid - ]) - xx.compile_options = {"G":G} + oh = (H + padding[0] * 2 - Kh * dilation[0] + dilation[0] - 1) // stride[0] + 1 + ow = (W + padding[1] * 2 - Kw * dilation[1] + dilation[1] - 1) // stride[1] + 1 + xx = x.reindex([N, G, oc // G, CpG, oh, ow, Kh, Kw], [ + 'i0', # Nid + f'i1*{CpG}+i3', # Gid + f'i4*{stride[0]}-{padding[0]}+i6*{dilation[0]}', # Hid+Khid + f'i5*{stride[1]}-{padding[1]}+i7*{dilation[1]}', # Wid+KWid + ]) + xx.compile_options = {"G": G} # w: [oc, CpG, Kh, Kw] - ww = weight.reindex([N, G, oc//G, CpG, oh, ow, Kh, Kw], [ - f'i1*{oc//G}+i2', - 'i3', - 'i6', - 'i7' - ]) - yy = xx*ww + ww = weight.reindex([N, G, oc // G, CpG, oh, ow, Kh, Kw], [ + f'i1*{oc // G}+i2', + 'i3', + 'i6', + 'i7' + ]) + yy = xx * ww y = yy.reindex_reduce('add', [N, oc, oh, ow], [ - 'i0', - f'i1*{oc//G}+i2', - 'i4', - 'i5' - ]) + 'i0', + f'i1*{oc // G}+i2', + 'i4', + 'i5' + ]) if bias is not None: - b = bias.broadcast(y.shape, [0,2,3]) + b = bias.broadcast(y.shape, [0, 2, 3]) y = y + b - return y + return y + class ConvTranspose(Module): def __init__(self, in_channels, out_channels, kernel_size, stride=1, \ @@ -694,7 +743,7 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1, \ # added self.dilation = dilation self.group = groups - assert groups==1, "Group conv not supported yet." + assert groups == 1, "Group conv not supported yet." self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size) self.stride = stride if isinstance(stride, tuple) else (stride, stride) @@ -702,15 +751,15 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1, \ # added self.padding = padding if isinstance(padding, tuple) else (padding, padding) self.real_padding = (self.dilation[0] * (self.kernel_size[0] - 1) - self.padding[0], - self.dilation[1] * (self.kernel_size[1] - 1) - self.padding[1]) - self.output_padding = output_padding if isinstance (output_padding, tuple) else (output_padding, output_padding) + self.dilation[1] * (self.kernel_size[1] - 1) - self.padding[1]) + self.output_padding = output_padding if isinstance(output_padding, tuple) else (output_padding, output_padding) assert self.output_padding[0] < max(self.stride[0], self.dilation[0]) and \ - self.output_padding[1] < max(self.stride[1], self.dilation[1]), \ + self.output_padding[1] < max(self.stride[1], self.dilation[1]), \ "output padding must be smaller than max(stride, dilation)" self.weight = init.invariant_uniform((in_channels, out_channels) + self.kernel_size, dtype="float") if bias: - fan=1 + fan = 1 for i in self.weight.shape[1:]: fan *= i bound = 1 / math.sqrt(fan) @@ -719,92 +768,96 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1, \ self.bias = None def execute(self, x): - N,C,H,W = x.shape - i,o,h,w = self.weight.shape - assert C==i + N, C, H, W = x.shape + i, o, h, w = self.weight.shape + assert C == i stride_h, stride_w = self.stride padding_h, padding_w = self.padding dilation_h, dilation_w = self.dilation - h_out = (H-1) * stride_h + self.output_padding[0] - 2*padding_h + 1 + (h-1)*dilation_h - w_out = (W-1) * stride_w + self.output_padding[1] - 2*padding_w + 1 + (w-1)*dilation_w + h_out = (H - 1) * stride_h + self.output_padding[0] - 2 * padding_h + 1 + (h - 1) * dilation_h + w_out = (W - 1) * stride_w + self.output_padding[1] - 2 * padding_w + 1 + (w - 1) * dilation_w out_shape = (N, o, h_out, w_out) shape = (N, i, o, H, W, h, w) - xx = x.broadcast(shape, (2, 5, 6)) # i,h,w - ww = self.weight.broadcast(shape, (0, 3, 4)) # N,H,W - y = (ww*xx).reindex_reduce("add", out_shape, [ - 'i0', # N - 'i2', # o - f'i3*{stride_h}-{padding_h}+i5*{dilation_h}', # Hid+Khid - f'i4*{stride_w}-{padding_w}+i6*{dilation_w}', # Wid+KWid + xx = x.broadcast(shape, (2, 5, 6)) # i,h,w + ww = self.weight.broadcast(shape, (0, 3, 4)) # N,H,W + y = (ww * xx).reindex_reduce("add", out_shape, [ + 'i0', # N + 'i2', # o + f'i3*{stride_h}-{padding_h}+i5*{dilation_h}', # Hid+Khid + f'i4*{stride_w}-{padding_w}+i6*{dilation_w}', # Wid+KWid ]) if self.bias is not None: - b = self.bias.broadcast(y.shape, [0,2,3]) + b = self.bias.broadcast(y.shape, [0, 2, 3]) y = y + b return y + def conv_transpose(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1): x = input - N,C,H,W = x.shape - i,o,h,w = weight.shape - assert C==i - assert groups==1, "Group conv not supported yet." + N, C, H, W = x.shape + i, o, h, w = weight.shape + assert C == i + assert groups == 1, "Group conv not supported yet." stride = stride if isinstance(stride, tuple) else (stride, stride) dilation = dilation if isinstance(dilation, tuple) else (dilation, dilation) # added padding = padding if isinstance(padding, tuple) else (padding, padding) - output_padding = output_padding if isinstance (output_padding, tuple) else (output_padding, output_padding) + output_padding = output_padding if isinstance(output_padding, tuple) else (output_padding, output_padding) assert output_padding[0] < max(stride[0], dilation[0]) and \ - output_padding[1] < max(stride[1], dilation[1]), \ + output_padding[1] < max(stride[1], dilation[1]), \ "output padding must be smaller than max(stride, dilation)" stride_h, stride_w = stride padding_h, padding_w = padding dilation_h, dilation_w = dilation - h_out = (H-1) * stride_h + output_padding[0] - 2*padding_h + 1 + (h-1)*dilation_h - w_out = (W-1) * stride_w + output_padding[1] - 2*padding_w + 1 + (w-1)*dilation_w + h_out = (H - 1) * stride_h + output_padding[0] - 2 * padding_h + 1 + (h - 1) * dilation_h + w_out = (W - 1) * stride_w + output_padding[1] - 2 * padding_w + 1 + (w - 1) * dilation_w out_shape = (N, o, h_out, w_out) shape = (N, i, o, H, W, h, w) - xx = x.broadcast(shape, (2, 5, 6)) # i,h,w - ww = weight.broadcast(shape, (0, 3, 4)) # N,H,W - y = (ww*xx).reindex_reduce("add", out_shape, [ - 'i0', # N - 'i2', # o - f'i3*{stride_h}-{padding_h}+i5*{dilation_h}', # Hid+Khid - f'i4*{stride_w}-{padding_w}+i6*{dilation_w}', # Wid+KWid + xx = x.broadcast(shape, (2, 5, 6)) # i,h,w + ww = weight.broadcast(shape, (0, 3, 4)) # N,H,W + y = (ww * xx).reindex_reduce("add", out_shape, [ + 'i0', # N + 'i2', # o + f'i3*{stride_h}-{padding_h}+i5*{dilation_h}', # Hid+Khid + f'i4*{stride_w}-{padding_w}+i6*{dilation_w}', # Wid+KWid ]) if isinstance(bias, jt.Var): - b = bias.broadcast(y.shape, [0,2,3]) + b = bias.broadcast(y.shape, [0, 2, 3]) y = y + b else: assert not bias, "Bias should be none or jittor var" return y + conv_transpose2d = conv_transpose -def pad(x,padding, mode='constant', value=0): - assert mode in ['constant','replicate','reflect','circular'],'only support constant,replicate,reflect,circular pad' - assert len(padding)%2==0 and len(padding)//2<=x.ndim + +def pad(x, padding, mode='constant', value=0): + assert mode in ['constant', 'replicate', 'reflect', + 'circular'], 'only support constant,replicate,reflect,circular pad' + assert len(padding) % 2 == 0 and len(padding) // 2 <= x.ndim padding = list(padding) - left = [0]*(x.ndim-len(padding)//2)+padding[::2][::-1] - right = [0]*(x.ndim-len(padding)//2)+padding[1::2][::-1] + left = [0] * (x.ndim - len(padding) // 2) + padding[::2][::-1] + right = [0] * (x.ndim - len(padding) // 2) + padding[1::2][::-1] out_dims = [] out_shape = [] - for i,n,l,r in zip(range(x.ndim),x.shape,left,right): - out_shape.append(n+l+r) + for i, n, l, r in zip(range(x.ndim), x.shape, left, right): + out_shape.append(n + l + r) if mode == 'constant': out_dims.append(f'i{i}-{l}') elif mode == 'replicate': - out_dims.append(f"i{i}<{l} ? 0 : i{i} > {n+l-1} ? {n-1} : i{i}-{l}") + out_dims.append(f"i{i}<{l} ? 0 : i{i} > {n + l - 1} ? {n - 1} : i{i}-{l}") elif mode == 'reflect': - out_dims.append(f"i{i}<{l} ? {l}-i{i} : i{i} > {n+l-1} ? {2*(n-1)+l}-i{i} : i{i}-{l}") + out_dims.append(f"i{i}<{l} ? {l}-i{i} : i{i} > {n + l - 1} ? {2 * (n - 1) + l}-i{i} : i{i}-{l}") elif mode == 'circular': - out_dims.append(f"i{i}<{l} ? {n-l}+i{i} : i{i} > {n+l-1} ? i{i}-{n+l} : i{i}-{l}") + out_dims.append(f"i{i}<{l} ? {n - l}+i{i} : i{i} > {n + l - 1} ? i{i}-{n + l} : i{i}-{l}") - return x.reindex(out_shape,out_dims,overflow_value=value) + return x.reindex(out_shape, out_dims, overflow_value=value) class ReflectionPad2d(Module): @@ -821,19 +874,20 @@ def __init__(self, padding): raise TypeError(f"ReflectionPad2d padding just support int or tuple, but found {type(padding)}") def execute(self, x): - n,c,h,w = x.shape + n, c, h, w = x.shape assert (self.pl < w and self.pr < w), f"padding_left and padding_right should be smaller than input width" assert (self.pt < h and self.pb < h), f"padding_top and padding_bottom should be smaller than input height" - oh=h+self.pt+self.pb - ow=w+self.pl+self.pr + oh = h + self.pt + self.pb + ow = w + self.pl + self.pr l = self.pl r = self.pl + w - 1 t = self.pt b = self.pt + h - 1 - return x.reindex([n,c,oh,ow], ["i0","i1", - f"i2<{t} ? {t}-i2 : i2 > {b} ? {h-1+b}-i2 : i2-{t}", - f"i3<{l} ? {l}-i3 : i3 > {r} ? {w-1+r}-i3 : i3-{l}", - ]) + return x.reindex([n, c, oh, ow], ["i0", "i1", + f"i2<{t} ? {t}-i2 : i2 > {b} ? {h - 1 + b}-i2 : i2-{t}", + f"i3<{l} ? {l}-i3 : i3 > {r} ? {w - 1 + r}-i3 : i3-{l}", + ]) + class ZeroPad2d(Module): def __init__(self, padding): @@ -849,8 +903,10 @@ def __init__(self, padding): raise TypeError(f"ZeroPad2d padding just support int or tuple, but found {type(padding)}") def execute(self, x): - n,c,h,w = x.shape - return x.reindex([n,c,h+self.pt+self.pb,w+self.pl+self.pr], ["i0","i1",f"i2-{self.pt}",f"i3-{self.pl}"]) + n, c, h, w = x.shape + return x.reindex([n, c, h + self.pt + self.pb, w + self.pl + self.pr], + ["i0", "i1", f"i2-{self.pt}", f"i3-{self.pl}"]) + class ConstantPad2d(Module): def __init__(self, padding, value): @@ -869,14 +925,15 @@ def __init__(self, padding, value): def execute(self, x): assert len(x.shape) >= 2 shape = x.shape - tar_shape = shape[0:-2] + [shape[-2]+self.pt+self.pb,shape[-1]+self.pl+self.pr] + tar_shape = shape[0:-2] + [shape[-2] + self.pt + self.pb, shape[-1] + self.pl + self.pr] tar_dims = [] - for i in range(len(shape)-2): + for i in range(len(shape) - 2): tar_dims.append(f"i{i}") - tar_dims.append(f"i{i+1}-{self.pt}") - tar_dims.append(f"i{i+2}-{self.pl}") + tar_dims.append(f"i{i + 1}-{self.pt}") + tar_dims.append(f"i{i + 2}-{self.pl}") return x.reindex(tar_shape, tar_dims, overflow_value=self.value) + class ReplicationPad2d(Module): def __init__(self, padding): self.padding = padding @@ -891,61 +948,69 @@ def __init__(self, padding): raise TypeError(f"ReplicationPad2d padding just support int or tuple, but found {type(padding)}") def execute(self, x): - n,c,h,w = x.shape - oh=h+self.pt+self.pb - ow=w+self.pl+self.pr + n, c, h, w = x.shape + oh = h + self.pt + self.pb + ow = w + self.pl + self.pr l = self.pl r = self.pl + w - 1 t = self.pt b = self.pt + h - 1 - return x.reindex([n,c,oh,ow], ["i0","i1", - f"i2<{t} ? 0 : i2 > {b} ? {h-1} : i2-{t}", - f"i3<{l} ? 0 : i3 > {r} ? {w-1} : i3-{l}" - ]) + return x.reindex([n, c, oh, ow], ["i0", "i1", + f"i2<{t} ? 0 : i2 > {b} ? {h - 1} : i2-{t}", + f"i3<{l} ? 0 : i3 > {r} ? {w - 1} : i3-{l}" + ]) + class Embedding(Module): def __init__(self, num, dim): self.num = num self.dim = dim - self.weight = jt.init.gauss([num,dim],'float32').stop_grad() + self.weight = jt.init.gauss([num, dim], 'float32').stop_grad() def execute(self, x): - res = self.weight[x].reshape([x.shape[0],self.dim]) + res = self.weight[x].reshape([x.shape[0], self.dim]) return res + class PixelShuffle(Module): def __init__(self, upscale_factor): self.upscale_factor = upscale_factor def execute(self, x): - n,c,h,w = x.shape + n, c, h, w = x.shape r = self.upscale_factor - assert c%(r*r)==0, f"input channel needs to be divided by upscale_factor's square in PixelShuffle" - return x.reindex([n,int(c/r**2),h*r,w*r], [ + assert c % (r * r) == 0, f"input channel needs to be divided by upscale_factor's square in PixelShuffle" + return x.reindex([n, int(c / r ** 2), h * r, w * r], [ "i0", - f"i1*{r*r}+i2%{r}*{r}+i3%{r}", + f"i1*{r * r}+i2%{r}*{r}+i3%{r}", f"i2/{r}", f"i3/{r}" ]) + class Tanh(Module): def __init__(self): super().__init__() - def execute(self, x) : + + def execute(self, x): return x.tanh() + class Sigmoid(Module): def __init__(self): super().__init__() - def execute(self, x) : + + def execute(self, x): return x.sigmoid() -def softplus(x,beta=1.0,threshold=20.0): + +def softplus(x, beta=1.0, threshold=20.0): return 1 / beta * jt.log(1 + (beta * x).minimum(threshold).exp()) + \ - (x - threshold/beta).maximum(0.0) + (x - threshold / beta).maximum(0.0) + -def hardtanh(x,min_val=-1,max_val=1): - return jt.clamp(x,min_v=min_val,max_v=max_val) +def hardtanh(x, min_val=-1, max_val=1): + return jt.clamp(x, min_v=min_val, max_v=max_val) class Softplus(Module): @@ -953,11 +1018,12 @@ class Softplus(Module): SoftPlus is a smooth approximation to the ReLU function and can be used to constrain the output of a machine to always be positive. Args: - + [in] beta (float): the beta value for the Softplus formulation. Default: 1. [in] threshold (float): values above this revert to a linear function. Default: 20. ''' + def __init__(self, beta=1, threshold=20): self.beta = beta self.threshold = threshold @@ -965,68 +1031,126 @@ def __init__(self, beta=1, threshold=20): def execute(self, x): return softplus(x, self.beta, self.threshold) + class Resize(Module): def __init__(self, size, mode="nearest", align_corners=False): super().__init__() self.size = size self.mode = mode self.align_corners = align_corners + def execute(self, x): return resize(x, self.size, self.mode, self.align_corners) + +def _bicubic(x, a): + ''' + c1 = jt.where(jt.abs(x)<=1) + x[:,:,c1[0],c1[1]] = (a+2)*jt.abs(x[:,:,c1[0],c1[1]])**3-(a+3)*x[:,:,c1[0],c1[1]]**2+1 + c2 = jt.where(jt.abs(x)>1) + x[:,:,c2[0],c2[1]] = a*jt.abs(x[:,:,c2[0],c2[1]])**3-5*a*x[:,:,c2[0],c2[1]]**2+8*a*jt.abs(x[:,:,c2[0],c2[1]])-4*a + c3 = jt.where(jt.abs(x)>=2) + x[:,:,c3[0],c3[1]] = 0 + return x + ''' + # normal ver + if jt.abs(x) <= 1: + return (a + 2) * (jt.abs(x) ** 3) - (a + 3) * (x ** 2) + 1 + if jt.abs(x) < 2: + return a * (jt.abs(x) ** 3) - 5 * a * (x ** 2) + 8 * a * (jt.abs(x)) - 4 * a + return 0 + + ''' + # pytorch ver. + coe = [] + coe.append(((a*(x+1)-5*a)*(x+1)+8*a)*(x+1)-4*a) + coe.append(((a+2)*x-(a+3))*x*x+1) + coe.append(((a+2)*(1-x)-(a+3))*(1-x)*(1-x)+1) + coe.append(((a*(2-x)-5*a)*(2-x)+8*a)*(2-x)-4*a) + return coe; + ''' + + def _interpolate(img, x, y, ids, mode): - if mode=="nearest": + if mode == "nearest": return img.reindex([*ids, x.floor(), y.floor()]) - if mode=="bilinear": + if mode == "bilinear": fx, fy = x.floor(), y.floor() - cx, cy = fx+1, fy+1 - dx, dy = x-fx, y-fy + cx, cy = fx + 1, fy + 1 + dx, dy = x - fx, y - fy a = img.reindex_var([*ids, fx, fy]) b = img.reindex_var([*ids, cx, fy]) c = img.reindex_var([*ids, fx, cy]) d = img.reindex_var([*ids, cx, cy]) - dnx, dny = 1-dx, 1-dy - ab = dx*b + dnx*a - cd = dx*d + dnx*c - o = ab*dny + cd*dy + dnx, dny = 1 - dx, 1 - dy + ab = dx * b + dnx * a + cd = dx * d + dnx * c + o = ab * dny + cd * dy return o - raise(f"Not support interpolation mode: {mode}") + if mode == "bicubic": # ugly ver. + n, c, h, w = img.shape + print(img) + fx, fy = x.floor(), y.floor() + dx, dy = x - fx, y - fy + outputs = jt.zeros((n, c, x.shape[-2], x.shape[-1])) + for dn in range(n): + for dc in range(c): + for i in range(x.shape[-2]): + for j in range(x.shape[-1]): + for a in range(-1, 3): + for b in range(-1, 3): + nx = max(min(fx[dn, dc, i, j] + a, h - 1), 0) + ny = max(min(fy[dn, dc, i, j] + b, w - 1), 0) + # print(_bicubic(dx[dn,dc,i,j]-a,-0.75)) + # print(_bicubic(dy[dn,dc,i,j]-b,-0.75)) + outputs[dn, dc, i, j] += img[dn, dc, nx, ny] * _bicubic(dx[dn, dc, i, j] - a, + -0.75) * _bicubic( + dy[dn, dc, i, j] - b, -0.75) + return outputs + raise (f"Not support interpolation mode: {mode}") + def resize(img, size, mode="nearest", align_corners=False): - n,c,h,w = img.shape - H,W = size - nid, cid, hid, wid = jt.index((n,c,H,W)) + n, c, h, w = img.shape + H, W = size + nid, cid, hid, wid = jt.index((n, c, H, W)) if align_corners: - x = hid * ((h-1) / max(1, H-1)) - y = wid * ((w-1) / max(1, W-1)) + x = hid * ((h - 1) / max(1, H - 1)) + y = wid * ((w - 1) / max(1, W - 1)) else: - x = hid * (h / H) + (h/H*0.5 - 0.5) - if H>h: x = x.clamp(0, h-1) - y = wid * (w / W) + (w/W*0.5 - 0.5) - if W>w: y = y.clamp(0, w-1) - return _interpolate(img, x, y, (nid,cid), mode) + x = hid * (h / H) + (h / H * 0.5 - 0.5) + if H > h: x = x.clamp(0, h - 1) + y = wid * (w / W) + (w / W * 0.5 - 0.5) + if W > w: y = y.clamp(0, w - 1) + return _interpolate(img, x, y, (nid, cid), mode) + def upsample(img, size, mode="nearest", align_corners=False): - n,c,h,w = img.shape - H,W = size - nid, cid, hid, wid = jt.index((n,c,H,W)) + n, c, h, w = img.shape + H, W = size + nid, cid, hid, wid = jt.index((n, c, H, W)) if align_corners: - x = hid * ((h-1) / max(1, H-1)) - y = wid * ((w-1) / max(1, W-1)) + x = hid * ((h - 1) / max(1, H - 1)) + y = wid * ((w - 1) / max(1, W - 1)) + elif mode == "bicubic": + x = (hid + 0.5) * (h / H) - 0.5 + y = (wid + 0.5) * (w / W) - 0.5 else: x = hid * (h / H) y = wid * (w / W) - return _interpolate(img, x, y, (nid,cid), mode) + return _interpolate(img, x, y, (nid, cid), mode) + -def interpolate(X,size=None,scale_factor=None,mode='bilinear',align_corners=False): +def interpolate(X, size=None, scale_factor=None, mode='bilinear', align_corners=False): if scale_factor is not None: - size = [X.shape[-2]*scale_factor,X.shape[-1]*scale_factor] - if isinstance(size,int): - size = (size,size) - if scale_factor is not None and scale_factor>1: - return upsample(X,size,mode,align_corners) + size = [X.shape[-2] * scale_factor, X.shape[-1] * scale_factor] + if isinstance(size, int): + size = (size, size) + if scale_factor is not None and scale_factor > 1: + return upsample(X, size, mode, align_corners) else: - return resize(X,size,mode,align_corners) + return resize(X, size, mode, align_corners) + def grid_sample_v0(input, grid, mode='bilinear', padding_mode='zeros'): r''' @@ -1043,7 +1167,7 @@ def grid_sample_v0(input, grid, mode='bilinear', padding_mode='zeros'): [in] mode (string): the interpolate way, default: bilinear. [in] padding_mode (string): the padding way, default: zeros. - + [out] output (var): the output var, whose shape is (N, C, Ho, Wo) Example: @@ -1067,188 +1191,194 @@ def grid_sample_v0(input, grid, mode='bilinear', padding_mode='zeros'): assert Ni == No assert len(input.shape) == 4 and len(grid.shape) - nid, cid, hid, wid = jt.index((Ni,Ci,Ho,Wo)) - x = ((grid[:,:,:,1].unsqueeze(1).repeat([1,Ci,1,1]) + 1) / 2) * (Hi - 1) - y = ((grid[:,:,:,0].unsqueeze(1).repeat([1,Ci,1,1]) + 1) / 2) * (Wi - 1) - return _interpolate(input, x, y, (nid,cid), mode) + nid, cid, hid, wid = jt.index((Ni, Ci, Ho, Wo)) + x = ((grid[:, :, :, 1].unsqueeze(1).repeat([1, Ci, 1, 1]) + 1) / 2) * (Hi - 1) + y = ((grid[:, :, :, 0].unsqueeze(1).repeat([1, Ci, 1, 1]) + 1) / 2) * (Wi - 1) + return _interpolate(input, x, y, (nid, cid), mode) -def linspace_from_neg_one(grid,num_steps,align_corners): - if num_steps <= 1: - return jt.array([],dtype=grid.dtype) +def linspace_from_neg_one(grid, num_steps, align_corners): + if num_steps <= 1: + return jt.array([], dtype=grid.dtype) # TODO: use jt.index - ra = np.linspace(-1,1,num_steps) + ra = np.linspace(-1, 1, num_steps) if not align_corners: - ra = ra*(num_steps-1)/num_steps - return jt.array(ra,dtype=grid.dtype) + ra = ra * (num_steps - 1) / num_steps + return jt.array(ra, dtype=grid.dtype) -def make_base_grid_4D(theta,N,C,H,W,align_corners): + +def make_base_grid_4D(theta, N, C, H, W, align_corners): base_grid = jt.zeros((N, H, W, 3), dtype=theta.dtype); - base_grid[...,0] = linspace_from_neg_one(theta, W, align_corners) - base_grid[...,1] = jt.unsqueeze(linspace_from_neg_one(theta, H, align_corners),-1) - base_grid[...,-1] = 1 + base_grid[..., 0] = linspace_from_neg_one(theta, W, align_corners) + base_grid[..., 1] = jt.unsqueeze(linspace_from_neg_one(theta, H, align_corners), -1) + base_grid[..., -1] = 1 return base_grid -def make_base_grid_5D(theta,N,C,D,H,W,align_corners): + +def make_base_grid_5D(theta, N, C, D, H, W, align_corners): base_grid = jt.zeros((N, D, H, W, 4), dtype=theta.dtype) - base_grid[...,0] = linspace_from_neg_one(theta, W, align_corners) - base_grid[...,1] = jt.unsqueeze(linspace_from_neg_one(theta, H, align_corners),-1) - base_grid[...,2] = jt.unsqueeze(jt.unsqueeze(linspace_from_neg_one(theta, D, align_corners),-1),-1) - base_grid[...,-1] = 1 + base_grid[..., 0] = linspace_from_neg_one(theta, W, align_corners) + base_grid[..., 1] = jt.unsqueeze(linspace_from_neg_one(theta, H, align_corners), -1) + base_grid[..., 2] = jt.unsqueeze(jt.unsqueeze(linspace_from_neg_one(theta, D, align_corners), -1), -1) + base_grid[..., -1] = 1 return base_grid -def affine_grid_generator_4D(theta,N,C,H,W,align_corners): - base_grid = make_base_grid_4D(theta, N, C, H, W, align_corners) - grid = jt.nn.bmm(base_grid.reshape(N, H * W, 3),theta.transpose(0,2,1)) - return grid.reshape(N, H, W, 2) -def affine_grid_generator_5D(theta,N,C,D,H,W,align_corners): +def affine_grid_generator_4D(theta, N, C, H, W, align_corners): + base_grid = make_base_grid_4D(theta, N, C, H, W, align_corners) + grid = jt.nn.bmm(base_grid.reshape(N, H * W, 3), theta.transpose(0, 2, 1)) + return grid.reshape(N, H, W, 2) + + +def affine_grid_generator_5D(theta, N, C, D, H, W, align_corners): base_grid = make_base_grid_5D(theta, N, C, D, H, W, align_corners) - grid = jt.nn.bmm(base_grid.reshape(N, D * H * W, 4),theta.transpose(0,2,1)) + grid = jt.nn.bmm(base_grid.reshape(N, D * H * W, 4), theta.transpose(0, 2, 1)) return grid.reshape(N, D, H, W, 3) + def affine_grid(theta, size, align_corners=False): - assert str(theta.dtype) in ['float','float32','float64'] - assert min(size)>0 - assert len(size) in [4,5] - if len(size)== 4: + assert str(theta.dtype) in ['float', 'float32', 'float64'] + assert min(size) > 0 + assert len(size) in [4, 5] + if len(size) == 4: assert theta.ndim == 3 and theta.shape[-2] == 2 and theta.shape[-1] == 3 return affine_grid_generator_4D(theta, size[0], size[1], size[2], size[3], align_corners) - elif len(size)==5: + elif len(size) == 5: assert theta.ndim == 3 and theta.shape[-2] == 3 and theta.shape[-1] == 4 return affine_grid_generator_5D(theta, size[0], size[1], size[2], size[3], size[4], align_corners) -def grid_sampler_unnormalize(coord,size,align_corners): +def grid_sampler_unnormalize(coord, size, align_corners): if align_corners: - #unnormalize coord from [-1, 1] to [0, size - 1] + # unnormalize coord from [-1, 1] to [0, size - 1] return ((coord + 1) / 2) * (size - 1) else: - #unnormalize coord from [-1, 1] to [-0.5, size - 0.5] + # unnormalize coord from [-1, 1] to [-0.5, size - 0.5] return ((coord + 1) * size - 1) / 2 -def clip_coordinates(x,clip_limit): - return jt.clamp(x,min_v=0,max_v=clip_limit-1) +def clip_coordinates(x, clip_limit): + return jt.clamp(x, min_v=0, max_v=clip_limit - 1) + -def reflect_coordinates(x,twice_low,twice_high): +def reflect_coordinates(x, twice_low, twice_high): if twice_low == twice_high: return jt.zeros_like(x) m = twice_low / 2 span = (twice_high - twice_low) / 2 x = (x - m).abs() - #`fmod` returns same sign as `in`, which is positive after the `fabs` above. + # `fmod` returns same sign as `in`, which is positive after the `fabs` above. extra = x.mod(span) flips = (x / span).floor() - result1 = extra+m - result2 = span-extra+m - con = flips%2==0 - not_con = flips%2!=0 - result1[not_con]=0.0 - result2[con]=0.0 - return result1+result2 + result1 = extra + m + result2 = span - extra + m + con = flips % 2 == 0 + not_con = flips % 2 != 0 + result1[not_con] = 0.0 + result2[con] = 0.0 + return result1 + result2 -def grid_sampler_compute_source_index(coord,size,padding_mode,align_corners): +def grid_sampler_compute_source_index(coord, size, padding_mode, align_corners): coord = grid_sampler_unnormalize(coord, size, align_corners) if padding_mode == 'border': - #clip coordinates to image borders + # clip coordinates to image borders coord = clip_coordinates(coord, size) elif padding_mode == 'reflection': - #reflect coordinates by image borders + # reflect coordinates by image borders if align_corners: - coord = reflect_coordinates(coord, 0, 2*(size - 1)) + coord = reflect_coordinates(coord, 0, 2 * (size - 1)) else: - coord = reflect_coordinates(coord, -1, 2*size - 1) - #clip coordinates to image borders + coord = reflect_coordinates(coord, -1, 2 * size - 1) + # clip coordinates to image borders coord = clip_coordinates(coord, size) return coord - -def grid_sampler_3d(X,grid,mode,padding_mode,align_corners): +def grid_sampler_3d(X, grid, mode, padding_mode, align_corners): N = X.shape[0] C = X.shape[1] inp_D = X.shape[2] inp_H = X.shape[3] inp_W = X.shape[4] - D = grid.shape[1] + D = grid.shape[1] H = grid.shape[2] W = grid.shape[3] - x = grid[:,:,:,:,0] - y = grid[:,:,:,:,1] - z = grid[:,:,:,:,2] - shape = [N,C,D,H,W] + x = grid[:, :, :, :, 0] + y = grid[:, :, :, :, 1] + z = grid[:, :, :, :, 2] + shape = [N, C, D, H, W] cid = jt.index(shape, dim=1) nid = jt.index(shape, dim=0) - x = grid_sampler_compute_source_index(x,inp_W,padding_mode,align_corners) - y = grid_sampler_compute_source_index(y,inp_H,padding_mode,align_corners) - z = grid_sampler_compute_source_index(z,inp_D,padding_mode,align_corners) - xid = x.reindex(shape,['i0','i2','i3','i4']) - yid = y.reindex(shape,['i0','i2','i3','i4']) - zid = z.reindex(shape,['i0','i2','i3','i4']) - - if mode=='nearest': - return X.reindex([nid,cid,zid.round(),yid.round(),xid.round()]) - elif mode=='bilinear': - fx,fy,fz = xid.floor(),yid.floor(),zid.floor() - cx,cy,cz = fx+1,fy+1,fz+1 - dx,dy,dz = xid-fx,yid-fy,zid-fz - dnx,dny,dnz = cx-xid,cy-yid,cz-zid - a = X.reindex([nid,cid,fz,fy,fx]) - b = X.reindex([nid,cid,cz,fy,fx]) - c = X.reindex([nid,cid,fz,cy,fx]) - d = X.reindex([nid,cid,fz,fy,cx]) - e = X.reindex([nid,cid,fz,cy,cx]) - f = X.reindex([nid,cid,cz,fy,cx]) - g = X.reindex([nid,cid,cz,cy,fx]) - h = X.reindex([nid,cid,cz,cy,cx]) - o = a*dnx*dny*dnz+b*dnx*dny*dz+c*dnx*dy*dnz+d*dx*dny*dnz+e*dx*dy*dnz+f*dx*dny*dz+g*dnx*dy*dz+h*dx*dy*dz + x = grid_sampler_compute_source_index(x, inp_W, padding_mode, align_corners) + y = grid_sampler_compute_source_index(y, inp_H, padding_mode, align_corners) + z = grid_sampler_compute_source_index(z, inp_D, padding_mode, align_corners) + xid = x.reindex(shape, ['i0', 'i2', 'i3', 'i4']) + yid = y.reindex(shape, ['i0', 'i2', 'i3', 'i4']) + zid = z.reindex(shape, ['i0', 'i2', 'i3', 'i4']) + + if mode == 'nearest': + return X.reindex([nid, cid, zid.round(), yid.round(), xid.round()]) + elif mode == 'bilinear': + fx, fy, fz = xid.floor(), yid.floor(), zid.floor() + cx, cy, cz = fx + 1, fy + 1, fz + 1 + dx, dy, dz = xid - fx, yid - fy, zid - fz + dnx, dny, dnz = cx - xid, cy - yid, cz - zid + a = X.reindex([nid, cid, fz, fy, fx]) + b = X.reindex([nid, cid, cz, fy, fx]) + c = X.reindex([nid, cid, fz, cy, fx]) + d = X.reindex([nid, cid, fz, fy, cx]) + e = X.reindex([nid, cid, fz, cy, cx]) + f = X.reindex([nid, cid, cz, fy, cx]) + g = X.reindex([nid, cid, cz, cy, fx]) + h = X.reindex([nid, cid, cz, cy, cx]) + o = a * dnx * dny * dnz + b * dnx * dny * dz + c * dnx * dy * dnz + d * dx * dny * dnz + e * dx * dy * dnz + f * dx * dny * dz + g * dnx * dy * dz + h * dx * dy * dz return o -def grid_sampler_2d(X,grid,mode,padding_mode,align_corners): + +def grid_sampler_2d(X, grid, mode, padding_mode, align_corners): N = X.shape[0] C = X.shape[1] inp_H = X.shape[2] inp_W = X.shape[3] - H = grid.shape[1] + H = grid.shape[1] W = grid.shape[2] - x = grid[:,:,:,0] - y = grid[:,:,:,1] - shape = [N,C,H,W] + x = grid[:, :, :, 0] + y = grid[:, :, :, 1] + shape = [N, C, H, W] cid = jt.index(shape, dim=1) nid = jt.index(shape, dim=0) - x = grid_sampler_compute_source_index(x,inp_W,padding_mode,align_corners) - y = grid_sampler_compute_source_index(y,inp_H,padding_mode,align_corners) - xid = x.reindex(shape,['i0','i2','i3']) - yid = y.reindex(shape,['i0','i2','i3']) - - if mode=='nearest': - return X.reindex([nid,cid,yid.round(),xid.round()]) - elif mode=='bilinear': - #xid,yid = (xid+0.00001),(yid+0.00001) - fx,fy = (xid).floor(),(yid).floor() - cx,cy = fx+1,fy+1 - dx,dy = xid-fx,yid-fy - dnx,dny = cx-xid,cy-yid - - a = X.reindex([nid,cid,fy,fx],overflow_value=0.0) - b = X.reindex([nid,cid,cy,fx],overflow_value=0.0) - c = X.reindex([nid,cid,fy,cx],overflow_value=0.0) - d = X.reindex([nid,cid,cy,cx],overflow_value=0.0) - o = a*dnx*dny+b*dnx*dy+c*dx*dny+d*dx*dy + x = grid_sampler_compute_source_index(x, inp_W, padding_mode, align_corners) + y = grid_sampler_compute_source_index(y, inp_H, padding_mode, align_corners) + xid = x.reindex(shape, ['i0', 'i2', 'i3']) + yid = y.reindex(shape, ['i0', 'i2', 'i3']) + + if mode == 'nearest': + return X.reindex([nid, cid, yid.round(), xid.round()]) + elif mode == 'bilinear': + # xid,yid = (xid+0.00001),(yid+0.00001) + fx, fy = (xid).floor(), (yid).floor() + cx, cy = fx + 1, fy + 1 + dx, dy = xid - fx, yid - fy + dnx, dny = cx - xid, cy - yid + + a = X.reindex([nid, cid, fy, fx], overflow_value=0.0) + b = X.reindex([nid, cid, cy, fx], overflow_value=0.0) + c = X.reindex([nid, cid, fy, cx], overflow_value=0.0) + d = X.reindex([nid, cid, cy, cx], overflow_value=0.0) + o = a * dnx * dny + b * dnx * dy + c * dx * dny + d * dx * dy return o def grid_sampler(X, grid, mode, padding_mode, align_corners): - assert X.dtype==grid.dtype - assert ((X.ndim==4 or X.ndim==5) and X.ndim==grid.ndim) - assert X.shape[0]==grid.shape[0] and grid.shape[-1]==X.ndim-2 - assert X.numel()>0 + assert X.dtype == grid.dtype + assert ((X.ndim == 4 or X.ndim == 5) and X.ndim == grid.ndim) + assert X.shape[0] == grid.shape[0] and grid.shape[-1] == X.ndim - 2 + assert X.numel() > 0 if X.ndim == 4: return grid_sampler_2d(X, grid, mode, padding_mode, align_corners) else: @@ -1256,8 +1386,8 @@ def grid_sampler(X, grid, mode, padding_mode, align_corners): def grid_sample(input, grid, mode='bilinear', padding_mode='zeros', align_corners=False): - assert mode in ['bilinear','nearest'] - assert padding_mode in ['zeros','border','reflection'] + assert mode in ['bilinear', 'nearest'] + assert padding_mode in ['zeros', 'border', 'reflection'] return grid_sampler(input, grid, mode, padding_mode, align_corners) @@ -1265,13 +1395,14 @@ class Upsample(Module): def __init__(self, scale_factor=None, mode='nearest'): self.scale_factor = scale_factor if isinstance(scale_factor, tuple) else (scale_factor, scale_factor) self.mode = mode - + def execute(self, x): return upsample(x, - size=( - int(x.shape[2]*self.scale_factor[0]), - int(x.shape[3]*self.scale_factor[1])), - mode=self.mode) + size=( + int(x.shape[2] * self.scale_factor[0]), + int(x.shape[3] * self.scale_factor[1])), + mode=self.mode) + class Sequential(Module): def __init__(self, *args): @@ -1280,49 +1411,113 @@ def __init__(self, *args): if isinstance(mod, collections.OrderedDict): for k, m in mod.items(): self.add_module(k, m) - elif isinstance(mod,list): + elif isinstance(mod, list): for m in mod: self.append(m) else: self.append(mod) + def __getitem__(self, idx): if idx not in self.layers: return list(self.layers.values())[idx] return self.layers[idx] + def __iter__(self): return self.layers.values().__iter__() + def keys(self): return self.layers.keys() + def values(self): return self.layers.values() + def items(self): return self.layers.items() + def execute(self, x): for k, layer in self.layers.items(): x = layer(x) return x + def dfs(self, parents, k, callback, callback_leave): n_children = len(self.layers) ret = callback(parents, k, self, n_children) if ret == False: return parents.append(self) - for k,v in self.layers.items(): + for k, v in self.layers.items(): v.dfs(parents, k, callback, callback_leave) parents.pop() if callback_leave: callback_leave(parents, k, self, n_children) + def append(self, mod): assert callable(mod), f"Module <{type(mod)}> is not callable" assert not isinstance(mod, type), f"Module is not a type" - self.layers[len(self.layers)]=mod + self.layers[len(self.layers)] = mod + def add_module(self, name, mod): assert callable(mod), f"Module <{type(mod)}> is not callable" assert not isinstance(mod, type), f"Module is not a type" - self.layers[name]=mod + self.layers[name] = mod def __len__(self): return len(self.layers) + +def unfold(X, kernel_size, dilation=1, padding=0, stride=1): + assert X.ndim == 4 + if not isinstance(kernel_size, tuple): + kernel_size = (kernel_size, kernel_size) + if not isinstance(dilation, tuple): + dilation = (dilation, dilation) + if not isinstance(padding, tuple): + padding = (padding, padding) + if not isinstance(stride, tuple): + stride = (stride, stride) + n, c, h, w = X.shape + shape = X.shape + area = kernel_size[0] * kernel_size[1] + block_nums = [] + for i in range(2, 4): + block_nums.append( + (shape[i] + 2 * padding[i - 2] - dilation[i - 2] * (kernel_size[i - 2] - 1) - 1) // stride[i - 2] + 1) + if padding[0] != 0 or padding[1] != 0: + X = X.reindex([n, c, h + padding[0] * 2, w + padding[1] * 2], + ["i0", "i1", f"i2-{padding[0]}", f"i3-{padding[1]}"]) + output = X.reindex([n, c * area, block_nums[0] * block_nums[1]], ["i0", f"i1/{area}", + f"i2/{block_nums[1]}*{stride[0]}+(i1%{area})/{kernel_size[1]}*{dilation[0]}", + f"i2%{block_nums[1]}*{stride[1]}+(i1%{area})%{kernel_size[1]}*{dilation[1]}"]) + return output + + +def fold(X, output_size, kernel_size, dilation=1, padding=0, stride=1): # this may be implemented in C for speed? + assert X.ndim == 3 + if not isinstance(kernel_size, tuple): + kernel_size = (kernel_size, kernel_size) + if not isinstance(dilation, tuple): + dilation = (dilation, dilation) + if not isinstance(padding, tuple): + padding = (padding, padding) + if not isinstance(stride, tuple): + stride = (stride, stride) + n, cl, num = X.shape + area = kernel_size[0] * kernel_size[1] + block_nums = [] + output = jt.zeros((n, cl // area, output_size[0], output_size[1])) + for i in range(2, 4): + block_nums.append( + (output_size[i - 2] + 2 * padding[i - 2] - dilation[i - 2] * (kernel_size[i - 2] - 1) - 1) // stride[ + i - 2] + 1) + for dn in range(n): + for c in range(cl // area): + for i in range(num): + for j in range(area): + output[dn, c, i // block_nums[1] * stride[0] + (j % area) / kernel_size[1] * dilation[0], i % + block_nums[1] * stride[1] + j % area % kernel_size[1] * dilation[1]] += X[ + dn, c * area + j, i] + return output + + ModuleList = Sequential diff --git a/python/jittor/test/test_linalg.py b/python/jittor/test/test_linalg.py index 6d6bbcdc..50cbc02e 100644 --- a/python/jittor/test/test_linalg.py +++ b/python/jittor/test/test_linalg.py @@ -257,4 +257,3 @@ def check_det(a): if __name__ == "__main__": unittest.main() - From 40a915de8702536f863c91e09bfaa23ee308b549 Mon Sep 17 00:00:00 2001 From: Exusial <2247838039@qq.com> Date: Thu, 11 Feb 2021 16:51:18 +0800 Subject: [PATCH 02/36] add eye. --- python/jittor/misc.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/jittor/misc.py b/python/jittor/misc.py index 74ceab64..bd9d3922 100644 --- a/python/jittor/misc.py +++ b/python/jittor/misc.py @@ -998,6 +998,10 @@ def grad(self, grad_a): dim+=len(a.shape) return func(a, dim) +def eye(n): + x = jt.ones((n)) + return x.reindex([n,n],["i0"],overflow_conditions=[f'i0!=i1']) + def linspace(start, end, steps): res = jt.index((steps,))[0] res = res*(end-start)/float(steps-1)+start From fa2f911d1ebcb4a6e0247727e130aaf20ced2603 Mon Sep 17 00:00:00 2001 From: Exusial <2247838039@qq.com> Date: Sat, 13 Feb 2021 15:56:50 +0800 Subject: [PATCH 03/36] add test for qr,bicubic,fold,unfold. --- python/jittor/nn.py | 40 ++++++++++++++---------------- python/jittor/test/test_bicubic.py | 38 ++++++++++++++++++++++++++++ python/jittor/test/test_fold.py | 34 +++++++++++++++++++++++++ python/jittor/test/test_linalg.py | 38 +++++++++++++++++++++++++++- 4 files changed, 127 insertions(+), 23 deletions(-) create mode 100644 python/jittor/test/test_bicubic.py create mode 100644 python/jittor/test/test_fold.py diff --git a/python/jittor/nn.py b/python/jittor/nn.py index 20576dff..bb71640f 100644 --- a/python/jittor/nn.py +++ b/python/jittor/nn.py @@ -1089,7 +1089,6 @@ def _interpolate(img, x, y, ids, mode): return o if mode == "bicubic": # ugly ver. n, c, h, w = img.shape - print(img) fx, fy = x.floor(), y.floor() dx, dy = x - fx, y - fy outputs = jt.zeros((n, c, x.shape[-2], x.shape[-1])) @@ -1492,32 +1491,29 @@ def unfold(X, kernel_size, dilation=1, padding=0, stride=1): return output -def fold(X, output_size, kernel_size, dilation=1, padding=0, stride=1): # this may be implemented in C for speed? - assert X.ndim == 3 - if not isinstance(kernel_size, tuple): - kernel_size = (kernel_size, kernel_size) - if not isinstance(dilation, tuple): - dilation = (dilation, dilation) - if not isinstance(padding, tuple): - padding = (padding, padding) - if not isinstance(stride, tuple): - stride = (stride, stride) - n, cl, num = X.shape +def fold(X,output_size,kernel_size,dilation=1,padding=0,stride=1):# this may be implemented in C for speed? + assert X.ndim==3 + if not isinstance(kernel_size,tuple): + kernel_size = (kernel_size,kernel_size) + if not isinstance(dilation,tuple): + dilation = (dilation,dilation) + if not isinstance(padding,tuple): + padding = (padding,padding) + if not isinstance(stride,tuple): + stride = (stride,stride) + n,cl,num = X.shape area = kernel_size[0] * kernel_size[1] block_nums = [] - output = jt.zeros((n, cl // area, output_size[0], output_size[1])) - for i in range(2, 4): - block_nums.append( - (output_size[i - 2] + 2 * padding[i - 2] - dilation[i - 2] * (kernel_size[i - 2] - 1) - 1) // stride[ - i - 2] + 1) + output = jt.zeros((n,cl // area,output_size[0]+2*padding[0],output_size[1]+2*padding[1])) + for i in range(2,4): + block_nums.append((output_size[i-2]+2*padding[i-2]-dilation[i-2]*(kernel_size[i-2]-1)-1) // stride[i-2]+1) for dn in range(n): for c in range(cl // area): for i in range(num): for j in range(area): - output[dn, c, i // block_nums[1] * stride[0] + (j % area) / kernel_size[1] * dilation[0], i % - block_nums[1] * stride[1] + j % area % kernel_size[1] * dilation[1]] += X[ - dn, c * area + j, i] - return output - + h = i//block_nums[1]*stride[0]+(j%area)/kernel_size[1]*dilation[0] + w = i%block_nums[1]*stride[1]+j%area%kernel_size[1]*dilation[1] + output[dn,c,i//block_nums[1]*stride[0]+(j%area)//kernel_size[1]*dilation[0],i%block_nums[1]*stride[1]+j%area%kernel_size[1]*dilation[1]]+=X[dn,c*area+j,i] + return output[:,:,padding[0]:padding[0]+output_size[0],padding[1]:padding[1]+output_size[1]] ModuleList = Sequential diff --git a/python/jittor/test/test_bicubic.py b/python/jittor/test/test_bicubic.py new file mode 100644 index 00000000..a07ed706 --- /dev/null +++ b/python/jittor/test/test_bicubic.py @@ -0,0 +1,38 @@ +# *************************************************************** +# Copyright (c) 2021 Jittor. All Rights Reserved. +# Maintainers: +# Guoye Yang <498731903@qq.com> +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import torch +from torch.nn import functional as F +import numpy as np + + +class TestBicubicInterpolate(unittest.TestCase): + # this is for testing bicubic interpolate + def test_bicubic(self): + for _ in range(20): + try: + tn = np.random.randn(1,1,5,5).astype('float32') + ja = jt.array(tn) + ta = torch.from_numpy(tn) + # test upsample + ju = jt.nn.interpolate(ja,scale_factor=2,mode='bicubic') + tu = F.interpolate(ta,scale_factor=2,mode='bicubic') + assert np.allclose(ju.data,tu.numpy(),rtol=1e-03,atol=1e-06) + # test fold + je = jt.nn.interpolate(ja,scale_factor=2,mode='bicubic',align_corners=True) + te = F.interpolate(ta,scale_factor=2,mode='bicubic',align_corners=True) + assert np.allclose(je.data,te.numpy(),rtol=1e-03,atol=1e-06) + except AssertionError: + print(ju,tu) + print(je,te) + +if __name__ == "__main__": + unittest.main() diff --git a/python/jittor/test/test_fold.py b/python/jittor/test/test_fold.py new file mode 100644 index 00000000..44a305bc --- /dev/null +++ b/python/jittor/test/test_fold.py @@ -0,0 +1,34 @@ +# *************************************************************** +# Copyright (c) 2021 Jittor. All Rights Reserved. +# Maintainers: +# Guoye Yang <498731903@qq.com> +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import torch +from torch.nn import functional as F +import numpy as np + + +class TestFoldOp(unittest.TestCase): + def test_fold(self): + # test unfold first and the test fold. + for _ in range(100): + # test unfold + tn = np.random.randn(1,3,4,4).astype('float32') + ja = jt.array(tn) + ta = torch.from_numpy(tn) + juf = jt.nn.unfold(ja,kernel_size=2,stride=2,dilation=2,padding=2) + tuf = F.unfold(ta,kernel_size=2,stride=2,dilation=2,padding=2) + assert np.allclose(juf.data,tuf.numpy()) + # test fold + jf = jt.nn.fold(juf,output_size=(4,4),kernel_size=2,stride=2,dilation=2,padding=2) + tf = F.fold(tuf,output_size=(4,4),kernel_size=2,stride=2,dilation=2,padding=2) + assert np.allclose(jf.data,tf.numpy()) + +if __name__ == "__main__": + unittest.main() diff --git a/python/jittor/test/test_linalg.py b/python/jittor/test/test_linalg.py index 50cbc02e..2034f4d6 100644 --- a/python/jittor/test/test_linalg.py +++ b/python/jittor/test/test_linalg.py @@ -8,11 +8,12 @@ # This file is subject to the terms and conditions defined in # file 'LICENSE.txt', which is part of this source code package. # *************************************************************** +import torch +from torch.autograd import Variable import jittor as jt import numpy as np import unittest - try: import autograd.numpy as anp from autograd import jacobian @@ -254,6 +255,41 @@ def check_det(a): gx = np.sum(gx, 2) assert np.allclose(gx, jx.data) + def test_qr(self): + for i in range(50): + tn = np.random.randn(3, 3).astype('float32') + while np.allclose(np.linalg.det(tn), 0): + tn = np.random.randn((3, 3)).astype('float32') + x = jt.array(tn) + # x = x.reindex([2, 2, x.shape[0], x.shape[1]], ["i2", "i3"]) + t_x = torch.from_numpy(tn) + t_x = Variable(t_x, requires_grad=True) + jq, jr = qr(x) + tq, tr = torch.qr(t_x) + try: + assert np.allclose(jq.data, tq.detach().numpy()) + assert np.allclose(jr.data, tr.detach().numpy()) + except AssertionError: + print("ours' qr results:") + print(jq) + print(jr) + print("pytorch's qr results:") + print(tq) + print(tr) + gq = jt.grad(jq, x).data + gr = jt.grad(jr, x).data + tgq = torch.autograd.grad(tq, t_x, torch.ones_like(tq), retain_graph=True) + tgr = torch.autograd.grad(tr, t_x, torch.ones_like(tr), retain_graph=True) + try: + assert np.allclose(gq, tgq[0].numpy()) + assert np.allclose(gr, tgr[0].numpy()) + except AssertionError: + print("ours' qr grad results:") + print(gq) + print(gr) + print("pytorch's qr grad result") + print(tgq[0]) + print(tgr[0]) if __name__ == "__main__": unittest.main() From f015999cc6eed248e763e7123b1abbba0c368ac4 Mon Sep 17 00:00:00 2001 From: Exusial <2247838039@qq.com> Date: Tue, 23 Feb 2021 13:28:25 +0800 Subject: [PATCH 04/36] fix bicubic and fold to code_op ver.add grad test. --- python/jittor/nn.py | 75 ++++++++++-------------------- python/jittor/test/test_bicubic.py | 16 +++++-- python/jittor/test/test_fold.py | 23 +++++---- 3 files changed, 49 insertions(+), 65 deletions(-) diff --git a/python/jittor/nn.py b/python/jittor/nn.py index bb71640f..62bccddd 100644 --- a/python/jittor/nn.py +++ b/python/jittor/nn.py @@ -1043,33 +1043,14 @@ def execute(self, x): return resize(x, self.size, self.mode, self.align_corners) -def _bicubic(x, a): - ''' - c1 = jt.where(jt.abs(x)<=1) - x[:,:,c1[0],c1[1]] = (a+2)*jt.abs(x[:,:,c1[0],c1[1]])**3-(a+3)*x[:,:,c1[0],c1[1]]**2+1 - c2 = jt.where(jt.abs(x)>1) - x[:,:,c2[0],c2[1]] = a*jt.abs(x[:,:,c2[0],c2[1]])**3-5*a*x[:,:,c2[0],c2[1]]**2+8*a*jt.abs(x[:,:,c2[0],c2[1]])-4*a - c3 = jt.where(jt.abs(x)>=2) - x[:,:,c3[0],c3[1]] = 0 - return x - ''' +def _bicubic(x, a, func): # normal ver - if jt.abs(x) <= 1: - return (a + 2) * (jt.abs(x) ** 3) - (a + 3) * (x ** 2) + 1 - if jt.abs(x) < 2: - return a * (jt.abs(x) ** 3) - 5 * a * (x ** 2) + 8 * a * (jt.abs(x)) - 4 * a + if func == 1: + return (a+2)*(jt.abs(x)**3)-(a+3)*(x**2)+1 + if func == 2: + return a*(jt.abs(x)**3)-5*a*(x**2)+8*a*(jt.abs(x))-4*a return 0 - ''' - # pytorch ver. - coe = [] - coe.append(((a*(x+1)-5*a)*(x+1)+8*a)*(x+1)-4*a) - coe.append(((a+2)*x-(a+3))*x*x+1) - coe.append(((a+2)*(1-x)-(a+3))*(1-x)*(1-x)+1) - coe.append(((a*(2-x)-5*a)*(2-x)+8*a)*(2-x)-4*a) - return coe; - ''' - def _interpolate(img, x, y, ids, mode): if mode == "nearest": @@ -1087,25 +1068,24 @@ def _interpolate(img, x, y, ids, mode): cd = dx * d + dnx * c o = ab * dny + cd * dy return o - if mode == "bicubic": # ugly ver. - n, c, h, w = img.shape + if mode=="bicubic": # ugly ver. + n,c,h,w = img.shape fx, fy = x.floor(), y.floor() - dx, dy = x - fx, y - fy - outputs = jt.zeros((n, c, x.shape[-2], x.shape[-1])) - for dn in range(n): - for dc in range(c): - for i in range(x.shape[-2]): - for j in range(x.shape[-1]): - for a in range(-1, 3): - for b in range(-1, 3): - nx = max(min(fx[dn, dc, i, j] + a, h - 1), 0) - ny = max(min(fy[dn, dc, i, j] + b, w - 1), 0) - # print(_bicubic(dx[dn,dc,i,j]-a,-0.75)) - # print(_bicubic(dy[dn,dc,i,j]-b,-0.75)) - outputs[dn, dc, i, j] += img[dn, dc, nx, ny] * _bicubic(dx[dn, dc, i, j] - a, - -0.75) * _bicubic( - dy[dn, dc, i, j] - b, -0.75) - return outputs + dix, diy = x - fx, y - fy + ax, ay = _bicubic(dix+1,-0.75,2), _bicubic(diy+1,-0.75,2) + bx, by = _bicubic(dix,-0.75,1), _bicubic(diy,-0.75,1) + cx, cy = _bicubic(1-dix,-0.75,1), _bicubic(1-diy,-0.75,1) + dx, dy = _bicubic(2-dix,-0.75,2), _bicubic(2-diy,-0.75,2) + afx, afy = jt.maximum(jt.minimum(fx-1,h-1),0), jt.maximum(jt.minimum(fy-1,w-1),0) + bfx, bfy = jt.maximum(jt.minimum(fx,h-1),0), jt.maximum(jt.minimum(fy,w-1),0) + cfx, cfy = jt.maximum(jt.minimum(fx+1,h-1),0), jt.maximum(jt.minimum(fy+1,w-1),0) + dfx, dfy = jt.maximum(jt.minimum(fx+2,h-1),0), jt.maximum(jt.minimum(fy+2,w-1),0) + a = ax*(img.reindex_var([*ids,afx,afy])*ay+img.reindex_var([*ids,afx,bfy])*by+img.reindex_var([*ids,afx,cfy])*cy+img.reindex_var([*ids,afx,dfy])*dy) + b = bx*(img.reindex_var([*ids,bfx,afy])*ay+img.reindex_var([*ids,bfx,bfy])*by+img.reindex_var([*ids,bfx,cfy])*cy+img.reindex_var([*ids,bfx,dfy])*dy) + c = cx*(img.reindex_var([*ids,cfx,afy])*ay+img.reindex_var([*ids,cfx,bfy])*by+img.reindex_var([*ids,cfx,cfy])*cy+img.reindex_var([*ids,cfx,dfy])*dy) + d = dx*(img.reindex_var([*ids,dfx,afy])*ay+img.reindex_var([*ids,dfx,bfy])*by+img.reindex_var([*ids,dfx,cfy])*cy+img.reindex_var([*ids,dfx,dfy])*dy) + o = a + b + c + d + return o raise (f"Not support interpolation mode: {mode}") @@ -1491,7 +1471,7 @@ def unfold(X, kernel_size, dilation=1, padding=0, stride=1): return output -def fold(X,output_size,kernel_size,dilation=1,padding=0,stride=1):# this may be implemented in C for speed? +def fold(X,output_size,kernel_size,dilation=1,padding=0,stride=1): assert X.ndim==3 if not isinstance(kernel_size,tuple): kernel_size = (kernel_size,kernel_size) @@ -1504,16 +1484,9 @@ def fold(X,output_size,kernel_size,dilation=1,padding=0,stride=1):# this may be n,cl,num = X.shape area = kernel_size[0] * kernel_size[1] block_nums = [] - output = jt.zeros((n,cl // area,output_size[0]+2*padding[0],output_size[1]+2*padding[1])) for i in range(2,4): block_nums.append((output_size[i-2]+2*padding[i-2]-dilation[i-2]*(kernel_size[i-2]-1)-1) // stride[i-2]+1) - for dn in range(n): - for c in range(cl // area): - for i in range(num): - for j in range(area): - h = i//block_nums[1]*stride[0]+(j%area)/kernel_size[1]*dilation[0] - w = i%block_nums[1]*stride[1]+j%area%kernel_size[1]*dilation[1] - output[dn,c,i//block_nums[1]*stride[0]+(j%area)//kernel_size[1]*dilation[0],i%block_nums[1]*stride[1]+j%area%kernel_size[1]*dilation[1]]+=X[dn,c*area+j,i] + output = X.reindex_reduce("add",[n,cl // area,output_size[0]+2*padding[0],output_size[1]+2*padding[1]],["i0",f"i1/{area}",f"i2/{block_nums[1]}*{stride[0]}+(i1%{area})/{kernel_size[1]}*{dilation[0]}",f"i2%{block_nums[1]}*{stride[1]}+(i1%{area})%{kernel_size[1]}*{dilation[1]}"]) return output[:,:,padding[0]:padding[0]+output_size[0],padding[1]:padding[1]+output_size[1]] ModuleList = Sequential diff --git a/python/jittor/test/test_bicubic.py b/python/jittor/test/test_bicubic.py index a07ed706..b99e908a 100644 --- a/python/jittor/test/test_bicubic.py +++ b/python/jittor/test/test_bicubic.py @@ -21,18 +21,24 @@ def test_bicubic(self): try: tn = np.random.randn(1,1,5,5).astype('float32') ja = jt.array(tn) - ta = torch.from_numpy(tn) + ta = torch.autograd.Variable(torch.from_numpy(tn),requires_grad=True) # test upsample ju = jt.nn.interpolate(ja,scale_factor=2,mode='bicubic') tu = F.interpolate(ta,scale_factor=2,mode='bicubic') - assert np.allclose(ju.data,tu.numpy(),rtol=1e-03,atol=1e-06) - # test fold + assert np.allclose(ju.data,tu.detach().numpy(),rtol=1e-03,atol=1e-06) + gju = jt.grad(ju,ja) + gtu = torch.autograd.grad(tu,ta,torch.ones_like(tu),retain_graph=True)[0] + assert np.allclose(gju.data,gtu.detach().numpy(),rtol=1e-03,atol=1e-06) + # test align je = jt.nn.interpolate(ja,scale_factor=2,mode='bicubic',align_corners=True) te = F.interpolate(ta,scale_factor=2,mode='bicubic',align_corners=True) - assert np.allclose(je.data,te.numpy(),rtol=1e-03,atol=1e-06) + assert np.allclose(je.data,te.detach().numpy(),rtol=1e-03,atol=1e-06) + gje = jt.grad(je,ja) + gte = torch.autograd.grad(te,ta,torch.ones_like(tu),retain_graph=True)[0] + assert np.allclose(gje.data,gte.detach().numpy(),rtol=1e-03,atol=1e-06) except AssertionError: print(ju,tu) print(je,te) if __name__ == "__main__": - unittest.main() + unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_fold.py b/python/jittor/test/test_fold.py index 44a305bc..f1000348 100644 --- a/python/jittor/test/test_fold.py +++ b/python/jittor/test/test_fold.py @@ -17,18 +17,23 @@ class TestFoldOp(unittest.TestCase): def test_fold(self): # test unfold first and the test fold. - for _ in range(100): - # test unfold - tn = np.random.randn(1,3,4,4).astype('float32') + for i in range(4,10): + tn = np.random.randn(1,3,i,i).astype('float32') ja = jt.array(tn) - ta = torch.from_numpy(tn) + ta = torch.autograd.Variable(torch.from_numpy(tn),requires_grad=True) juf = jt.nn.unfold(ja,kernel_size=2,stride=2,dilation=2,padding=2) tuf = F.unfold(ta,kernel_size=2,stride=2,dilation=2,padding=2) - assert np.allclose(juf.data,tuf.numpy()) + assert np.allclose(juf.data,tuf.detach().numpy()) + gjuf = jt.grad(juf,ja) + gtuf = torch.autograd.grad(tuf,ta,torch.ones_like(tuf),retain_graph=True)[0] + assert np.allclose(gjuf.data,gtuf.detach().numpy()) # test fold - jf = jt.nn.fold(juf,output_size=(4,4),kernel_size=2,stride=2,dilation=2,padding=2) - tf = F.fold(tuf,output_size=(4,4),kernel_size=2,stride=2,dilation=2,padding=2) - assert np.allclose(jf.data,tf.numpy()) + jf = jt.nn.fold(juf,output_size=(i,i),kernel_size=2,stride=2,dilation=2,padding=2) + tf = F.fold(tuf,output_size=(i,i),kernel_size=2,stride=2,dilation=2,padding=2) + assert np.allclose(jf.data,tf.detach().numpy()) + gjf = jt.grad(jf,juf) + gtf = torch.autograd.grad(tf,tuf,torch.ones_like(tf),retain_graph=True)[0] + assert np.allclose(gjf.data,gtf.detach().numpy()) if __name__ == "__main__": - unittest.main() + unittest.main() \ No newline at end of file From cc7c0bbeabffa454cf7003c91c4533101be21b4b Mon Sep 17 00:00:00 2001 From: Exusial <2247838039@qq.com> Date: Tue, 23 Feb 2021 20:44:54 +0800 Subject: [PATCH 05/36] add docs.update pinv to support (..,M,N) shape --- doc/source/jittor.linalg.md | 47 +++++++++ python/jittor/linalg.py | 71 ++++++++++++-- python/jittor/test/test_linalg.py | 154 +++++++++++++++--------------- 3 files changed, 188 insertions(+), 84 deletions(-) diff --git a/doc/source/jittor.linalg.md b/doc/source/jittor.linalg.md index a61952d5..cce65a78 100644 --- a/doc/source/jittor.linalg.md +++ b/doc/source/jittor.linalg.md @@ -3,8 +3,55 @@ jittor.linalg 这里是Jittor的线性代数函数的API文档,您可以通过`from jittor import linalg`来获取该模块。 +## 基本函数简介 +#### 基本线性代数运算API +- linalg.inv(a) + + 对a进行求逆运算 + +- linalg.pinv(a) + + 对a进行广义求逆运算。该运算不要求原矩阵a可逆。 + +- linalg.slogdet(a) + + 对a求取slogdet。会返回值以及符号。 + +- linalg.det(a) + + 对a求行列式。 + +- linalg.solve(a,b) + + 求解线性方程Ax=b的解。 + +#### 分解API +- linalg.cholesky(a) + + 对a进行cholesky分解。 + +- linalg.qr(a) + + 对a进行qr分解。 + +- linalg.svd + + 对a进行奇异值分解。 +#### 特征值API +- linalg.eig(a) + + 求取a的特征值以及特征向量。 + +- linalg.eigh(a) + + 针对埃尔米特矩阵或者对称矩阵求特征值以及特征向量。 + + +目前的linalg库支持 + ```eval_rst .. automodule:: jittor.linalg :members: :undoc-members: ``` + diff --git a/python/jittor/linalg.py b/python/jittor/linalg.py index 7c869218..15d44662 100644 --- a/python/jittor/linalg.py +++ b/python/jittor/linalg.py @@ -14,7 +14,18 @@ #TODO:full_matrices=1 def svd(x): - + r''' + calculate the Singular Value Decomposition of x.It follows the below fomula: + x = usv* + only support full matrices == False ver now, which means: + x's shape (...,M,K) + u's shape (...,M,K) + s's shape (...,K) + v's shape (...,K,N) + where K is min(M,N). + :param x: + :return:u,s,v. + ''' def forward_code(np, data): a = data["inputs"][0] u, s, v = data["outputs"] @@ -84,7 +95,13 @@ def T(x): def eigh(x): - + r""" + calculate the eigenvalues and eigenvectors of x. + :param x (...,M,M): + :return:w, v. + w (...,M) : the eigenvalues. + v (...,M,M) : normalized eigenvectors. + """ def forward_code(np, data): a = data["inputs"][0] w, v = data["outputs"] @@ -126,7 +143,11 @@ def T(x): def inv(x): - + r""" + calculate the inverse of x. + :param x (...,M,M): + :return:x^-1 (...,M,M). + """ def forward_code(np, data): a = data["inputs"][0] m_a = data["outputs"][0] @@ -156,7 +177,11 @@ def T(x): def pinv(x): - + r""" + calculate the pseudo-inverse of a x. + :param x (...,M,N) + :return: x's pinv (...N,M) + """ def forward_code(np, data): a = data["inputs"][0] m_a = data["outputs"][0] @@ -178,9 +203,9 @@ def T(x): + _dot(_dot(_dot(np.eye(mx.shape[-2]) - _dot(mx, inp), dout), T(mx)), mx) ) np.copyto(out, t) - + sw = list(x.shape[:-2]) + [x.shape[-1]] + [x.shape[-2]] lmx = jt.numpy_code( - [x.shape], + [sw], [x.dtype], [x], forward_code, @@ -191,7 +216,11 @@ def T(x): def det(x): - + r""" + calculate the determinant of x. + :param x (...,M,M): + :return:|x| (...,1) + """ def forward_code(np, data): a = data["inputs"][0] L = data["outputs"][0] @@ -227,6 +256,13 @@ def T(x): def slogdet(x): + r""" + calculate the sign and log of the determinant of x. + :param x (...,M,M): + :return sign, x's logdet. + sign array decides the sign of determinant and their values can be -1,0,1.Only Real number now.0 means det is 0 and logdet is -inf. + logdet in shape (...,1). + """ def forward_code(np, data): a = data["inputs"][0] sign, m_a = data["outputs"] @@ -264,7 +300,13 @@ def T(x): def cholesky(x): - + r""" + do Cholesky decomposition of x in the form of below formula: + x = LL^T + x must be a Hermite and positive-definite matrix. L is a lower-triangular matrix. + :param x (...,M,M): + :return: L (...,M,M). + """ def forward_code(np, data): a = data["inputs"][0] L = data["outputs"][0] @@ -300,7 +342,12 @@ def conjugate_solve(L, X): def solve(a,b): - + r""" + Solve a linear matrix equation Ax = B.This is done by calculating x = A^-1B.So A must not be singular. + :param a:(...,M,M) + :param b:(...,M) + :return:solution of Ax = b formula.x in the shape of (...M) + """ def forward_code(np, data): a, b = data["inputs"] L = data["outputs"][0] @@ -335,6 +382,12 @@ def backward_code2(np, data): def qr(x): + r""" + do the qr factorization of x in the below formula: + x = QR where Q is orthogonal matrix and R is upper-triangle matrix. + :param x (...,M,M): + :return:q,r as the result of qr factorization.They are both in the shape of (...,M,M). + """ def forward_code(np, data): a = data["inputs"][0] q, r = data["outputs"] diff --git a/python/jittor/test/test_linalg.py b/python/jittor/test/test_linalg.py index 2034f4d6..97081b37 100644 --- a/python/jittor/test/test_linalg.py +++ b/python/jittor/test/test_linalg.py @@ -17,41 +17,43 @@ try: import autograd.numpy as anp from autograd import jacobian + has_autograd = True except: has_autograd = False + @unittest.skipIf(not has_autograd, "No autograd found.") class TestCodeOp(unittest.TestCase): def test_svd(self): def check_svd(a): - u,s,v = anp.linalg.svd(a, full_matrices=0) - return u,s,v + u, s, v = anp.linalg.svd(a, full_matrices=0) + return u, s, v def check_u(a): - u,s,v = anp.linalg.svd(a, full_matrices=0) + u, s, v = anp.linalg.svd(a, full_matrices=0) return u def check_s(a): - u,s,v = anp.linalg.svd(a, full_matrices=0) + u, s, v = anp.linalg.svd(a, full_matrices=0) return s def check_v(a): - u,s,v = anp.linalg.svd(a, full_matrices=0) + u, s, v = anp.linalg.svd(a, full_matrices=0) return v for i in range(50): - #not for full-matrices! - a = jt.random((2,2,5,4)) + # not for full-matrices! + a = jt.random((2, 2, 5, 4)) c_a = anp.array(a.data) - u,s,v = jt.linalg.svd(a) - tu,ts,tv = check_svd(c_a) - assert np.allclose(tu,u.data) - assert np.allclose(ts,s.data) - assert np.allclose(tv,v.data) - ju = jt.grad(u,a) - js = jt.grad(s,a) - jv = jt.grad(v,a) + u, s, v = jt.linalg.svd(a) + tu, ts, tv = check_svd(c_a) + assert np.allclose(tu, u.data) + assert np.allclose(ts, s.data) + assert np.allclose(tv, v.data) + ju = jt.grad(u, a) + js = jt.grad(s, a) + jv = jt.grad(v, a) grad_u = jacobian(check_u) gu = grad_u(c_a) gu = np.sum(gu, 4) @@ -70,56 +72,56 @@ def check_v(a): gv = np.sum(gv, 2) gv = np.sum(gv, 2) try: - assert np.allclose(ju.data,gu,atol=1e-5) + assert np.allclose(ju.data, gu, atol=1e-5) except AssertionError: print(ju.data) print(gu) try: - assert np.allclose(js.data,gs,atol=1e-5) + assert np.allclose(js.data, gs, atol=1e-5) except AssertionError: print(js.data) print(gs) try: - assert np.allclose(jv.data,gv,atol=1e-5) + assert np.allclose(jv.data, gv, atol=1e-5) except AssertionError: print(jv.data) print(gv) def test_eigh(self): - def check_eigh(a,UPLO='L'): - w, v = anp.linalg.eigh(a,UPLO) + def check_eigh(a, UPLO='L'): + w, v = anp.linalg.eigh(a, UPLO) return w, v - def check_w(a,UPLO='L'): - w, v = anp.linalg.eigh(a,UPLO) + def check_w(a, UPLO='L'): + w, v = anp.linalg.eigh(a, UPLO) return w - def check_v(a,UPLO='L'): - w, v = anp.linalg.eigh(a,UPLO) + def check_v(a, UPLO='L'): + w, v = anp.linalg.eigh(a, UPLO) return v for i in range(50): - a = jt.random((2,2,3,3)) + a = jt.random((2, 2, 3, 3)) c_a = a.data w, v = jt.linalg.eigh(a) tw, tv = check_eigh(c_a) - assert np.allclose(w.data,tw) - assert np.allclose(v.data,tv) + assert np.allclose(w.data, tw) + assert np.allclose(v.data, tv) jw = jt.grad(w, a) jv = jt.grad(v, a) check_gw = jacobian(check_w) check_gv = jacobian(check_v) gw = check_gw(c_a) - gw = np.sum(gw,4) - gw = np.sum(gw,2) - gw = np.sum(gw,2) - assert np.allclose(gw,jw.data,rtol = 1,atol = 5e-8) + gw = np.sum(gw, 4) + gw = np.sum(gw, 2) + gw = np.sum(gw, 2) + assert np.allclose(gw, jw.data, rtol=1, atol=5e-8) gv = check_gv(c_a) - gv = np.sum(gv,4) - gv = np.sum(gv,4) - gv = np.sum(gv,2) - gv = np.sum(gv,2) - assert np.allclose(gv,jv.data,rtol = 1,atol = 5e-8) + gv = np.sum(gv, 4) + gv = np.sum(gv, 4) + gv = np.sum(gv, 2) + gv = np.sum(gv, 2) + assert np.allclose(gv, jv.data, rtol=1, atol=5e-8) def test_pinv(self): def check_pinv(a): @@ -127,34 +129,35 @@ def check_pinv(a): return w for i in range(50): - x = jt.random((2,2,4,4)) + x = jt.random((2, 2, 4, 3)) c_a = x.data mx = jt.linalg.pinv(x) tx = check_pinv(c_a) - np.allclose(mx.data,tx) - jx = jt.grad(mx,x) + np.allclose(mx.data, tx) + jx = jt.grad(mx, x) check_grad = jacobian(check_pinv) gx = check_grad(c_a) - np.allclose(gx,jx.data) + np.allclose(gx, jx.data) def test_inv(self): def check_inv(a): w = anp.linalg.inv(a) return w + for i in range(50): - tn = np.random.randn(4,4).astype('float32')*5 - while np.allclose(np.linalg.det(tn),0): - tn = np.random.randn((4,4)).astype('float32')*5 + tn = np.random.randn(4, 4).astype('float32') * 5 + while np.allclose(np.linalg.det(tn), 0): + tn = np.random.randn((4, 4)).astype('float32') * 5 x = jt.array(tn) - x = x.reindex([2,2,x.shape[0],x.shape[1]],["i2","i3"]) + x = x.reindex([2, 2, x.shape[0], x.shape[1]], ["i2", "i3"]) c_a = x.data mx = jt.linalg.inv(x) tx = check_inv(c_a) - np.allclose(mx.data,tx) - jx = jt.grad(mx,x) + np.allclose(mx.data, tx) + jx = jt.grad(mx, x) check_grad = jacobian(check_inv) gx = check_grad(c_a) - np.allclose(gx,jx.data) + np.allclose(gx, jx.data) def test_slogdet(self): def check_ans(a): @@ -166,11 +169,11 @@ def check_slogdet(a): return w for i in range(50): - tn = np.random.randn(4,4).astype('float32')*10 - while np.allclose(np.linalg.det(tn),0): - tn = np.random.randn((4,4)).astype('float32')*10 + tn = np.random.randn(4, 4).astype('float32') * 10 + while np.allclose(np.linalg.det(tn), 0): + tn = np.random.randn((4, 4)).astype('float32') * 10 x = jt.array(tn) - x = x.reindex([2,2,x.shape[0],x.shape[1]],["i2","i3"]) + x = x.reindex([2, 2, x.shape[0], x.shape[1]], ["i2", "i3"]) s = list(x.shape) det_s = s[:-2] if len(det_s) == 0: @@ -179,12 +182,12 @@ def check_slogdet(a): ts, ta = check_ans(x.data) assert np.allclose(sign.data, ts) assert np.allclose(mx.data, ta) - jx = jt.grad(mx,x) + jx = jt.grad(mx, x) check_sgrad = jacobian(check_slogdet) gx = check_sgrad(x.data) - gx = np.sum(gx,2) - gx = np.sum(gx,2) - assert np.allclose(gx,jx.data) + gx = np.sum(gx, 2) + gx = np.sum(gx, 2) + assert np.allclose(gx, jx.data) def test_cholesky(self): def check_cholesky(a): @@ -193,39 +196,39 @@ def check_cholesky(a): for i in range(50): x = jt.array(np.diag((np.random.rand(3) + 1) * 2)) - x = x.reindex([2,2,x.shape[0],x.shape[1]],["i2","i3"]) + x = x.reindex([2, 2, x.shape[0], x.shape[1]], ["i2", "i3"]) tx = x.data L = jt.linalg.cholesky(x) tL = check_cholesky(tx) - assert np.allclose(tL,L.data) - jx = jt.grad(L,x) + assert np.allclose(tL, L.data) + jx = jt.grad(L, x) check_grad = jacobian(check_cholesky) gx = check_grad(tx) gx = np.sum(gx, 0) gx = np.sum(gx, 0) gx = np.sum(gx, 0) gx = np.sum(gx, 0) - assert np.allclose(jx.data,gx) + assert np.allclose(jx.data, gx) def test_solve(self): - def check_solve(a,b): - ans = anp.linalg.solve(a,b) + def check_solve(a, b): + ans = anp.linalg.solve(a, b) return ans for i in range(50): - a = jt.random((2,2,3,3)) - b = jt.random((2,2,3)) - ans = jt.linalg.solve(a,b) - ta = check_solve(a.data,b.data) + a = jt.random((2, 2, 3, 3)) + b = jt.random((2, 2, 3)) + ans = jt.linalg.solve(a, b) + ta = check_solve(a.data, b.data) assert np.allclose(ans.data, ta) jx = jt.grad(ans, a) check_sgrad = jacobian(check_solve) - gx = check_sgrad(a.data,b.data) - gx = np.sum(gx,0) - gx = np.sum(gx,0) - gx = np.sum(gx,0) + gx = check_sgrad(a.data, b.data) + gx = np.sum(gx, 0) + gx = np.sum(gx, 0) + gx = np.sum(gx, 0) try: - assert np.allclose(gx, jx.data,rtol=1) + assert np.allclose(gx, jx.data, rtol=1) except AssertionError: print(gx) print(jx.data) @@ -264,11 +267,11 @@ def test_qr(self): # x = x.reindex([2, 2, x.shape[0], x.shape[1]], ["i2", "i3"]) t_x = torch.from_numpy(tn) t_x = Variable(t_x, requires_grad=True) - jq, jr = qr(x) + jq, jr = jt.linalg.qr(x) tq, tr = torch.qr(t_x) try: - assert np.allclose(jq.data, tq.detach().numpy()) - assert np.allclose(jr.data, tr.detach().numpy()) + assert np.allclose(jq.data, tq.detach().numpy(), rtol=1e-4, atol=1e-6) + assert np.allclose(jr.data, tr.detach().numpy(), rtol=1e-4, atol=1e-6) except AssertionError: print("ours' qr results:") print(jq) @@ -281,8 +284,8 @@ def test_qr(self): tgq = torch.autograd.grad(tq, t_x, torch.ones_like(tq), retain_graph=True) tgr = torch.autograd.grad(tr, t_x, torch.ones_like(tr), retain_graph=True) try: - assert np.allclose(gq, tgq[0].numpy()) - assert np.allclose(gr, tgr[0].numpy()) + assert np.allclose(gq, tgq[0].numpy(), rtol=1e-4, atol=1e-6) + assert np.allclose(gr, tgr[0].numpy(), rtol=1e-4, atol=1e-6) except AssertionError: print("ours' qr grad results:") print(gq) @@ -291,5 +294,6 @@ def test_qr(self): print(tgq[0]) print(tgr[0]) + if __name__ == "__main__": unittest.main() From 9c722a902cae133e0514eefe7f98b645112668a7 Mon Sep 17 00:00:00 2001 From: Exusial <2247838039@qq.com> Date: Fri, 26 Feb 2021 09:09:46 +0800 Subject: [PATCH 06/36] edit maintainer and testfunc's name. --- python/jittor/misc.py | 4 ---- python/jittor/test/test_bicubic.py | 1 + python/jittor/test/test_fold.py | 1 + python/jittor/test/test_linalg.py | 2 +- 4 files changed, 3 insertions(+), 5 deletions(-) diff --git a/python/jittor/misc.py b/python/jittor/misc.py index 3571b70c..f88f4d6c 100644 --- a/python/jittor/misc.py +++ b/python/jittor/misc.py @@ -993,10 +993,6 @@ def grad(self, grad_a): dim+=len(a.shape) return func(a, dim) -def eye(n): - x = jt.ones((n)) - return x.reindex([n,n],["i0"],overflow_conditions=[f'i0!=i1']) - def linspace(start, end, steps): res = jt.index((steps,))[0] res = res*(end-start)/float(steps-1)+start diff --git a/python/jittor/test/test_bicubic.py b/python/jittor/test/test_bicubic.py index b99e908a..69441274 100644 --- a/python/jittor/test/test_bicubic.py +++ b/python/jittor/test/test_bicubic.py @@ -1,6 +1,7 @@ # *************************************************************** # Copyright (c) 2021 Jittor. All Rights Reserved. # Maintainers: +# Haoyang Peng <2247838039@qq.com> # Guoye Yang <498731903@qq.com> # Dun Liang . # diff --git a/python/jittor/test/test_fold.py b/python/jittor/test/test_fold.py index f1000348..bc394e47 100644 --- a/python/jittor/test/test_fold.py +++ b/python/jittor/test/test_fold.py @@ -1,6 +1,7 @@ # *************************************************************** # Copyright (c) 2021 Jittor. All Rights Reserved. # Maintainers: +# Haoyang Peng <2247838039@qq.com> # Guoye Yang <498731903@qq.com> # Dun Liang . # diff --git a/python/jittor/test/test_linalg.py b/python/jittor/test/test_linalg.py index 97081b37..bdbb7a54 100644 --- a/python/jittor/test/test_linalg.py +++ b/python/jittor/test/test_linalg.py @@ -24,7 +24,7 @@ @unittest.skipIf(not has_autograd, "No autograd found.") -class TestCodeOp(unittest.TestCase): +class TestLinalgOp(unittest.TestCase): def test_svd(self): def check_svd(a): u, s, v = anp.linalg.svd(a, full_matrices=0) From cb705f7eb6219f7352c69362d30ffe2661e36ddd Mon Sep 17 00:00:00 2001 From: Gword <471184555@qq.com> Date: Tue, 2 Mar 2021 16:34:17 +0800 Subject: [PATCH 07/36] fix nn --- python/jittor/nn.py | 801 +++++++++++++++++++------------------------- 1 file changed, 351 insertions(+), 450 deletions(-) diff --git a/python/jittor/nn.py b/python/jittor/nn.py index 4573042b..41ef869b 100644 --- a/python/jittor/nn.py +++ b/python/jittor/nn.py @@ -31,12 +31,12 @@ def matmul_transpose(a, b): if len(a.shape)>2: aa = a.reshape((-1, a.shape[-1])) cc = matmul_transpose(aa, b) - return cc.reshape(a.shape[:-1] + (-1,)) + return cc.reshape(a.shape[:-1]+(-1,)) shape = list(a.shape)[:-1] + list(b.shape) - a = a.broadcast(shape, [len(shape) - 2]) + a = a.broadcast(shape, [len(shape)-2]) b = b.broadcast(shape) - return (a * b).sum(len(shape) - 1) + return (a*b).sum(len(shape)-1) def bmm_transpose(a, b): @@ -55,14 +55,10 @@ def bmm(a, b): shape of input a is [batch, n, m], shape of input b is [batch, m, k], return shape is [batch, n, k] - Example:: - import jittor as jt from jittor import nn - batch, n, m, k = 100, 5, 6, 7 - a = jt.random((batch, n, m)) b = jt.random((batch, m, k)) c = nn.bmm(a, b) @@ -70,37 +66,29 @@ def bmm(a, b): assert len(a.shape) > 2 and len(b.shape) > 2 return matmul(a, b) - def matmul(a, b): ''' matrix multiply, - Example:: - a = jt.random([3]) b = jt.random([3]) c = jt.matmul(a, b) assert c.shape == [1] - a = jt.random([3, 4]) b = jt.random([4]) c = jt.matmul(a, b) assert c.shape == [3] - a = jt.random([10, 3, 4]) b = jt.random([4]) c = jt.matmul(a, b) assert c.shape == [10, 3] - a = jt.random([10, 3, 4]) b = jt.random([4, 5]) c = jt.matmul(a, b) assert c.shape == [10, 3, 5] - a = jt.random([10, 3, 4]) b = jt.random([10, 4, 5]) c = jt.matmul(a, b) assert c.shape == [10, 3, 5] - a = jt.random([8, 1, 3, 4]) b = jt.random([10, 4, 5]) c = jt.matmul(a, b) @@ -110,11 +98,11 @@ def matmul(a, b): len_b = len(b.shape) if len_b == 1: # a: [n, m], b:[m], c:[n] - return (a * b).sum(-1) + return (a*b).sum(-1) if len_a == 1: # a: [n], b:[n,k], c:[k] return (a.broadcast(b, [-1]) * b).sum(0) - if len_a >= 3 and len_a == len_b: + if len_a>=3 and len_a==len_b: # bmm # a: [..., n, m], b: [..., m, k], c:[..., n, k] if jt.flags.use_cuda: @@ -128,67 +116,52 @@ def matmul(a, b): # cc:[..., n, m, k] # --> # 012 - if len_b == 2 and len_a > 2: + if len_b == 2 and len_a>2: # TODO:ugly implementation for tuner aa = a.reshape((-1, m)) cc = matmul(aa, b) # print(a.shape, b.shape, cc.shape) return cc.reshape(a.shape[:-1] + [k]) - for i in range(len_c - 2): - ai = len_a - (len_c - i) - bi = len_b - (len_c - i) - an = a.shape[ai] if ai >= 0 else 1 - bn = b.shape[bi] if bi >= 0 else 1 - if an != 1 and bn != 1: + for i in range(len_c-2): + ai = len_a-(len_c-i) + bi = len_b-(len_c-i) + an = a.shape[ai] if ai>=0 else 1 + bn = b.shape[bi] if bi>=0 else 1 + if an!=1 and bn!=1: assert an == bn, f"dimension not match, a.shape:{a.shape}, b.shape:{a.shape}" cn = max(an, bn) shape.append(cn) shape.extend([n, m, k]) a = a.broadcast(shape, [-1]) b = b.broadcast(shape, [-3]) - return (a * b).sum(-2) - - + return (a*b).sum(-2) jt.Var.matmul = jt.Var.__matmul__ = matmul -jt.Var.__imatmul__ = lambda a, b: a.assign(matmul(a, b)) - +jt.Var.__imatmul__ = lambda a,b: a.assign(matmul(a,b)) def get_init_var_rand(shape, dtype): return jt.array(np.random.normal(0.0, 1.0, shape).astype(np.float32)) - -def relu(x): return jt.ternary((x > 0.0), x, jt.broadcast_var(0.0, x)) - - -def leaky_relu(x, scale=0.01): return jt.ternary(x > 0, x, x * scale) - - +def relu(x): return jt.ternary((x>0.0), x, jt.broadcast_var(0.0, x)) +def leaky_relu(x, scale=0.01): return jt.ternary(x>0, x, x*scale) def relu6(x): return jt.minimum(jt.maximum(x, 0.0), 6.0) - - -def elu(x, alpha=1.0): return jt.ternary(x > 0, x, alpha * (x.exp() - 1)) - - +def elu(x,alpha=1.0):return jt.ternary(x>0,x,alpha*(x.exp()-1)) def sign(x): one = jt.ones(x.shape) - x = jt.ternary(x > 0, one, x) - return jt.ternary(x < 0, -one, x) - + x = jt.ternary(x>0, one, x) + return jt.ternary(x<0, -one, x) def gelu(x): _sqrt2 = 1.4142135623730951 - erf = jt.erf(x / _sqrt2) + 1 - r = erf * x * .5 + erf = jt.erf(x/_sqrt2)+1 + r = erf*x*.5 return r - class ELU(Module): - def __init__(self, alpha=1.0): - self.alpha = alpha - - def execute(self, x): - return elu(x, self.alpha) - + def __init__(self,alpha=1.0): + self.alpha=alpha + + def execute(self,x): + return elu(x,self.alpha) class PReLU(Module): def __init__(self, num_parameters=1, init_=0.25): @@ -198,150 +171,135 @@ def __init__(self, num_parameters=1, init_=0.25): def execute(self, x): if self.num_parameters != 1: assert self.num_parameters == x.size(1), f"num_parameters does not match input channels in PReLU" - return jt.maximum(0, x) + self.a.broadcast(x, [0, 2, 3]) * jt.minimum(0, x) + return jt.maximum(0, x) + self.a.broadcast(x, [0,2,3]) * jt.minimum(0, x) else: return jt.maximum(0, x) + self.a * jt.minimum(0, x) - -# TODO dims is 4 will cause slowly execution +#TODO dims is 4 will cause slowly execution def cross_entropy_loss(output, target, ignore_index=None): if len(output.shape) == 4: c_dim = output.shape[1] output = output.transpose((0, 2, 3, 1)) output = output.reshape((-1, c_dim)) if ignore_index is not None: - target = jt.ternary(target == ignore_index, - jt.array(-1).broadcast(target), target) + target = jt.ternary(target==ignore_index, + jt.array(-1).broadcast(target), target) mask = jt.logical_and(target >= 0, target < output.shape[1]) - target = target.reshape((-1,)) + target = target.reshape((-1, )) target = target.broadcast(output, [1]) target = target.index(1) == target - + output = output - output.max([1], keepdims=True) loss = output.exp().sum(1).log() - loss = loss - (output * target).sum(1) + loss = loss - (output*target).sum(1) if ignore_index is None: return loss.mean() else: return loss.sum() / jt.maximum(mask.int().sum(), 1) - def mse_loss(output, target): - return (output - target).sqr().mean() - + return (output-target).sqr().mean() def bce_loss(output, target, weight=None, size_average=True): loss = - (target * jt.log(jt.maximum(output, 1e-20)) + (1 - target) * jt.log(jt.maximum(1 - output, 1e-20))) if weight is not None: loss *= weight - + if size_average: return loss.mean() else: return loss.sum() - def l1_loss(output, target): - return (output - target).abs().mean() + return (output-target).abs().mean() -def smooth_l1_loss(y_true, y_pred, reduction="mean"): +def smooth_l1_loss(y_true, y_pred,reduction="mean"): """Implements Smooth-L1 loss. y_true and y_pred are typically: [N, 4], but could be any shape. - Args: y_true - ground truth y_pred - predictions reduction - the mode of cal loss which must be in ['mean','sum','none'] """ diff = jt.abs(y_true - y_pred) - less_than_one = (diff < 1.0).float32() + less_than_one = (diff<1.0).float32() loss = (less_than_one * 0.5 * diff.sqr()) + (1 - less_than_one) * (diff - 0.5) - if reduction == "mean": + if reduction=="mean": return loss.mean() - elif reduction == "sum": + elif reduction=="sum": return loss.sum() - elif reduction == "none": + elif reduction=="none": return loss else: raise ValueError(f'not support {reduction}') - -def nll_loss(output, target, weight=None, ignore_index=-100, reduction='mean'): - assert output.ndim <= 2 and output.ndim > 0 and target.ndim == 1 +def nll_loss(output,target,weight=None,ignore_index=-100,reduction='mean'): + assert output.ndim<=2 and output.ndim>0 and target.ndim==1 n_classes = output.shape[-1] - assert weight is None or weight.numel() == n_classes - assert ignore_index < 0 or ignore_index < n_classes + assert weight is None or weight.numel()==n_classes + assert ignore_index<0 or ignore_index 0: - weight[ignore_index] = 0 - if output.ndim == 2: - index = jt.index((output.shape[0],), dim=0) - loss = -output[index, target] * weight[target] + if ignore_index>0: + weight[ignore_index]=0 + if output.ndim==2: + index = jt.index((output.shape[0],),dim=0) + loss = -output[index,target]*weight[target] else: - loss = -output[target[0]] * weight[target[0]] - if reduction == "mean": - total_weight = weight[target].sum() if output.ndim == 2 else weight[target[0]].sum() - return loss.sum() / total_weight - elif reduction == "sum": + loss = -output[target[0]]*weight[target[0]] + if reduction=="mean": + total_weight = weight[target].sum() if output.ndim==2 else weight[target[0]].sum() + return loss.sum()/total_weight + elif reduction=="sum": return loss.sum() - elif reduction == "none": + elif reduction=="none": return loss else: raise ValueError(f'not support {reduction}') - - + class CrossEntropyLoss(Module): - def __init__(self, ignore_index=None): + def __init__(self,ignore_index=None): self.ignore_index = ignore_index - + def execute(self, output, target): - return cross_entropy_loss(output, target, self.ignore_index) - + return cross_entropy_loss(output, target,self.ignore_index) class MSELoss(Module): def __init__(self): pass - def execute(self, output, target): return mse_loss(output, target) - class BCELoss(Module): def __init__(self, weight=None, size_average=True): self.weight = weight self.size_average = size_average - def execute(self, output, target): return bce_loss(output, target, self.weight, self.size_average) - class L1Loss(Module): def __init__(self): pass - def execute(self, output, target): return l1_loss(output, target) - def binary_cross_entropy_with_logits(output, target, weight=None, pos_weight=None, size_average=True): - max_val = jt.clamp(-output, min_v=0) + max_val = jt.clamp(-output,min_v=0) if pos_weight is not None: - log_weight = (pos_weight - 1) * target + 1 - loss = (1 - target) * output + (log_weight * (((-max_val).exp() + (-output - max_val).exp()).log() + max_val)) + log_weight = (pos_weight-1)*target + 1 + loss = (1-target)*output+(log_weight*(((-max_val).exp()+(-output - max_val).exp()).log()+max_val)) else: - loss = (1 - target) * output + max_val + ((-max_val).exp() + (-output - max_val).exp()).log() + loss = (1-target)*output+max_val+((-max_val).exp()+(-output -max_val).exp()).log() if weight is not None: - loss *= weight + loss *=weight if size_average: return loss.mean() else: return loss.sum() - class BCEWithLogitsLoss(Module): def __init__(self, weight=None, pos_weight=None, size_average=True): self.pos_weight = pos_weight @@ -349,24 +307,21 @@ def __init__(self, weight=None, pos_weight=None, size_average=True): self.size_average = size_average def execute(self, output, target): - return binary_cross_entropy_with_logits(output, target, self.weight, self.pos_weight, self.size_average) + return binary_cross_entropy_with_logits(output,target,self.weight,self.pos_weight,self.size_average) - -def softmax(x, dim=None): +def softmax(x, dim = None): if dim is None: x = (x - x.max()).exp() ret = x / x.sum() else: - x = (x - x.max(dim, keepdims=True)).exp() + x = (x-x.max(dim, keepdims=True)).exp() ret = x / x.sum(dim, keepdims=True) return ret - -def log_softmax(x, dim=None): - x = softmax(x, dim=dim) +def log_softmax(x,dim=None): + x = softmax(x,dim=dim) return jt.log(x) - def log_sigmoid(x): return jt.log(jt.sigmoid(x)) @@ -378,14 +333,12 @@ def __init__(self, *args, **kwargs): def execute(self, input): return input - class Dropout(Module): def __init__(self, p=0.5, is_train=False): assert p >= 0 and p <= 1, "dropout probability has to be between 0 and 1, but got {}".format(p) self.p = p self.is_train = is_train - # TODO: test model.train() to change self.is_train - + #TODO: test model.train() to change self.is_train def execute(self, input): output = input if self.p > 0 and self.is_train: @@ -395,21 +348,19 @@ def execute(self, input): else: noise = jt.random(input.shape) noise = (noise > self.p).int() - output = output * noise / (1.0 - self.p) # div keep prob + output = output * noise / (1.0 - self.p) # div keep prob return output - -def dropout(x, p=0.5, is_train=False): - return Dropout(p=p, is_train=is_train)(x) - +def dropout(x,p=0.5,is_train=False): + return Dropout(p=p,is_train=is_train)(x) class Linear(Module): def __init__(self, in_features, out_features, bias=True): self.in_features = in_features self.out_features = out_features self.weight = init.invariant_uniform((out_features, in_features), "float32") - bound = 1.0 / math.sqrt(in_features) - self.bias = init.uniform((out_features,), "float32", -bound, bound) if bias else None + bound = 1.0/math.sqrt(in_features) + self.bias = init.uniform((out_features,), "float32",-bound,bound) if bias else None def execute(self, x): x = matmul_transpose(x, self.weight) @@ -417,7 +368,6 @@ def execute(self, x): return x + self.bias return x - class BatchNorm(Module): def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, is_train=True, sync=True): self.sync = sync @@ -432,34 +382,32 @@ def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, is_train=T self.running_var = init.constant((num_features,), "float32", 1.0).stop_grad() def execute(self, x): - dims = [0] + list(range(2, x.ndim)) + dims = [0]+list(range(2,x.ndim)) if self.is_train: xmean = jt.mean(x, dims=dims) - x2mean = jt.mean(x * x, dims=dims) + x2mean = jt.mean(x*x, dims=dims) if self.sync and jt.in_mpi: xmean = xmean.mpi_all_reduce("mean") x2mean = x2mean.mpi_all_reduce("mean") - xvar = (x2mean - xmean * xmean).maximum(0.0) - w = self.weight / jt.sqrt(xvar + self.eps) + xvar = (x2mean-xmean*xmean).maximum(0.0) + w = self.weight / jt.sqrt(xvar+self.eps) b = self.bias - xmean * w norm_x = x * w.broadcast(x, dims) + b.broadcast(x, dims) self.running_mean.update(self.running_mean + - (xmean.reshape((-1,)) - self.running_mean) * self.momentum) + (xmean.reshape((-1,)) - self.running_mean) * self.momentum) self.running_var.update(self.running_var + - (xvar.reshape((-1,)) - self.running_var) * self.momentum) + (xvar.reshape((-1,))-self.running_var)*self.momentum) return norm_x else: - w = self.weight / jt.sqrt(self.running_var + self.eps) + w = self.weight / jt.sqrt(self.running_var+self.eps) b = self.bias - self.running_mean * w norm_x = x * w.broadcast(x, dims) + b.broadcast(x, dims) return norm_x - BatchNorm2d = BatchNorm1d = BatchNorm - class InstanceNorm(Module): def __init__(self, num_features, eps=1e-05, momentum=0.1, affine=True, is_train=True, sync=True): self.sync = sync @@ -473,19 +421,17 @@ def __init__(self, num_features, eps=1e-05, momentum=0.1, affine=True, is_train= self.bias = init.constant((num_features,), "float32", 0.0) if affine else 0.0 def execute(self, x): - dims = list(range(2, x.ndim)) + dims = list(range(2,x.ndim)) xmean = jt.mean(x, dims=dims) - x2mean = jt.mean(x * x, dims=dims) + x2mean = jt.mean(x*x, dims=dims) - xvar = (x2mean - xmean * xmean).maximum(0.0) - w = self.weight / jt.sqrt(xvar + self.eps) + xvar = (x2mean-xmean*xmean).maximum(0.0) + w = self.weight / jt.sqrt(xvar+self.eps) b = self.bias - xmean * w return x * w.broadcast(x, dims) + b.broadcast(x, dims) - InstanceNorm2d = InstanceNorm1d = InstanceNorm - class LayerNorm(Module): def __init__(self, normalized_shape, eps: float = 1e-5, elementwise_affine: bool = True) -> None: if isinstance(normalized_shape, int): @@ -499,17 +445,16 @@ def __init__(self, normalized_shape, eps: float = 1e-5, elementwise_affine: bool def execute(self, x): dims = [-i for i in range(len(self.normalized_shape), 0, -1)] xmean = jt.mean(x, dims=dims, keepdims=1) - x2mean = jt.mean(x * x, dims=dims, keepdims=1) + x2mean = jt.mean(x*x, dims=dims, keepdims=1) - xvar = (x2mean - xmean * xmean).maximum(0.0) - w = self.weight / jt.sqrt(xvar + self.eps) + xvar = (x2mean-xmean*xmean).maximum(0.0) + w = self.weight / jt.sqrt(xvar+self.eps) b = self.bias - xmean * w return x * w + b LayerNorm2d = LayerNorm1d = LayerNorm - class GroupNorm(Module): def __init__(self, num_groups, num_channels, eps=1e-05, affine=True, is_train=True): self.num_groups = num_groups @@ -523,15 +468,15 @@ def __init__(self, num_groups, num_channels, eps=1e-05, affine=True, is_train=Tr def execute(self, x): N = x.shape[0] C = self.num_channels - output_shape = (N, -1) - # TODO: 3d group norm - if x.ndim == 4: + output_shape = (N,-1) + # TODO: 3d group norm + if x.ndim==4: output_shape = x.shape assert C % self.num_groups == 0 - x = x.reshape((N, self.num_groups, C // self.num_groups, -1)) - xmean = jt.mean(x, dims=[2, 3]).reshape((N, self.num_groups, 1)) - x2mean = jt.mean(x * x, dims=[2, 3]).reshape((N, self.num_groups, 1)) - xvar = (x2mean - xmean * xmean).maximum(0.0) + x = x.reshape((N, self.num_groups, C//self.num_groups, -1)) + xmean = jt.mean(x, dims=[2,3]).reshape((N, self.num_groups, 1)) + x2mean = jt.mean(x*x, dims=[2,3]).reshape((N, self.num_groups, 1)) + xvar = (x2mean-xmean*xmean).maximum(0.0) if self.affine: w = self.weight.reshape((1, self.num_groups, -1)) @@ -539,12 +484,11 @@ def execute(self, x): else: w = 1 b = 0 - w = w / jt.sqrt(xvar + self.eps) + w = w / jt.sqrt(xvar+self.eps) b = b - xmean * w x = x * w.broadcast(x, [3]) + b.broadcast(x, [3]) return x.reshape(output_shape) - Relu = jt.make_module(relu) ReLU = Relu Leaky_relu = jt.make_module(leaky_relu, 2) @@ -555,7 +499,6 @@ def execute(self, x): from jittor.depthwise_conv import DepthwiseConv - class Conv(Module): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True): self.in_channels = in_channels @@ -576,9 +519,9 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, assert out_channels % groups == 0, 'out_channels must be divisible by groups' # self.weight = init.relu_invariant_gauss([out_channels, in_channels//groups, Kh, Kw], dtype="float", mode="fan_out") - self.weight = init.invariant_uniform([out_channels, in_channels // groups, Kh, Kw], dtype="float") + self.weight = init.invariant_uniform([out_channels, in_channels//groups, Kh, Kw], dtype="float") if bias: - fan = 1 + fan=1 for i in self.weight.shape[1:]: fan *= i bound = 1 / math.sqrt(fan) @@ -590,11 +533,11 @@ def execute(self, x): if self.is_depthwise_conv and jt.flags.use_cuda: y = self.depthwise_conv(x, self.weight) if self.bias is not None: - b = self.bias.broadcast(y.shape, [0, 2, 3]) + b = self.bias.broadcast(y.shape, [0,2,3]) y = y + b return y elif self.groups == 1: - N, C, H, W = x.shape + N,C,H,W = x.shape Kh, Kw = self.kernel_size assert C==self.in_channels oh = (H+self.padding[0]*2-Kh*self.dilation[0]+self.dilation[0]-1)//self.stride[0]+1 @@ -606,19 +549,19 @@ def execute(self, x): f'i3*{self.stride[0]}-{self.padding[0]}+i5*{self.dilation[0]}', # Hid+Khid f'i4*{self.stride[1]}-{self.padding[1]}+i6*{self.dilation[1]}', # Wid+KWid ]) - ww = self.weight.broadcast(xx.shape, [0, 3, 4]) - yy = xx * ww - y = yy.sum([2, 5, 6]) # Kc, Kh, Kw + ww = self.weight.broadcast(xx.shape, [0,3,4]) + yy = xx*ww + y = yy.sum([2,5,6]) # Kc, Kh, Kw if self.bias is not None: - b = self.bias.broadcast(y.shape, [0, 2, 3]) + b = self.bias.broadcast(y.shape, [0,2,3]) y = y + b return y else: - N, C, H, W = x.shape + N,C,H,W = x.shape Kh, Kw = self.kernel_size G = self.groups - CpG = C // G # channels per group - assert C == self.in_channels + CpG = C // G # channels per group + assert C==self.in_channels oc = self.out_channels oh = (H+self.padding[0]*2-Kh*self.dilation[0]+self.dilation[0]-1)//self.stride[0]+1 ow = (W+self.padding[1]*2-Kw*self.dilation[1]+self.dilation[1]-1)//self.stride[1]+1 @@ -628,32 +571,29 @@ def execute(self, x): f'i1*{CpG}+i3', # Gid f'i4*{self.stride[0]}-{self.padding[0]}+i6*{self.dilation[0]}', # Hid+Khid f'i5*{self.stride[1]}-{self.padding[1]}+i7*{self.dilation[1]}', # Wid+KWid - ]) # w: [oc, CpG, Kh, Kw] - ww = self.weight.reindex([N, G, oc // G, CpG, oh, ow, Kh, Kw], [ - f'i1*{oc // G}+i2', + ww = self.weight.reindex([N, G, oc//G, CpG, oh, ow, Kh, Kw], [ + f'i1*{oc//G}+i2', 'i3', 'i6', 'i7' ]) - ww.compile_options = xx.compile_options = {"G": G, "C": C} - yy = xx * ww + ww.compile_options = xx.compile_options = {"G":G,"C":C} + yy = xx*ww y = yy.reindex_reduce('add', [N, oc, oh, ow], [ 'i0', - f'i1*{oc // G}+i2', + f'i1*{oc//G}+i2', 'i4', 'i5' ]) if self.bias is not None: - b = self.bias.broadcast(y.shape, [0, 2, 3]) + b = self.bias.broadcast(y.shape, [0,2,3]) y = y + b - return y - + return y Conv2d = Conv - class Conv1d(Module): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True): self.in_channels = in_channels @@ -666,12 +606,11 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, self.bias = bias assert in_channels % groups == 0, 'in_channels must be divisible by groups' assert out_channels % groups == 0, 'out_channels must be divisible by groups' - self.conv = Conv(self.in_channels, self.out_channels, self.kernel_size, self.stride, self.padding, - self.dilation, self.groups, self.bias) + self.conv = Conv(self.in_channels, self.out_channels, self.kernel_size, self.stride, self.padding, self.dilation, self.groups, self.bias) def execute(self, x): - N, C, D = x.shape - assert C == self.in_channels + N,C,D = x.shape + assert C==self.in_channels x = x.unsqueeze(-1) x = self.conv(x) y = x.squeeze(-1) @@ -685,57 +624,56 @@ def conv2d(x, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): out_channels = weight.shape[0] if groups == 1: - N, C, H, W = x.shape + N,C,H,W = x.shape Kh, Kw = weight.shape[-2:] - oh = (H + padding[0] * 2 - Kh * dilation[0] + dilation[0] - 1) // stride[0] + 1 - ow = (W + padding[1] * 2 - Kw * dilation[1] + dilation[1] - 1) // stride[1] + 1 - xx = x.reindex([N, out_channels, C, oh, ow, Kh, Kw], [ - 'i0', # Nid - 'i2', # Cid - f'i3*{stride[0]}-{padding[0]}+i5*{dilation[0]}', # Hid+Khid - f'i4*{stride[1]}-{padding[1]}+i6*{dilation[1]}', # Wid+KWid - ]) - ww = weight.broadcast(xx.shape, [0, 3, 4]) - yy = xx * ww - y = yy.sum([2, 5, 6]) # Kc, Kh, Kw + oh = (H+padding[0]*2-Kh*dilation[0]+dilation[0]-1)//stride[0]+1 + ow = (W+padding[1]*2-Kw*dilation[1]+dilation[1]-1)//stride[1]+1 + xx = x.reindex([N,out_channels,C,oh,ow,Kh,Kw], [ + 'i0', # Nid + 'i2', # Cid + f'i3*{stride[0]}-{padding[0]}+i5*{dilation[0]}', # Hid+Khid + f'i4*{stride[1]}-{padding[1]}+i6*{dilation[1]}', # Wid+KWid + ]) + ww = weight.broadcast(xx.shape, [0,3,4]) + yy = xx*ww + y = yy.sum([2,5,6]) # Kc, Kh, Kw if bias is not None: - b = bias.broadcast(y.shape, [0, 2, 3]) + b = bias.broadcast(y.shape, [0,2,3]) y = y + b return y else: - N, C, H, W = x.shape + N,C,H,W = x.shape Kh, Kw = weight.shape[-2:] G = groups - CpG = C // G # channels per group + CpG = C // G # channels per group oc = out_channels - oh = (H + padding[0] * 2 - Kh * dilation[0] + dilation[0] - 1) // stride[0] + 1 - ow = (W + padding[1] * 2 - Kw * dilation[1] + dilation[1] - 1) // stride[1] + 1 - xx = x.reindex([N, G, oc // G, CpG, oh, ow, Kh, Kw], [ - 'i0', # Nid - f'i1*{CpG}+i3', # Gid - f'i4*{stride[0]}-{padding[0]}+i6*{dilation[0]}', # Hid+Khid - f'i5*{stride[1]}-{padding[1]}+i7*{dilation[1]}', # Wid+KWid - ]) - xx.compile_options = {"G": G} + oh = (H+padding[0]*2-Kh*dilation[0]+dilation[0]-1)//stride[0]+1 + ow = (W+padding[1]*2-Kw*dilation[1]+dilation[1]-1)//stride[1]+1 + xx = x.reindex([N,G,oc//G,CpG,oh,ow,Kh,Kw], [ + 'i0', # Nid + f'i1*{CpG}+i3', # Gid + f'i4*{stride[0]}-{padding[0]}+i6*{dilation[0]}', # Hid+Khid + f'i5*{stride[1]}-{padding[1]}+i7*{dilation[1]}', # Wid+KWid + ]) + xx.compile_options = {"G":G} # w: [oc, CpG, Kh, Kw] - ww = weight.reindex([N, G, oc // G, CpG, oh, ow, Kh, Kw], [ - f'i1*{oc // G}+i2', - 'i3', - 'i6', - 'i7' - ]) - yy = xx * ww + ww = weight.reindex([N, G, oc//G, CpG, oh, ow, Kh, Kw], [ + f'i1*{oc//G}+i2', + 'i3', + 'i6', + 'i7' + ]) + yy = xx*ww y = yy.reindex_reduce('add', [N, oc, oh, ow], [ - 'i0', - f'i1*{oc // G}+i2', - 'i4', - 'i5' - ]) + 'i0', + f'i1*{oc//G}+i2', + 'i4', + 'i5' + ]) if bias is not None: - b = bias.broadcast(y.shape, [0, 2, 3]) + b = bias.broadcast(y.shape, [0,2,3]) y = y + b - return y - + return y class ConvTranspose(Module): def __init__(self, in_channels, out_channels, kernel_size, stride=1, \ @@ -746,7 +684,7 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1, \ # added self.dilation = dilation self.group = groups - assert groups == 1, "Group conv not supported yet." + assert groups==1, "Group conv not supported yet." self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size) self.stride = stride if isinstance(stride, tuple) else (stride, stride) @@ -754,15 +692,15 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1, \ # added self.padding = padding if isinstance(padding, tuple) else (padding, padding) self.real_padding = (self.dilation[0] * (self.kernel_size[0] - 1) - self.padding[0], - self.dilation[1] * (self.kernel_size[1] - 1) - self.padding[1]) - self.output_padding = output_padding if isinstance(output_padding, tuple) else (output_padding, output_padding) + self.dilation[1] * (self.kernel_size[1] - 1) - self.padding[1]) + self.output_padding = output_padding if isinstance (output_padding, tuple) else (output_padding, output_padding) assert self.output_padding[0] < max(self.stride[0], self.dilation[0]) and \ - self.output_padding[1] < max(self.stride[1], self.dilation[1]), \ + self.output_padding[1] < max(self.stride[1], self.dilation[1]), \ "output padding must be smaller than max(stride, dilation)" self.weight = init.invariant_uniform((in_channels, out_channels) + self.kernel_size, dtype="float") if bias: - fan = 1 + fan=1 for i in self.weight.shape[1:]: fan *= i bound = 1 / math.sqrt(fan) @@ -771,96 +709,92 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1, \ self.bias = None def execute(self, x): - N, C, H, W = x.shape - i, o, h, w = self.weight.shape - assert C == i + N,C,H,W = x.shape + i,o,h,w = self.weight.shape + assert C==i stride_h, stride_w = self.stride padding_h, padding_w = self.padding dilation_h, dilation_w = self.dilation - h_out = (H - 1) * stride_h + self.output_padding[0] - 2 * padding_h + 1 + (h - 1) * dilation_h - w_out = (W - 1) * stride_w + self.output_padding[1] - 2 * padding_w + 1 + (w - 1) * dilation_w + h_out = (H-1) * stride_h + self.output_padding[0] - 2*padding_h + 1 + (h-1)*dilation_h + w_out = (W-1) * stride_w + self.output_padding[1] - 2*padding_w + 1 + (w-1)*dilation_w out_shape = (N, o, h_out, w_out) shape = (N, i, o, H, W, h, w) - xx = x.broadcast(shape, (2, 5, 6)) # i,h,w - ww = self.weight.broadcast(shape, (0, 3, 4)) # N,H,W - y = (ww * xx).reindex_reduce("add", out_shape, [ - 'i0', # N - 'i2', # o - f'i3*{stride_h}-{padding_h}+i5*{dilation_h}', # Hid+Khid - f'i4*{stride_w}-{padding_w}+i6*{dilation_w}', # Wid+KWid + xx = x.broadcast(shape, (2, 5, 6)) # i,h,w + ww = self.weight.broadcast(shape, (0, 3, 4)) # N,H,W + y = (ww*xx).reindex_reduce("add", out_shape, [ + 'i0', # N + 'i2', # o + f'i3*{stride_h}-{padding_h}+i5*{dilation_h}', # Hid+Khid + f'i4*{stride_w}-{padding_w}+i6*{dilation_w}', # Wid+KWid ]) if self.bias is not None: - b = self.bias.broadcast(y.shape, [0, 2, 3]) + b = self.bias.broadcast(y.shape, [0,2,3]) y = y + b return y - def conv_transpose(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1): x = input - N, C, H, W = x.shape - i, o, h, w = weight.shape - assert C == i - assert groups == 1, "Group conv not supported yet." + N,C,H,W = x.shape + i,o,h,w = weight.shape + assert C==i + assert groups==1, "Group conv not supported yet." stride = stride if isinstance(stride, tuple) else (stride, stride) dilation = dilation if isinstance(dilation, tuple) else (dilation, dilation) # added padding = padding if isinstance(padding, tuple) else (padding, padding) - output_padding = output_padding if isinstance(output_padding, tuple) else (output_padding, output_padding) + output_padding = output_padding if isinstance (output_padding, tuple) else (output_padding, output_padding) assert output_padding[0] < max(stride[0], dilation[0]) and \ - output_padding[1] < max(stride[1], dilation[1]), \ + output_padding[1] < max(stride[1], dilation[1]), \ "output padding must be smaller than max(stride, dilation)" stride_h, stride_w = stride padding_h, padding_w = padding dilation_h, dilation_w = dilation - h_out = (H - 1) * stride_h + output_padding[0] - 2 * padding_h + 1 + (h - 1) * dilation_h - w_out = (W - 1) * stride_w + output_padding[1] - 2 * padding_w + 1 + (w - 1) * dilation_w + h_out = (H-1) * stride_h + output_padding[0] - 2*padding_h + 1 + (h-1)*dilation_h + w_out = (W-1) * stride_w + output_padding[1] - 2*padding_w + 1 + (w-1)*dilation_w out_shape = (N, o, h_out, w_out) shape = (N, i, o, H, W, h, w) - xx = x.broadcast(shape, (2, 5, 6)) # i,h,w - ww = weight.broadcast(shape, (0, 3, 4)) # N,H,W - y = (ww * xx).reindex_reduce("add", out_shape, [ - 'i0', # N - 'i2', # o - f'i3*{stride_h}-{padding_h}+i5*{dilation_h}', # Hid+Khid - f'i4*{stride_w}-{padding_w}+i6*{dilation_w}', # Wid+KWid + xx = x.broadcast(shape, (2, 5, 6)) # i,h,w + ww = weight.broadcast(shape, (0, 3, 4)) # N,H,W + y = (ww*xx).reindex_reduce("add", out_shape, [ + 'i0', # N + 'i2', # o + f'i3*{stride_h}-{padding_h}+i5*{dilation_h}', # Hid+Khid + f'i4*{stride_w}-{padding_w}+i6*{dilation_w}', # Wid+KWid ]) if isinstance(bias, jt.Var): - b = bias.broadcast(y.shape, [0, 2, 3]) + b = bias.broadcast(y.shape, [0,2,3]) y = y + b else: assert not bias, "Bias should be none or jittor var" return y - conv_transpose2d = conv_transpose - -def pad(x, padding, mode='constant', value=0): - assert mode in ['constant', 'replicate', 'reflect', - 'circular'], 'only support constant,replicate,reflect,circular pad' - assert len(padding) % 2 == 0 and len(padding) // 2 <= x.ndim +def pad(x,padding, mode='constant', value=0): + assert mode in ['constant','replicate','reflect','circular'],'only support constant,replicate,reflect,circular pad' + assert len(padding)%2==0 and len(padding)//2<=x.ndim padding = list(padding) - left = [0] * (x.ndim - len(padding) // 2) + padding[::2][::-1] - right = [0] * (x.ndim - len(padding) // 2) + padding[1::2][::-1] + left = [0]*(x.ndim-len(padding)//2)+padding[::2][::-1] + right = [0]*(x.ndim-len(padding)//2)+padding[1::2][::-1] out_dims = [] out_shape = [] - for i, n, l, r in zip(range(x.ndim), x.shape, left, right): - out_shape.append(n + l + r) + for i,n,l,r in zip(range(x.ndim),x.shape,left,right): + out_shape.append(n+l+r) if mode == 'constant': out_dims.append(f'i{i}-{l}') elif mode == 'replicate': - out_dims.append(f"i{i}<{l} ? 0 : i{i} > {n + l - 1} ? {n - 1} : i{i}-{l}") + out_dims.append(f"i{i}<{l} ? 0 : i{i} > {n+l-1} ? {n-1} : i{i}-{l}") elif mode == 'reflect': - out_dims.append(f"i{i}<{l} ? {l}-i{i} : i{i} > {n + l - 1} ? {2 * (n - 1) + l}-i{i} : i{i}-{l}") + out_dims.append(f"i{i}<{l} ? {l}-i{i} : i{i} > {n+l-1} ? {2*(n-1)+l}-i{i} : i{i}-{l}") elif mode == 'circular': - out_dims.append(f"i{i}<{l} ? {n - l}+i{i} : i{i} > {n + l - 1} ? i{i}-{n + l} : i{i}-{l}") + out_dims.append(f"i{i}<{l} ? {n-l}+i{i} : i{i} > {n+l-1} ? i{i}-{n+l} : i{i}-{l}") - return x.reindex(out_shape, out_dims, overflow_value=value) + return x.reindex(out_shape,out_dims,overflow_value=value) class ReflectionPad2d(Module): @@ -877,20 +811,19 @@ def __init__(self, padding): raise TypeError(f"ReflectionPad2d padding just support int or tuple, but found {type(padding)}") def execute(self, x): - n, c, h, w = x.shape + n,c,h,w = x.shape assert (self.pl < w and self.pr < w), f"padding_left and padding_right should be smaller than input width" assert (self.pt < h and self.pb < h), f"padding_top and padding_bottom should be smaller than input height" - oh = h + self.pt + self.pb - ow = w + self.pl + self.pr + oh=h+self.pt+self.pb + ow=w+self.pl+self.pr l = self.pl r = self.pl + w - 1 t = self.pt b = self.pt + h - 1 - return x.reindex([n, c, oh, ow], ["i0", "i1", - f"i2<{t} ? {t}-i2 : i2 > {b} ? {h - 1 + b}-i2 : i2-{t}", - f"i3<{l} ? {l}-i3 : i3 > {r} ? {w - 1 + r}-i3 : i3-{l}", - ]) - + return x.reindex([n,c,oh,ow], ["i0","i1", + f"i2<{t} ? {t}-i2 : i2 > {b} ? {h-1+b}-i2 : i2-{t}", + f"i3<{l} ? {l}-i3 : i3 > {r} ? {w-1+r}-i3 : i3-{l}", + ]) class ZeroPad2d(Module): def __init__(self, padding): @@ -906,10 +839,8 @@ def __init__(self, padding): raise TypeError(f"ZeroPad2d padding just support int or tuple, but found {type(padding)}") def execute(self, x): - n, c, h, w = x.shape - return x.reindex([n, c, h + self.pt + self.pb, w + self.pl + self.pr], - ["i0", "i1", f"i2-{self.pt}", f"i3-{self.pl}"]) - + n,c,h,w = x.shape + return x.reindex([n,c,h+self.pt+self.pb,w+self.pl+self.pr], ["i0","i1",f"i2-{self.pt}",f"i3-{self.pl}"]) class ConstantPad2d(Module): def __init__(self, padding, value): @@ -928,15 +859,14 @@ def __init__(self, padding, value): def execute(self, x): assert len(x.shape) >= 2 shape = x.shape - tar_shape = shape[0:-2] + [shape[-2] + self.pt + self.pb, shape[-1] + self.pl + self.pr] + tar_shape = shape[0:-2] + [shape[-2]+self.pt+self.pb,shape[-1]+self.pl+self.pr] tar_dims = [] - for i in range(len(shape) - 2): + for i in range(len(shape)-2): tar_dims.append(f"i{i}") - tar_dims.append(f"i{i + 1}-{self.pt}") - tar_dims.append(f"i{i + 2}-{self.pl}") + tar_dims.append(f"i{i+1}-{self.pt}") + tar_dims.append(f"i{i+2}-{self.pl}") return x.reindex(tar_shape, tar_dims, overflow_value=self.value) - class ReplicationPad2d(Module): def __init__(self, padding): self.padding = padding @@ -951,82 +881,71 @@ def __init__(self, padding): raise TypeError(f"ReplicationPad2d padding just support int or tuple, but found {type(padding)}") def execute(self, x): - n, c, h, w = x.shape - oh = h + self.pt + self.pb - ow = w + self.pl + self.pr + n,c,h,w = x.shape + oh=h+self.pt+self.pb + ow=w+self.pl+self.pr l = self.pl r = self.pl + w - 1 t = self.pt b = self.pt + h - 1 - return x.reindex([n, c, oh, ow], ["i0", "i1", - f"i2<{t} ? 0 : i2 > {b} ? {h - 1} : i2-{t}", - f"i3<{l} ? 0 : i3 > {r} ? {w - 1} : i3-{l}" - ]) - + return x.reindex([n,c,oh,ow], ["i0","i1", + f"i2<{t} ? 0 : i2 > {b} ? {h-1} : i2-{t}", + f"i3<{l} ? 0 : i3 > {r} ? {w-1} : i3-{l}" + ]) class Embedding(Module): def __init__(self, num, dim): self.num = num self.dim = dim - self.weight = jt.init.gauss([num, dim], 'float32').stop_grad() + self.weight = jt.init.gauss([num,dim],'float32').stop_grad() def execute(self, x): - res = self.weight[x].reshape([x.shape[0], self.dim]) + res = self.weight[x].reshape([x.shape[0],self.dim]) return res - class PixelShuffle(Module): def __init__(self, upscale_factor): self.upscale_factor = upscale_factor def execute(self, x): - n, c, h, w = x.shape + n,c,h,w = x.shape r = self.upscale_factor - assert c % (r * r) == 0, f"input channel needs to be divided by upscale_factor's square in PixelShuffle" - return x.reindex([n, int(c / r ** 2), h * r, w * r], [ + assert c%(r*r)==0, f"input channel needs to be divided by upscale_factor's square in PixelShuffle" + return x.reindex([n,int(c/r**2),h*r,w*r], [ "i0", - f"i1*{r * r}+i2%{r}*{r}+i3%{r}", + f"i1*{r*r}+i2%{r}*{r}+i3%{r}", f"i2/{r}", f"i3/{r}" ]) - class Tanh(Module): def __init__(self): super().__init__() - - def execute(self, x): + def execute(self, x) : return x.tanh() - class Sigmoid(Module): def __init__(self): super().__init__() - - def execute(self, x): + def execute(self, x) : return x.sigmoid() - -def softplus(x, beta=1.0, threshold=20.0): +def softplus(x,beta=1.0,threshold=20.0): return 1 / beta * jt.log(1 + (beta * x).minimum(threshold).exp()) + \ - (x - threshold / beta).maximum(0.0) - + (x - threshold/beta).maximum(0.0) -def hardtanh(x, min_val=-1, max_val=1): - return jt.clamp(x, min_v=min_val, max_v=max_val) +def hardtanh(x,min_val=-1,max_val=1): + return jt.clamp(x,min_v=min_val,max_v=max_val) class Softplus(Module): r''' SoftPlus is a smooth approximation to the ReLU function and can be used to constrain the output of a machine to always be positive. - Args: - + [in] beta (float): the beta value for the Softplus formulation. Default: 1. - [in] threshold (float): values above this revert to a linear function. Default: 20. ''' - def __init__(self, beta=1, threshold=20): self.beta = beta self.threshold = threshold @@ -1034,14 +953,12 @@ def __init__(self, beta=1, threshold=20): def execute(self, x): return softplus(x, self.beta, self.threshold) - class Resize(Module): def __init__(self, size, mode="nearest", align_corners=False): super().__init__() self.size = size self.mode = mode self.align_corners = align_corners - def execute(self, x): return resize(x, self.size, self.mode, self.align_corners) @@ -1054,7 +971,7 @@ def _bicubic(x, a, func): return a*(jt.abs(x)**3)-5*a*(x**2)+8*a*(jt.abs(x))-4*a return 0 - + def _interpolate(img, x, y, ids, mode): if mode == "nearest": return img.reindex([*ids, x.floor(), y.floor()]) @@ -1179,188 +1096,182 @@ def grid_sample_v0(input, grid, mode='bilinear', padding_mode='zeros'): return _interpolate(input, x, y, (nid, cid), mode) -def linspace_from_neg_one(grid, num_steps, align_corners): - if num_steps <= 1: - return jt.array([], dtype=grid.dtype) +def linspace_from_neg_one(grid,num_steps,align_corners): + if num_steps <= 1: + return jt.array([],dtype=grid.dtype) # TODO: use jt.index - ra = np.linspace(-1, 1, num_steps) + ra = np.linspace(-1,1,num_steps) if not align_corners: - ra = ra * (num_steps - 1) / num_steps - return jt.array(ra, dtype=grid.dtype) - + ra = ra*(num_steps-1)/num_steps + return jt.array(ra,dtype=grid.dtype) -def make_base_grid_4D(theta, N, C, H, W, align_corners): +def make_base_grid_4D(theta,N,C,H,W,align_corners): base_grid = jt.zeros((N, H, W, 3), dtype=theta.dtype); - base_grid[..., 0] = linspace_from_neg_one(theta, W, align_corners) - base_grid[..., 1] = jt.unsqueeze(linspace_from_neg_one(theta, H, align_corners), -1) - base_grid[..., -1] = 1 + base_grid[...,0] = linspace_from_neg_one(theta, W, align_corners) + base_grid[...,1] = jt.unsqueeze(linspace_from_neg_one(theta, H, align_corners),-1) + base_grid[...,-1] = 1 return base_grid - -def make_base_grid_5D(theta, N, C, D, H, W, align_corners): +def make_base_grid_5D(theta,N,C,D,H,W,align_corners): base_grid = jt.zeros((N, D, H, W, 4), dtype=theta.dtype) - base_grid[..., 0] = linspace_from_neg_one(theta, W, align_corners) - base_grid[..., 1] = jt.unsqueeze(linspace_from_neg_one(theta, H, align_corners), -1) - base_grid[..., 2] = jt.unsqueeze(jt.unsqueeze(linspace_from_neg_one(theta, D, align_corners), -1), -1) - base_grid[..., -1] = 1 + base_grid[...,0] = linspace_from_neg_one(theta, W, align_corners) + base_grid[...,1] = jt.unsqueeze(linspace_from_neg_one(theta, H, align_corners),-1) + base_grid[...,2] = jt.unsqueeze(jt.unsqueeze(linspace_from_neg_one(theta, D, align_corners),-1),-1) + base_grid[...,-1] = 1 return base_grid +def affine_grid_generator_4D(theta,N,C,H,W,align_corners): + base_grid = make_base_grid_4D(theta, N, C, H, W, align_corners) + grid = jt.nn.bmm(base_grid.reshape(N, H * W, 3),theta.transpose(0,2,1)) + return grid.reshape(N, H, W, 2) -def affine_grid_generator_4D(theta, N, C, H, W, align_corners): - base_grid = make_base_grid_4D(theta, N, C, H, W, align_corners) - grid = jt.nn.bmm(base_grid.reshape(N, H * W, 3), theta.transpose(0, 2, 1)) - return grid.reshape(N, H, W, 2) - - -def affine_grid_generator_5D(theta, N, C, D, H, W, align_corners): +def affine_grid_generator_5D(theta,N,C,D,H,W,align_corners): base_grid = make_base_grid_5D(theta, N, C, D, H, W, align_corners) - grid = jt.nn.bmm(base_grid.reshape(N, D * H * W, 4), theta.transpose(0, 2, 1)) + grid = jt.nn.bmm(base_grid.reshape(N, D * H * W, 4),theta.transpose(0,2,1)) return grid.reshape(N, D, H, W, 3) - def affine_grid(theta, size, align_corners=False): - assert str(theta.dtype) in ['float', 'float32', 'float64'] - assert min(size) > 0 - assert len(size) in [4, 5] - if len(size) == 4: + assert str(theta.dtype) in ['float','float32','float64'] + assert min(size)>0 + assert len(size) in [4,5] + if len(size)== 4: assert theta.ndim == 3 and theta.shape[-2] == 2 and theta.shape[-1] == 3 return affine_grid_generator_4D(theta, size[0], size[1], size[2], size[3], align_corners) - elif len(size) == 5: + elif len(size)==5: assert theta.ndim == 3 and theta.shape[-2] == 3 and theta.shape[-1] == 4 return affine_grid_generator_5D(theta, size[0], size[1], size[2], size[3], size[4], align_corners) -def grid_sampler_unnormalize(coord, size, align_corners): +def grid_sampler_unnormalize(coord,size,align_corners): if align_corners: - # unnormalize coord from [-1, 1] to [0, size - 1] + #unnormalize coord from [-1, 1] to [0, size - 1] return ((coord + 1) / 2) * (size - 1) else: - # unnormalize coord from [-1, 1] to [-0.5, size - 0.5] + #unnormalize coord from [-1, 1] to [-0.5, size - 0.5] return ((coord + 1) * size - 1) / 2 -def clip_coordinates(x, clip_limit): - return jt.clamp(x, min_v=0, max_v=clip_limit - 1) +def clip_coordinates(x,clip_limit): + return jt.clamp(x,min_v=0,max_v=clip_limit-1) - -def reflect_coordinates(x, twice_low, twice_high): +def reflect_coordinates(x,twice_low,twice_high): if twice_low == twice_high: return jt.zeros_like(x) m = twice_low / 2 span = (twice_high - twice_low) / 2 x = (x - m).abs() - # `fmod` returns same sign as `in`, which is positive after the `fabs` above. + #`fmod` returns same sign as `in`, which is positive after the `fabs` above. extra = x.mod(span) flips = (x / span).floor() - result1 = extra + m - result2 = span - extra + m - con = flips % 2 == 0 - not_con = flips % 2 != 0 - result1[not_con] = 0.0 - result2[con] = 0.0 - return result1 + result2 + result1 = extra+m + result2 = span-extra+m + con = flips%2==0 + not_con = flips%2!=0 + result1[not_con]=0.0 + result2[con]=0.0 + return result1+result2 -def grid_sampler_compute_source_index(coord, size, padding_mode, align_corners): +def grid_sampler_compute_source_index(coord,size,padding_mode,align_corners): coord = grid_sampler_unnormalize(coord, size, align_corners) if padding_mode == 'border': - # clip coordinates to image borders + #clip coordinates to image borders coord = clip_coordinates(coord, size) elif padding_mode == 'reflection': - # reflect coordinates by image borders + #reflect coordinates by image borders if align_corners: - coord = reflect_coordinates(coord, 0, 2 * (size - 1)) + coord = reflect_coordinates(coord, 0, 2*(size - 1)) else: - coord = reflect_coordinates(coord, -1, 2 * size - 1) - # clip coordinates to image borders + coord = reflect_coordinates(coord, -1, 2*size - 1) + #clip coordinates to image borders coord = clip_coordinates(coord, size) return coord -def grid_sampler_3d(X, grid, mode, padding_mode, align_corners): + +def grid_sampler_3d(X,grid,mode,padding_mode,align_corners): N = X.shape[0] C = X.shape[1] inp_D = X.shape[2] inp_H = X.shape[3] inp_W = X.shape[4] - D = grid.shape[1] + D = grid.shape[1] H = grid.shape[2] W = grid.shape[3] - x = grid[:, :, :, :, 0] - y = grid[:, :, :, :, 1] - z = grid[:, :, :, :, 2] - shape = [N, C, D, H, W] + x = grid[:,:,:,:,0] + y = grid[:,:,:,:,1] + z = grid[:,:,:,:,2] + shape = [N,C,D,H,W] cid = jt.index(shape, dim=1) nid = jt.index(shape, dim=0) - x = grid_sampler_compute_source_index(x, inp_W, padding_mode, align_corners) - y = grid_sampler_compute_source_index(y, inp_H, padding_mode, align_corners) - z = grid_sampler_compute_source_index(z, inp_D, padding_mode, align_corners) - xid = x.reindex(shape, ['i0', 'i2', 'i3', 'i4']) - yid = y.reindex(shape, ['i0', 'i2', 'i3', 'i4']) - zid = z.reindex(shape, ['i0', 'i2', 'i3', 'i4']) - - if mode == 'nearest': - return X.reindex([nid, cid, zid.round(), yid.round(), xid.round()]) - elif mode == 'bilinear': - fx, fy, fz = xid.floor(), yid.floor(), zid.floor() - cx, cy, cz = fx + 1, fy + 1, fz + 1 - dx, dy, dz = xid - fx, yid - fy, zid - fz - dnx, dny, dnz = cx - xid, cy - yid, cz - zid - a = X.reindex([nid, cid, fz, fy, fx]) - b = X.reindex([nid, cid, cz, fy, fx]) - c = X.reindex([nid, cid, fz, cy, fx]) - d = X.reindex([nid, cid, fz, fy, cx]) - e = X.reindex([nid, cid, fz, cy, cx]) - f = X.reindex([nid, cid, cz, fy, cx]) - g = X.reindex([nid, cid, cz, cy, fx]) - h = X.reindex([nid, cid, cz, cy, cx]) - o = a * dnx * dny * dnz + b * dnx * dny * dz + c * dnx * dy * dnz + d * dx * dny * dnz + e * dx * dy * dnz + f * dx * dny * dz + g * dnx * dy * dz + h * dx * dy * dz + x = grid_sampler_compute_source_index(x,inp_W,padding_mode,align_corners) + y = grid_sampler_compute_source_index(y,inp_H,padding_mode,align_corners) + z = grid_sampler_compute_source_index(z,inp_D,padding_mode,align_corners) + xid = x.reindex(shape,['i0','i2','i3','i4']) + yid = y.reindex(shape,['i0','i2','i3','i4']) + zid = z.reindex(shape,['i0','i2','i3','i4']) + + if mode=='nearest': + return X.reindex([nid,cid,zid.round(),yid.round(),xid.round()]) + elif mode=='bilinear': + fx,fy,fz = xid.floor(),yid.floor(),zid.floor() + cx,cy,cz = fx+1,fy+1,fz+1 + dx,dy,dz = xid-fx,yid-fy,zid-fz + dnx,dny,dnz = cx-xid,cy-yid,cz-zid + a = X.reindex([nid,cid,fz,fy,fx]) + b = X.reindex([nid,cid,cz,fy,fx]) + c = X.reindex([nid,cid,fz,cy,fx]) + d = X.reindex([nid,cid,fz,fy,cx]) + e = X.reindex([nid,cid,fz,cy,cx]) + f = X.reindex([nid,cid,cz,fy,cx]) + g = X.reindex([nid,cid,cz,cy,fx]) + h = X.reindex([nid,cid,cz,cy,cx]) + o = a*dnx*dny*dnz+b*dnx*dny*dz+c*dnx*dy*dnz+d*dx*dny*dnz+e*dx*dy*dnz+f*dx*dny*dz+g*dnx*dy*dz+h*dx*dy*dz return o - -def grid_sampler_2d(X, grid, mode, padding_mode, align_corners): +def grid_sampler_2d(X,grid,mode,padding_mode,align_corners): N = X.shape[0] C = X.shape[1] inp_H = X.shape[2] inp_W = X.shape[3] - H = grid.shape[1] + H = grid.shape[1] W = grid.shape[2] - x = grid[:, :, :, 0] - y = grid[:, :, :, 1] - shape = [N, C, H, W] + x = grid[:,:,:,0] + y = grid[:,:,:,1] + shape = [N,C,H,W] cid = jt.index(shape, dim=1) nid = jt.index(shape, dim=0) - x = grid_sampler_compute_source_index(x, inp_W, padding_mode, align_corners) - y = grid_sampler_compute_source_index(y, inp_H, padding_mode, align_corners) - xid = x.reindex(shape, ['i0', 'i2', 'i3']) - yid = y.reindex(shape, ['i0', 'i2', 'i3']) - - if mode == 'nearest': - return X.reindex([nid, cid, yid.round(), xid.round()]) - elif mode == 'bilinear': - # xid,yid = (xid+0.00001),(yid+0.00001) - fx, fy = (xid).floor(), (yid).floor() - cx, cy = fx + 1, fy + 1 - dx, dy = xid - fx, yid - fy - dnx, dny = cx - xid, cy - yid - - a = X.reindex([nid, cid, fy, fx], overflow_value=0.0) - b = X.reindex([nid, cid, cy, fx], overflow_value=0.0) - c = X.reindex([nid, cid, fy, cx], overflow_value=0.0) - d = X.reindex([nid, cid, cy, cx], overflow_value=0.0) - o = a * dnx * dny + b * dnx * dy + c * dx * dny + d * dx * dy + x = grid_sampler_compute_source_index(x,inp_W,padding_mode,align_corners) + y = grid_sampler_compute_source_index(y,inp_H,padding_mode,align_corners) + xid = x.reindex(shape,['i0','i2','i3']) + yid = y.reindex(shape,['i0','i2','i3']) + + if mode=='nearest': + return X.reindex([nid,cid,yid.round(),xid.round()]) + elif mode=='bilinear': + #xid,yid = (xid+0.00001),(yid+0.00001) + fx,fy = (xid).floor(),(yid).floor() + cx,cy = fx+1,fy+1 + dx,dy = xid-fx,yid-fy + dnx,dny = cx-xid,cy-yid + + a = X.reindex([nid,cid,fy,fx],overflow_value=0.0) + b = X.reindex([nid,cid,cy,fx],overflow_value=0.0) + c = X.reindex([nid,cid,fy,cx],overflow_value=0.0) + d = X.reindex([nid,cid,cy,cx],overflow_value=0.0) + o = a*dnx*dny+b*dnx*dy+c*dx*dny+d*dx*dy return o def grid_sampler(X, grid, mode, padding_mode, align_corners): - assert X.dtype == grid.dtype - assert ((X.ndim == 4 or X.ndim == 5) and X.ndim == grid.ndim) - assert X.shape[0] == grid.shape[0] and grid.shape[-1] == X.ndim - 2 - assert X.numel() > 0 + assert X.dtype==grid.dtype + assert ((X.ndim==4 or X.ndim==5) and X.ndim==grid.ndim) + assert X.shape[0]==grid.shape[0] and grid.shape[-1]==X.ndim-2 + assert X.numel()>0 if X.ndim == 4: return grid_sampler_2d(X, grid, mode, padding_mode, align_corners) else: @@ -1368,8 +1279,8 @@ def grid_sampler(X, grid, mode, padding_mode, align_corners): def grid_sample(input, grid, mode='bilinear', padding_mode='zeros', align_corners=False): - assert mode in ['bilinear', 'nearest'] - assert padding_mode in ['zeros', 'border', 'reflection'] + assert mode in ['bilinear','nearest'] + assert padding_mode in ['zeros','border','reflection'] return grid_sampler(input, grid, mode, padding_mode, align_corners) @@ -1377,14 +1288,13 @@ class Upsample(Module): def __init__(self, scale_factor=None, mode='nearest'): self.scale_factor = scale_factor if isinstance(scale_factor, tuple) else (scale_factor, scale_factor) self.mode = mode - + def execute(self, x): return upsample(x, - size=( - int(x.shape[2] * self.scale_factor[0]), - int(x.shape[3] * self.scale_factor[1])), - mode=self.mode) - + size=( + int(x.shape[2]*self.scale_factor[0]), + int(x.shape[3]*self.scale_factor[1])), + mode=self.mode) class Sequential(Module): def __init__(self, *args): @@ -1393,56 +1303,47 @@ def __init__(self, *args): if isinstance(mod, collections.OrderedDict): for k, m in mod.items(): self.add_module(k, m) - elif isinstance(mod, list): + elif isinstance(mod,list): for m in mod: self.append(m) else: self.append(mod) - def __getitem__(self, idx): if idx not in self.layers: return list(self.layers.values())[idx] return self.layers[idx] - def __iter__(self): return self.layers.values().__iter__() - def keys(self): return self.layers.keys() - def values(self): return self.layers.values() - def items(self): return self.layers.items() - def execute(self, x): for k, layer in self.layers.items(): x = layer(x) return x - def dfs(self, parents, k, callback, callback_leave): n_children = len(self.layers) ret = callback(parents, k, self, n_children) if ret == False: return parents.append(self) - for k, v in self.layers.items(): + for k,v in self.layers.items(): v.dfs(parents, k, callback, callback_leave) parents.pop() if callback_leave: callback_leave(parents, k, self, n_children) - def append(self, mod): assert callable(mod), f"Module <{type(mod)}> is not callable" assert not isinstance(mod, type), f"Module is not a type" - self.layers[len(self.layers)] = mod - + self.layers[len(self.layers)]=mod def add_module(self, name, mod): assert callable(mod), f"Module <{type(mod)}> is not callable" assert not isinstance(mod, type), f"Module is not a type" - self.layers[name] = mod + self.layers[name]=mod def __len__(self): return len(self.layers) From e5d6262e062471c3fe12c23a4783f12b6d05941b Mon Sep 17 00:00:00 2001 From: Gword <471184555@qq.com> Date: Tue, 2 Mar 2021 16:37:21 +0800 Subject: [PATCH 08/36] fix nn --- python/jittor/nn.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/python/jittor/nn.py b/python/jittor/nn.py index 41ef869b..37e0f419 100644 --- a/python/jittor/nn.py +++ b/python/jittor/nn.py @@ -55,10 +55,14 @@ def bmm(a, b): shape of input a is [batch, n, m], shape of input b is [batch, m, k], return shape is [batch, n, k] + Example:: + import jittor as jt from jittor import nn + batch, n, m, k = 100, 5, 6, 7 + a = jt.random((batch, n, m)) b = jt.random((batch, m, k)) c = nn.bmm(a, b) @@ -68,27 +72,34 @@ def bmm(a, b): def matmul(a, b): ''' matrix multiply, + Example:: + a = jt.random([3]) b = jt.random([3]) c = jt.matmul(a, b) assert c.shape == [1] + a = jt.random([3, 4]) b = jt.random([4]) c = jt.matmul(a, b) assert c.shape == [3] + a = jt.random([10, 3, 4]) b = jt.random([4]) c = jt.matmul(a, b) assert c.shape == [10, 3] + a = jt.random([10, 3, 4]) b = jt.random([4, 5]) c = jt.matmul(a, b) assert c.shape == [10, 3, 5] + a = jt.random([10, 3, 4]) b = jt.random([10, 4, 5]) c = jt.matmul(a, b) assert c.shape == [10, 3, 5] + a = jt.random([8, 1, 3, 4]) b = jt.random([10, 4, 5]) c = jt.matmul(a, b) @@ -218,6 +229,7 @@ def l1_loss(output, target): def smooth_l1_loss(y_true, y_pred,reduction="mean"): """Implements Smooth-L1 loss. y_true and y_pred are typically: [N, 4], but could be any shape. + Args: y_true - ground truth y_pred - predictions @@ -941,9 +953,11 @@ def hardtanh(x,min_val=-1,max_val=1): class Softplus(Module): r''' SoftPlus is a smooth approximation to the ReLU function and can be used to constrain the output of a machine to always be positive. + Args: [in] beta (float): the beta value for the Softplus formulation. Default: 1. + [in] threshold (float): values above this revert to a linear function. Default: 20. ''' def __init__(self, beta=1, threshold=20): @@ -971,7 +985,7 @@ def _bicubic(x, a, func): return a*(jt.abs(x)**3)-5*a*(x**2)+8*a*(jt.abs(x))-4*a return 0 - + def _interpolate(img, x, y, ids, mode): if mode == "nearest": return img.reindex([*ids, x.floor(), y.floor()]) From a15cb63b2d91d393af09a46e5d28b5bac5495c2b Mon Sep 17 00:00:00 2001 From: Exusial <2247838039@qq.com> Date: Thu, 18 Mar 2021 20:21:10 +0800 Subject: [PATCH 09/36] add sampler and subset. --- python/jittor/dataset/__init__.py | 3 +- python/jittor/dataset/dataset.py | 10 ++++ python/jittor/dataset/sampler.py | 85 ++++++++++++++++++++++++++++++ python/jittor/test/test_sampler.py | 44 ++++++++++++++++ 4 files changed, 141 insertions(+), 1 deletion(-) create mode 100644 python/jittor/dataset/sampler.py create mode 100644 python/jittor/test/test_sampler.py diff --git a/python/jittor/dataset/__init__.py b/python/jittor/dataset/__init__.py index 04aaa012..89373cc6 100644 --- a/python/jittor/dataset/__init__.py +++ b/python/jittor/dataset/__init__.py @@ -1,4 +1,5 @@ from .dataset import Dataset, ImageFolder from .mnist import MNIST -from .voc import VOC \ No newline at end of file +from .voc import VOC +from .sampler import * \ No newline at end of file diff --git a/python/jittor/dataset/dataset.py b/python/jittor/dataset/dataset.py index e77c980d..9cb1f306 100644 --- a/python/jittor/dataset/dataset.py +++ b/python/jittor/dataset/dataset.py @@ -80,6 +80,7 @@ def __init__(self, self.buffer_size = buffer_size self.stop_grad = stop_grad self.keep_numpy_array = keep_numpy_array + self.subset = None def __getitem__(self, index): raise NotImplementedError @@ -348,6 +349,8 @@ def __iter__(self): self.total_len = len(self) if self.shuffle == False: index_list = get_order_list(self.total_len) + elif self.subset is not None: + index_list = [x for x in range(self.subset[0], self.subset[1])] else: index_list = get_random_list(self.total_len) @@ -465,6 +468,13 @@ def __iter__(self): batch_data = self.to_jittor(batch_data) yield batch_data + def set_subset(self, start, end): + if start < 0 or end > self.total_len: + return + if start > end: + return + self.subset = (start, end) + self.total_len = end - start class ImageFolder(Dataset): """ diff --git a/python/jittor/dataset/sampler.py b/python/jittor/dataset/sampler.py new file mode 100644 index 00000000..510c7b05 --- /dev/null +++ b/python/jittor/dataset/sampler.py @@ -0,0 +1,85 @@ +import jittor as jt +from .dataset import Dataset +import numpy as np +from PIL import Image + +class Sampler(): + def __init__(self,data_source): + self.data_source = data_source + + def __iter__(self): + pass + + def __len__(self): + pass + +class SequentialSampler(Sampler): + def __init__(self,data_source): + self.data_source = data_source + + def __iter__(self): + return iter(range(len(self.data_source))) + + def __len__(self): + return len(self.data_source) + + +class RandomSampler(Sampler): + def __init__(self,data_source,replacement=False, num_samples=None): + self.data_source = data_source + self.rep = replacement + self._num_samples = num_samples + + @property + def num_samples(self): + if self._num_samples is None: + return len(self.data_source) + return self._num_samples + + def __len__(self): + return self.num_samples + + def __iter__(self): + n = len(self.data_source) + if self.rep: + return iter(np.random.randint(low=0, high=n, size=(self.num_samples,), dtype=np.int64).tolist()) + return iter(jt.randperm(n).numpy().tolist()) + + +class SubsetRandomSampler(Sampler): + def __init__(self,indices): + self.indices = indices + + def __iter__(self): + return (self.indices[jt.to_int(i)] for i in jt.randperm((len(self.indices)))) + + def __len__(self): + return len(self.indices) + + +class BatchSampler(Sampler): + def __init__(self,sampler,batch_size,drop_last): + self.sampler = sampler + self.batch_size = batch_size + self.drop_last = drop_last + + def __iter__(self): + batch = [] + for idx in self.sampler: + batch.append(idx) + if len(batch) == self.batch_size: + yield batch + batch = [] + if len(batch) > 0 and not self.drop_last: + yield batch + + def __len__(self): + if self.drop_last: + return len(self.sampler) // self.batch_size + else: + return (len(self.sampler) + self.batch_size - 1) // self.batch_size + + + + + \ No newline at end of file diff --git a/python/jittor/test/test_sampler.py b/python/jittor/test/test_sampler.py new file mode 100644 index 00000000..2e7826cb --- /dev/null +++ b/python/jittor/test/test_sampler.py @@ -0,0 +1,44 @@ +import jittor as jt +from jittor.dataset import * +from PIL import Image +import numpy as np +import unittest + +test_img = np.random.normal(size=(40,3,100,100)) +class TestSamplerDataset(Dataset): + def __init__(self): + super().__init__() + self.set_attrs(total_len=40,batch_size=1) + + def __getitem__(self,idx): + return test_img[idx:(idx+1),...] + +testdataset = TestSamplerDataset() + +class TestSampler(unittest.TestCase): + def test_sequential_sampler(self): + seqsampler = SequentialSampler(testdataset) + assert len(seqsampler) == 40 + for idx,batch in enumerate(seqsampler): + assert idx == batch + + def test_random_sampler(self): + randomsampler = RandomSampler(testdataset) + assert len(randomsampler) == 40 + + def test_subset_random_sampler(self): + testdataset.set_subset(20,30) + subsetsampler = SubsetRandomSampler(testdataset) + assert len(subsetsampler) == 10 + + def test_batch_sampler(self): + testdataset.subset = None + seqforbatch = SequentialSampler(testdataset) + batchsampler = BatchSampler(seqforbatch,4,drop_last=False) + assert len(batchsampler) == 10 + for batch in batchsampler: + assert len(batch) == 4 + + +if __name__ == "__main__": + unittest.main() From 5c6c160aeccb00c27a3cc8ec056a64af0439024e Mon Sep 17 00:00:00 2001 From: Exusial <2247838039@qq.com> Date: Fri, 19 Mar 2021 10:10:40 +0800 Subject: [PATCH 10/36] fix subset,change jittor op to np op. --- python/jittor/dataset/dataset.py | 630 ++++++++++++++++++++++++++--- python/jittor/dataset/sampler.py | 53 ++- python/jittor/test/test_sampler.py | 26 +- 3 files changed, 618 insertions(+), 91 deletions(-) diff --git a/python/jittor/dataset/dataset.py b/python/jittor/dataset/dataset.py index 9cb1f306..52e30ca8 100644 --- a/python/jittor/dataset/dataset.py +++ b/python/jittor/dataset/dataset.py @@ -1,9 +1,9 @@ # *************************************************************** -# Copyright (c) 2021 Jittor. All Rights Reserved. -# Maintainers: +# Copyright (c) 2021 Jittor. All Rights Reserved. +# Maintainers: # Meng-Hao Guo -# Dun Liang . -# +# Dun Liang . +# # This file is subject to the terms and conditions defined in # file 'LICENSE.txt', which is part of this source code package. # *************************************************************** @@ -23,20 +23,22 @@ import time dataset_root = os.path.join(pathlib.Path.home(), ".cache", "jittor", "dataset") -mp_log_v = os.environ.get("mp_log_v", 0) +mp_log_v = os.environ.get("mp_log_v", 0) mpi = jt.mpi img_open_hook = HookTimer(Image, "open") + class Worker: def __init__(self, target, args, buffer_size, keep_numpy_array=False): self.buffer = jt.RingBuffer(buffer_size) self.buffer.keep_numpy_array(keep_numpy_array) self.status = mp.Array('f', 5, lock=False) - self.p = mp.Process(target=target, args=args+(self.buffer,self.status)) + self.p = mp.Process(target=target, args=args + (self.buffer, self.status)) self.p.daemon = True self.p.start() + class Dataset(object): ''' Base class for reading data. @@ -48,7 +50,7 @@ class Dataset(object): [in] drop_last(bool): if true, the last batch of dataset might smaller than batch_size, default True. [in] num_workers(int): number of workers for loading data. [in] buffer_size(int): buffer size for each worker in bytes, default(512MB). - + Example:: class YourDataset(Dataset): @@ -63,14 +65,15 @@ def __getitem__(self, k): for x, y in dataset: ...... ''' + def __init__(self, - batch_size = 16, - shuffle = False, - drop_last = False, - num_workers = 0, - buffer_size = 512*1024*1024, - stop_grad = True, - keep_numpy_array = False): + batch_size=16, + shuffle=False, + drop_last=False, + num_workers=0, + buffer_size=512 * 1024 * 1024, + stop_grad=True, + keep_numpy_array=False): super().__init__() self.total_len = None self.batch_size = batch_size @@ -80,7 +83,6 @@ def __init__(self, self.buffer_size = buffer_size self.stop_grad = stop_grad self.keep_numpy_array = keep_numpy_array - self.subset = None def __getitem__(self, index): raise NotImplementedError @@ -90,15 +92,15 @@ def __batch_len__(self): assert self.batch_size > 0 if self.drop_last: return self.total_len // self.batch_size - return (self.total_len-1) // self.batch_size + 1 + return (self.total_len - 1) // self.batch_size + 1 def __len__(self): return self.__batch_len__() def set_attrs(self, **kw): - ''' + ''' You can set attributes of dataset by using set_attrs function, including total_len, batch_size, shuffle, drop_last, num_workers, buffer_size. - + Example:: dataset = YourDataset().set_attrs(batch_size=256, shuffle=True) @@ -113,7 +115,7 @@ def set_attrs(self, **kw): * buffer_size: buffer size for each worker in bytes, default(512MB). * stop_grad: stop grad for data, default(True). ''' - for k,v in kw.items(): + for k, v in kw.items(): assert hasattr(self, k), k setattr(self, k, v) self.reset() @@ -134,8 +136,8 @@ def to_jittor(self, batch): new_batch = [] for a in batch: if isinstance(a, np.ndarray) or \ - isinstance(a, int) or \ - isinstance(a, float): + isinstance(a, int) or \ + isinstance(a, float): new_batch.append(to_jt(a)) else: new_batch.append(self.to_jittor(a)) @@ -159,7 +161,7 @@ def terminate(self): if hasattr(self, "workers"): for w in self.workers: w.p.terminate() - + def _worker_main(self, worker_id, buffer, status): import jittor_utils jittor_utils.cc.init_subprocess() @@ -193,8 +195,9 @@ def _worker_main(self, worker_id, buffer, status): # load and transform data batch = [] if mp_log_v: - print(f"#{worker_id} {os.getpid()} load batch", cid*self.real_batch_size, min(self.real_len, (cid+1)*self.real_batch_size)) - for i in range(cid*self.real_batch_size, min(self.real_len, (cid+1)*self.real_batch_size)): + print(f"#{worker_id} {os.getpid()} load batch", cid * self.real_batch_size, + min(self.real_len, (cid + 1) * self.real_batch_size)) + for i in range(cid * self.real_batch_size, min(self.real_len, (cid + 1) * self.real_batch_size)): batch.append(self[self.index_list[i]]) batch = self.collate_batch(batch) now = time.time() @@ -203,7 +206,8 @@ def _worker_main(self, worker_id, buffer, status): # send data to main process if mp_log_v: - print(f"#{worker_id} {os.getpid()} send", type(batch).__name__, [ type(b).__name__ for b in batch ], buffer) + print(f"#{worker_id} {os.getpid()} send", type(batch).__name__, [type(b).__name__ for b in batch], + buffer) try: buffer.send(batch) except: @@ -262,7 +266,7 @@ def display_worker_status(self): * buffer: ring buffer status, such as how many free space, left index, right index, total size(bytes). Example:: - + from jittor.dataset import Dataset class YourDataset(Dataset): pass @@ -276,7 +280,7 @@ class YourDataset(Dataset): msg.append(f"progress:{self.last_id}/{self.batch_len}") msg.append(f"batch(s): {self.batch_time:.3f}\twait(s):{self.wait_time:.3f}") msg.append(f"recv(s): {self.recv_time:.3f}\tto_jittor(s):{self.to_jittor_time:.3f}") - msg.append(f"last 10 workers: {self.idmap[max(0, self.last_id-9):self.last_id+1]}") + msg.append(f"last 10 workers: {self.idmap[max(0, self.last_id - 9):self.last_id + 1]}") msg.append(f"ID\twait(s)\topen(s)\tload(s)\tsend(s)\ttotal(s)") for i in range(self.num_workers): w = self.workers[i] @@ -301,7 +305,7 @@ def _stop_all_workers(self): # clean workers' buffer for w in self.workers: w.buffer.clear() - + def _init_workers(self): jt.clean() jt.gc() @@ -318,7 +322,7 @@ def _init_workers(self): # number of idle workers condition self.num_idle_c = mp.Condition(self.gid.get_lock()) for i in range(self.num_workers): - w = Worker(target=self._worker_main, args=(i,), + w = Worker(target=self._worker_main, args=(i,), buffer_size=self.buffer_size, keep_numpy_array=self.keep_numpy_array) workers.append(w) @@ -349,11 +353,9 @@ def __iter__(self): self.total_len = len(self) if self.shuffle == False: index_list = get_order_list(self.total_len) - elif self.subset is not None: - index_list = [x for x in range(self.subset[0], self.subset[1])] else: index_list = get_random_list(self.total_len) - + # scatter index_list for all mpi process # scatter rule: # batch 1 batch 2 @@ -371,26 +373,26 @@ def __iter__(self): assert self.batch_size >= world_size, \ f"Batch size({self.batch_size}) is smaller than MPI world_size({world_size})" - real_batch_size = (self.batch_size-1) // world_size + 1 + real_batch_size = (self.batch_size - 1) // world_size + 1 if real_batch_size * world_size != self.batch_size: LOG.w("Batch size is not divisible by MPI world size, " "The distributed version may be different from " "the single-process version.") fix_batch = self.total_len // self.batch_size last_batch = self.total_len - fix_batch * self.batch_size - fix_batch_l = index_list[0:fix_batch*self.batch_size] \ - .reshape(-1,self.batch_size) + fix_batch_l = index_list[0:fix_batch * self.batch_size] \ + .reshape(-1, self.batch_size) fix_batch_l = fix_batch_l[ - :,real_batch_size*world_rank:real_batch_size*(world_rank+1)] + :, real_batch_size * world_rank:real_batch_size * (world_rank + 1)] real_batch_size = fix_batch_l.shape[1] fix_batch_l = fix_batch_l.flatten() if not self.drop_last and last_batch > 0: last_batch_l = index_list[-last_batch:] - real_last_batch = (last_batch-1)//world_size+1 + real_last_batch = (last_batch - 1) // world_size + 1 l = real_last_batch * world_rank r = l + real_last_batch if r > last_batch: r = last_batch - if l >= r: l = r-1 + if l >= r: l = r - 1 index_list = np.concatenate([fix_batch_l, last_batch_l[l:r]]) else: index_list = fix_batch_l @@ -398,16 +400,16 @@ def __iter__(self): self.real_len = len(index_list) self.real_batch_size = real_batch_size assert self.total_len // self.batch_size == \ - self.real_len // self.real_batch_size, f"Number of batches({self.total_len // self.batch_size}!={self.real_len // self.real_batch_size}) not match, total_len: {self.total_len}, batch_size: {self.batch_size}, real_len: {self.real_len}, real_batch_size: {self.real_batch_size}" + self.real_len // self.real_batch_size, f"Number of batches({self.total_len // self.batch_size}!={self.real_len // self.real_batch_size}) not match, total_len: {self.total_len}, batch_size: {self.batch_size}, real_len: {self.real_len}, real_batch_size: {self.real_batch_size}" else: self.real_len = self.total_len self.real_batch_size = self.batch_size - + self.batch_len = self.__batch_len__() - + if not hasattr(self, "workers") and self.num_workers: self._init_workers() - + if self.num_workers: self._stop_all_workers() self.index_list_numpy[:] = index_list @@ -441,7 +443,7 @@ def __iter__(self): start = now if mp_log_v: - print(f"#{worker_id} {os.getpid()} recv", type(batch).__name__, [ type(b).__name__ for b in batch ]) + print(f"#{worker_id} {os.getpid()} recv", type(batch).__name__, [type(b).__name__ for b in batch]) batch = self.to_jittor(batch) now = time.time() self.to_jittor_time = now - start @@ -468,18 +470,541 @@ def __iter__(self): batch_data = self.to_jittor(batch_data) yield batch_data - def set_subset(self, start, end): - if start < 0 or end > self.total_len: + +class ImageFolder(Dataset): + """ + A image classify dataset, load image and label from directory:: + + * root/label1/img1.png + * root/label1/img2.png + * ... + * root/label2/img1.png + * root/label2/img2.png + * ... + + Args:: + + [in] root(string): Root directory path. + + Attributes:: + + * classes(list): List of the class names. + * class_to_idx(dict): map from class_name to class_index. + * imgs(list): List of (image_path, class_index) tuples + + Example:: + + train_dir = './data/celebA_train' + train_loader = ImageFolder(train_dir).set_attrs(batch_size=batch_size, shuffle=True) + for batch_idx, (x_, target) in enumerate(train_loader): + ... + + """ + + def __init__(self, root, transform=None): + super().__init__() + self.root = root + self.transform = transform + self.classes = sorted([d.name for d in os.scandir(root) if d.is_dir()]) + self.class_to_idx = {v: k for k, v in enumerate(self.classes)} + self.imgs = [] + image_exts = set(('.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff')) + + for i, class_name in enumerate(self.classes): + class_dir = os.path.join(root, class_name) + for dname, _, fnames in sorted(os.walk(class_dir, followlinks=True)): + for fname in sorted(fnames): + if os.path.splitext(fname)[-1].lower() in image_exts: + path = os.path.join(class_dir, fname) + self.imgs.append((path, i)) + LOG.i(f"Found {len(self.classes)} classes and {len(self.imgs)} images.") + self.set_attrs(total_len=len(self.imgs)) + + def __getitem__(self, k): + with open(self.imgs[k][0], 'rb') as f: + img = Image.open(f).convert('RGB') + if self.transform: + img = self.transform(img) + return img, self.imgs[k][1] + + +# *************************************************************** +# Copyright (c) 2021 Jittor. All Rights Reserved. +# Maintainers: +# Meng-Hao Guo +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import numpy as np +from urllib import request +import gzip +import pickle +import os +from jittor.dataset.utils import get_random_list, get_order_list, collate_batch, HookTimer +from collections.abc import Sequence, Mapping +import pathlib +from PIL import Image +import multiprocessing as mp +import signal +from jittor_utils import LOG +import jittor as jt +import time + +dataset_root = os.path.join(pathlib.Path.home(), ".cache", "jittor", "dataset") +mp_log_v = os.environ.get("mp_log_v", 0) +mpi = jt.mpi +img_open_hook = HookTimer(Image, "open") + + +class Worker: + def __init__(self, target, args, buffer_size, keep_numpy_array=False): + self.buffer = jt.RingBuffer(buffer_size) + self.buffer.keep_numpy_array(keep_numpy_array) + + self.status = mp.Array('f', 5, lock=False) + self.p = mp.Process(target=target, args=args + (self.buffer, self.status)) + self.p.daemon = True + self.p.start() + + +class Dataset(object): + ''' + Base class for reading data. + + Args:: + + [in] batch_size(int): batch size, default 16. + [in] shuffle(bool): shuffle at each epoch, default False. + [in] drop_last(bool): if true, the last batch of dataset might smaller than batch_size, default True. + [in] num_workers(int): number of workers for loading data. + [in] buffer_size(int): buffer size for each worker in bytes, default(512MB). + + Example:: + + class YourDataset(Dataset): + def __init__(self): + super().__init__() + self.set_attrs(total_len=1024) + + def __getitem__(self, k): + return k, k*k + + dataset = YourDataset().set_attrs(batch_size=256, shuffle=True) + for x, y in dataset: + ...... + ''' + + def __init__(self, + batch_size=16, + shuffle=False, + drop_last=False, + num_workers=0, + buffer_size=512 * 1024 * 1024, + stop_grad=True, + keep_numpy_array=False): + super().__init__() + self.total_len = None + self.batch_size = batch_size + self.shuffle = shuffle + self.drop_last = drop_last + self.num_workers = num_workers + self.buffer_size = buffer_size + self.stop_grad = stop_grad + self.keep_numpy_array = keep_numpy_array + + def __getitem__(self, index): + raise NotImplementedError + + def __batch_len__(self): + assert self.total_len >= 0 + assert self.batch_size > 0 + if self.drop_last: + return self.total_len // self.batch_size + return (self.total_len - 1) // self.batch_size + 1 + + def __len__(self): + return self.__batch_len__() + + def set_attrs(self, **kw): + ''' + You can set attributes of dataset by using set_attrs function, including total_len, batch_size, shuffle, drop_last, num_workers, buffer_size. + + Example:: + + dataset = YourDataset().set_attrs(batch_size=256, shuffle=True) + + Attrs: + + * batch_size(int): batch size, default 16. + * total_len(int): total lenght. + * shuffle(bool): shuffle at each epoch, default False. + * drop_last(bool): if true, the last batch of dataset might smaller than batch_size, default True. + * num_workers: number of workers for loading data + * buffer_size: buffer size for each worker in bytes, default(512MB). + * stop_grad: stop grad for data, default(True). + ''' + for k, v in kw.items(): + assert hasattr(self, k), k + setattr(self, k, v) + self.reset() + return self + + def to_jittor(self, batch): + ''' + Change batch data to jittor array, such as np.ndarray, int, and float. + ''' + if self.keep_numpy_array: return batch + if isinstance(batch, jt.Var): return batch + to_jt = lambda x: jt.array(x).stop_grad() \ + if self.stop_grad else jt.array(x) + if isinstance(batch, np.ndarray): + return to_jt(batch) + if not isinstance(batch, (list, tuple)): + return batch + new_batch = [] + for a in batch: + if isinstance(a, np.ndarray) or \ + isinstance(a, int) or \ + isinstance(a, float): + new_batch.append(to_jt(a)) + else: + new_batch.append(self.to_jittor(a)) + return new_batch + + def collate_batch(self, batch): + ''' + Puts each data field into a tensor with outer dimension batch size. + + Args:: + + [in] batch(list): A list of variables, such as jt.var, Image.Image, np.ndarray, int, float, str and so on. + + ''' + return collate_batch(batch) + + def terminate(self): + ''' + Terminate is used to terminate multi-process worker reading data. + ''' + if hasattr(self, "workers"): + for w in self.workers: + w.p.terminate() + + def _worker_main(self, worker_id, buffer, status): + import jittor_utils + jittor_utils.cc.init_subprocess() + jt.jt_init_subprocess() + # parallel_op_compiler still problematic, + # it is not work on ubuntu 16.04. but worked on ubuntu 20.04 + # it seems like the static value of parallel compiler + # is not correctly init. + jt.flags.use_parallel_op_compiler = 0 + import time + try: + gid_obj = self.gid.get_obj() + gid_lock = self.gid.get_lock() + start = time.time() + while True: + # get id + with gid_lock: + while gid_obj.value >= self.batch_len or buffer.is_stop(): + self.num_idle.value += 1 + self.num_idle_c.notify() + self.gidc.wait() + self.num_idle.value -= 1 + cid = gid_obj.value + self.idmap[cid] = worker_id + gid_obj.value += 1 + self.gidc.notify() + now = time.time() + other_time = now - start + start = now + + # load and transform data + batch = [] + if mp_log_v: + print(f"#{worker_id} {os.getpid()} load batch", cid * self.real_batch_size, + min(self.real_len, (cid + 1) * self.real_batch_size)) + for i in range(cid * self.real_batch_size, min(self.real_len, (cid + 1) * self.real_batch_size)): + batch.append(self[self.index_list[i]]) + batch = self.collate_batch(batch) + now = time.time() + data_time = now - start + start = now + + # send data to main process + if mp_log_v: + print(f"#{worker_id} {os.getpid()} send", type(batch).__name__, [type(b).__name__ for b in batch], + buffer) + try: + buffer.send(batch) + except: + if buffer.is_stop(): + continue + raise + now = time.time() + send_time = now - start + start = now + status[0], status[1], status[2], status[3], status[4] = \ + other_time, data_time, send_time, \ + other_time + data_time + send_time, \ + img_open_hook.duration + img_open_hook.duration = 0.0 + except: + import traceback + line = traceback.format_exc() + print(line) + os.kill(os.getppid(), signal.SIGINT) + exit(0) + + def display_worker_status(self): + ''' Display dataset worker status, when dataset.num_workers > 0, it will display infomation blow: + +.. code-block:: console + + progress:479/5005 + batch(s): 0.302 wait(s):0.000 + recv(s): 0.069 to_jittor(s):0.021 + recv_raw_call: 6720.0 + last 10 workers: [6, 7, 3, 0, 2, 4, 7, 5, 6, 1] + ID wait(s) load(s) send(s) total + #0 0.000 1.340 2.026 3.366 Buffer(free=0.000% l=462425368 r=462425368 size=536870912) + #1 0.000 1.451 3.607 5.058 Buffer(free=0.000% l=462425368 r=462425368 size=536870912) + #2 0.000 1.278 1.235 2.513 Buffer(free=0.000% l=462425368 r=462425368 size=536870912) + #3 0.000 1.426 1.927 3.353 Buffer(free=0.000% l=462425368 r=462425368 size=536870912) + #4 0.000 1.452 1.074 2.526 Buffer(free=0.000% l=462425368 r=462425368 size=536870912) + #5 0.000 1.422 3.204 4.625 Buffer(free=0.000% l=462425368 r=462425368 size=536870912) + #6 0.000 1.445 1.953 3.398 Buffer(free=0.000% l=462425368 r=462425368 size=536870912) + #7 0.000 1.582 0.507 2.090 Buffer(free=0.000% l=308283552 r=308283552 size=536870912) + +Meaning of the outputs: + +* progress: dataset loading progress (current/total) +* batch: batch time, exclude data loading time +* wait: time of main proc wait worker proc +* recv: time of recv batch data +* to_jittor: time of batch data to jittor variable +* recv_raw_call: total number of underlying recv_raw called +* last 10 workers: id of last 10 workers which main proc load from. +* table meaning + * ID: worker id + * wait: worker wait time + * open: worker image open time + * load: worker load time + * buffer: ring buffer status, such as how many free space, left index, right index, total size(bytes). + +Example:: + + from jittor.dataset import Dataset + class YourDataset(Dataset): + pass + dataset = YourDataset().set_attrs(num_workers=8) + for x, y in dataset: + dataset.display_worker_status() + ''' + if not hasattr(self, "workers"): return - if start > end: + msg = [""] + msg.append(f"progress:{self.last_id}/{self.batch_len}") + msg.append(f"batch(s): {self.batch_time:.3f}\twait(s):{self.wait_time:.3f}") + msg.append(f"recv(s): {self.recv_time:.3f}\tto_jittor(s):{self.to_jittor_time:.3f}") + msg.append(f"last 10 workers: {self.idmap[max(0, self.last_id - 9):self.last_id + 1]}") + msg.append(f"ID\twait(s)\topen(s)\tload(s)\tsend(s)\ttotal(s)") + for i in range(self.num_workers): + w = self.workers[i] + s = w.status + msg.append(f"#{i}\t{s[0]:.3f}\t{s[4]:.3f}\t{s[1]:.3f}\t{s[2]:.3f}\t{s[3]:.3f}\t{w.buffer}") + LOG.i('\n'.join(msg)) + + def _stop_all_workers(self): + # stop workers + for w in self.workers: + w.buffer.stop() + # wait until all workers idle + if self.num_idle.value < self.num_workers: + with self.gid.get_lock(): + self.gid.get_obj().value = self.batch_len + if mp_log_v: + print("idle num", self.num_idle.value) + while self.num_idle.value < self.num_workers: + self.num_idle_c.wait() + if mp_log_v: + print("idle num", self.num_idle.value) + # clean workers' buffer + for w in self.workers: + w.buffer.clear() + + def _init_workers(self): + jt.clean() + jt.gc() + self.index_list = mp.Array('i', self.real_len, lock=False) + workers = [] + # batch id to worker id + self.idmap = mp.Array('i', self.batch_len, lock=False) + # global token index + self.gid = mp.Value('i', self.batch_len) + # global token index condition + self.gidc = mp.Condition(self.gid.get_lock()) + # number of idle workers + self.num_idle = mp.Value('i', 0, lock=False) + # number of idle workers condition + self.num_idle_c = mp.Condition(self.gid.get_lock()) + for i in range(self.num_workers): + w = Worker(target=self._worker_main, args=(i,), + buffer_size=self.buffer_size, + keep_numpy_array=self.keep_numpy_array) + workers.append(w) + self.workers = workers + self.index_list_numpy = np.ndarray(dtype='int32', shape=self.real_len, buffer=self.index_list) + + def reset(self): + if not hasattr(self, "workers"): return - self.subset = (start, end) - self.total_len = end - start + self._stop_all_workers() + self.terminate() + del self.index_list + del self.idmap + del self.gid + del self.gidc + del self.num_idle + del self.num_idle_c + del self.workers + del self.index_list_numpy + + def __del__(self): + if mp_log_v: + print("dataset deleted") + self.terminate() + + def __iter__(self): + if self.total_len is None: + self.total_len = len(self) + if self.shuffle == False: + index_list = get_order_list(self.total_len) + else: + index_list = get_random_list(self.total_len) + + # scatter index_list for all mpi process + # scatter rule: + # batch 1 batch 2 + # [........] [........] ... + # 00011122 00011122 + # if last batch is smaller than world_size + # pad to world_size + # last batch + # [.] -> [012] + if jt.in_mpi: + world_size = mpi.world_size() + world_rank = mpi.world_rank() + index_list = np.int32(index_list) + mpi.broadcast(index_list, 0) + + assert self.batch_size >= world_size, \ + f"Batch size({self.batch_size}) is smaller than MPI world_size({world_size})" + real_batch_size = (self.batch_size - 1) // world_size + 1 + if real_batch_size * world_size != self.batch_size: + LOG.w("Batch size is not divisible by MPI world size, " + "The distributed version may be different from " + "the single-process version.") + fix_batch = self.total_len // self.batch_size + last_batch = self.total_len - fix_batch * self.batch_size + fix_batch_l = index_list[0:fix_batch * self.batch_size] \ + .reshape(-1, self.batch_size) + fix_batch_l = fix_batch_l[ + :, real_batch_size * world_rank:real_batch_size * (world_rank + 1)] + real_batch_size = fix_batch_l.shape[1] + fix_batch_l = fix_batch_l.flatten() + if not self.drop_last and last_batch > 0: + last_batch_l = index_list[-last_batch:] + real_last_batch = (last_batch - 1) // world_size + 1 + l = real_last_batch * world_rank + r = l + real_last_batch + if r > last_batch: r = last_batch + if l >= r: l = r - 1 + index_list = np.concatenate([fix_batch_l, last_batch_l[l:r]]) + else: + index_list = fix_batch_l + + self.real_len = len(index_list) + self.real_batch_size = real_batch_size + assert self.total_len // self.batch_size == \ + self.real_len // self.real_batch_size, f"Number of batches({self.total_len // self.batch_size}!={self.real_len // self.real_batch_size}) not match, total_len: {self.total_len}, batch_size: {self.batch_size}, real_len: {self.real_len}, real_batch_size: {self.real_batch_size}" + else: + self.real_len = self.total_len + self.real_batch_size = self.batch_size + + self.batch_len = self.__batch_len__() + + if not hasattr(self, "workers") and self.num_workers: + self._init_workers() + + if self.num_workers: + self._stop_all_workers() + self.index_list_numpy[:] = index_list + gid_obj = self.gid.get_obj() + gid_lock = self.gid.get_lock() + with gid_lock: + gid_obj.value = 0 + self.gidc.notify_all() + start = time.time() + self.batch_time = 0 + for i in range(self.batch_len): + # try not get lock first + if gid_obj.value <= i: + with gid_lock: + if gid_obj.value <= i: + if mp_log_v: + print("wait") + self.gidc.wait() + now = time.time() + self.wait_time = now - start + start = now + + self.last_id = i + worker_id = self.idmap[i] + w = self.workers[worker_id] + if mp_log_v: + print(f"#{worker_id} {os.getpid()} recv buffer", w.buffer) + batch = w.buffer.recv() + now = time.time() + self.recv_time = now - start + start = now + + if mp_log_v: + print(f"#{worker_id} {os.getpid()} recv", type(batch).__name__, [type(b).__name__ for b in batch]) + batch = self.to_jittor(batch) + now = time.time() + self.to_jittor_time = now - start + start = now + + yield batch + + now = time.time() + self.batch_time = now - start + start = now + else: + batch_data = [] + for idx in index_list: + batch_data.append(self[int(idx)]) + if len(batch_data) == self.real_batch_size: + batch_data = self.collate_batch(batch_data) + batch_data = self.to_jittor(batch_data) + yield batch_data + batch_data = [] + + # depend on drop_last + if not self.drop_last and len(batch_data) > 0: + batch_data = self.collate_batch(batch_data) + batch_data = self.to_jittor(batch_data) + yield batch_data + class ImageFolder(Dataset): """ A image classify dataset, load image and label from directory:: - + * root/label1/img1.png * root/label1/img2.png * ... @@ -505,15 +1030,16 @@ class ImageFolder(Dataset): ... """ + def __init__(self, root, transform=None): super().__init__() self.root = root self.transform = transform self.classes = sorted([d.name for d in os.scandir(root) if d.is_dir()]) - self.class_to_idx = {v:k for k,v in enumerate(self.classes)} + self.class_to_idx = {v: k for k, v in enumerate(self.classes)} self.imgs = [] image_exts = set(('.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff')) - + for i, class_name in enumerate(self.classes): class_dir = os.path.join(root, class_name) for dname, _, fnames in sorted(os.walk(class_dir, followlinks=True)): @@ -523,7 +1049,7 @@ def __init__(self, root, transform=None): self.imgs.append((path, i)) LOG.i(f"Found {len(self.classes)} classes and {len(self.imgs)} images.") self.set_attrs(total_len=len(self.imgs)) - + def __getitem__(self, k): with open(self.imgs[k][0], 'rb') as f: img = Image.open(f).convert('RGB') diff --git a/python/jittor/dataset/sampler.py b/python/jittor/dataset/sampler.py index 510c7b05..0f018b07 100644 --- a/python/jittor/dataset/sampler.py +++ b/python/jittor/dataset/sampler.py @@ -3,66 +3,70 @@ import numpy as np from PIL import Image + class Sampler(): - def __init__(self,data_source): + def __init__(self, data_source): self.data_source = data_source - + def __iter__(self): pass - + def __len__(self): pass + class SequentialSampler(Sampler): - def __init__(self,data_source): + def __init__(self, data_source): self.data_source = data_source - + def __iter__(self): return iter(range(len(self.data_source))) - + def __len__(self): return len(self.data_source) class RandomSampler(Sampler): - def __init__(self,data_source,replacement=False, num_samples=None): + def __init__(self, data_source, replacement=False, num_samples=None): self.data_source = data_source self.rep = replacement self._num_samples = num_samples - + @property def num_samples(self): if self._num_samples is None: return len(self.data_source) return self._num_samples - + def __len__(self): return self.num_samples - + def __iter__(self): n = len(self.data_source) if self.rep: - return iter(np.random.randint(low=0, high=n, size=(self.num_samples,), dtype=np.int64).tolist()) - return iter(jt.randperm(n).numpy().tolist()) + return iter(np.random.randint(low=0, high=n, size=(self.num_samples,), dtype=np.int64).tolist()) + return iter(np.random.permutation(n).tolist()) class SubsetRandomSampler(Sampler): - def __init__(self,indices): - self.indices = indices - - def __iter__(self): - return (self.indices[jt.to_int(i)] for i in jt.randperm((len(self.indices)))) - + def __init__(self, data_source, indice): + self.data_source = data_source + self.indices = indice + assert indice[0] >= 0 and indice[1] < data_source.total_len and indice[0] < indice[1] + + def __iter__(self): + return (self.data_source[i + self.indices[0]] for i in np.random.permutation(self.indices[1] - self.indices[0])) + def __len__(self): - return len(self.indices) + return self.indices[1] - self.indices[0] class BatchSampler(Sampler): - def __init__(self,sampler,batch_size,drop_last): + def __init__(self, sampler, batch_size, drop_last): self.sampler = sampler self.batch_size = batch_size self.drop_last = drop_last - + def __iter__(self): batch = [] for idx in self.sampler: @@ -72,14 +76,9 @@ def __iter__(self): batch = [] if len(batch) > 0 and not self.drop_last: yield batch - + def __len__(self): if self.drop_last: return len(self.sampler) // self.batch_size else: return (len(self.sampler) + self.batch_size - 1) // self.batch_size - - - - - \ No newline at end of file diff --git a/python/jittor/test/test_sampler.py b/python/jittor/test/test_sampler.py index 2e7826cb..824ff18a 100644 --- a/python/jittor/test/test_sampler.py +++ b/python/jittor/test/test_sampler.py @@ -4,41 +4,43 @@ import numpy as np import unittest -test_img = np.random.normal(size=(40,3,100,100)) +test_img = np.random.normal(size=(40, 1, 2, 2)) + + class TestSamplerDataset(Dataset): def __init__(self): super().__init__() - self.set_attrs(total_len=40,batch_size=1) + self.set_attrs(total_len=40, batch_size=1) + + def __getitem__(self, idx): + return test_img[idx:(idx + 1), ...] - def __getitem__(self,idx): - return test_img[idx:(idx+1),...] testdataset = TestSamplerDataset() + class TestSampler(unittest.TestCase): def test_sequential_sampler(self): seqsampler = SequentialSampler(testdataset) assert len(seqsampler) == 40 - for idx,batch in enumerate(seqsampler): + for idx, batch in enumerate(seqsampler): assert idx == batch - + def test_random_sampler(self): randomsampler = RandomSampler(testdataset) assert len(randomsampler) == 40 - + def test_subset_random_sampler(self): - testdataset.set_subset(20,30) - subsetsampler = SubsetRandomSampler(testdataset) + subsetsampler = SubsetRandomSampler(testdataset, (20, 30)) assert len(subsetsampler) == 10 def test_batch_sampler(self): - testdataset.subset = None seqforbatch = SequentialSampler(testdataset) - batchsampler = BatchSampler(seqforbatch,4,drop_last=False) + batchsampler = BatchSampler(seqforbatch, 4, drop_last=False) assert len(batchsampler) == 10 for batch in batchsampler: assert len(batch) == 4 - + if __name__ == "__main__": unittest.main() From 37d723a0f281ed5b5564be5f97fb41d00b6c1bcf Mon Sep 17 00:00:00 2001 From: Exusial <2247838039@qq.com> Date: Sat, 20 Mar 2021 18:27:41 +0800 Subject: [PATCH 11/36] fix. --- python/jittor/dataset/dataset.py | 530 ------------------------------- 1 file changed, 530 deletions(-) diff --git a/python/jittor/dataset/dataset.py b/python/jittor/dataset/dataset.py index 52e30ca8..1f9ee9a1 100644 --- a/python/jittor/dataset/dataset.py +++ b/python/jittor/dataset/dataset.py @@ -526,533 +526,3 @@ def __getitem__(self, k): if self.transform: img = self.transform(img) return img, self.imgs[k][1] - - -# *************************************************************** -# Copyright (c) 2021 Jittor. All Rights Reserved. -# Maintainers: -# Meng-Hao Guo -# Dun Liang . -# -# This file is subject to the terms and conditions defined in -# file 'LICENSE.txt', which is part of this source code package. -# *************************************************************** -import numpy as np -from urllib import request -import gzip -import pickle -import os -from jittor.dataset.utils import get_random_list, get_order_list, collate_batch, HookTimer -from collections.abc import Sequence, Mapping -import pathlib -from PIL import Image -import multiprocessing as mp -import signal -from jittor_utils import LOG -import jittor as jt -import time - -dataset_root = os.path.join(pathlib.Path.home(), ".cache", "jittor", "dataset") -mp_log_v = os.environ.get("mp_log_v", 0) -mpi = jt.mpi -img_open_hook = HookTimer(Image, "open") - - -class Worker: - def __init__(self, target, args, buffer_size, keep_numpy_array=False): - self.buffer = jt.RingBuffer(buffer_size) - self.buffer.keep_numpy_array(keep_numpy_array) - - self.status = mp.Array('f', 5, lock=False) - self.p = mp.Process(target=target, args=args + (self.buffer, self.status)) - self.p.daemon = True - self.p.start() - - -class Dataset(object): - ''' - Base class for reading data. - - Args:: - - [in] batch_size(int): batch size, default 16. - [in] shuffle(bool): shuffle at each epoch, default False. - [in] drop_last(bool): if true, the last batch of dataset might smaller than batch_size, default True. - [in] num_workers(int): number of workers for loading data. - [in] buffer_size(int): buffer size for each worker in bytes, default(512MB). - - Example:: - - class YourDataset(Dataset): - def __init__(self): - super().__init__() - self.set_attrs(total_len=1024) - - def __getitem__(self, k): - return k, k*k - - dataset = YourDataset().set_attrs(batch_size=256, shuffle=True) - for x, y in dataset: - ...... - ''' - - def __init__(self, - batch_size=16, - shuffle=False, - drop_last=False, - num_workers=0, - buffer_size=512 * 1024 * 1024, - stop_grad=True, - keep_numpy_array=False): - super().__init__() - self.total_len = None - self.batch_size = batch_size - self.shuffle = shuffle - self.drop_last = drop_last - self.num_workers = num_workers - self.buffer_size = buffer_size - self.stop_grad = stop_grad - self.keep_numpy_array = keep_numpy_array - - def __getitem__(self, index): - raise NotImplementedError - - def __batch_len__(self): - assert self.total_len >= 0 - assert self.batch_size > 0 - if self.drop_last: - return self.total_len // self.batch_size - return (self.total_len - 1) // self.batch_size + 1 - - def __len__(self): - return self.__batch_len__() - - def set_attrs(self, **kw): - ''' - You can set attributes of dataset by using set_attrs function, including total_len, batch_size, shuffle, drop_last, num_workers, buffer_size. - - Example:: - - dataset = YourDataset().set_attrs(batch_size=256, shuffle=True) - - Attrs: - - * batch_size(int): batch size, default 16. - * total_len(int): total lenght. - * shuffle(bool): shuffle at each epoch, default False. - * drop_last(bool): if true, the last batch of dataset might smaller than batch_size, default True. - * num_workers: number of workers for loading data - * buffer_size: buffer size for each worker in bytes, default(512MB). - * stop_grad: stop grad for data, default(True). - ''' - for k, v in kw.items(): - assert hasattr(self, k), k - setattr(self, k, v) - self.reset() - return self - - def to_jittor(self, batch): - ''' - Change batch data to jittor array, such as np.ndarray, int, and float. - ''' - if self.keep_numpy_array: return batch - if isinstance(batch, jt.Var): return batch - to_jt = lambda x: jt.array(x).stop_grad() \ - if self.stop_grad else jt.array(x) - if isinstance(batch, np.ndarray): - return to_jt(batch) - if not isinstance(batch, (list, tuple)): - return batch - new_batch = [] - for a in batch: - if isinstance(a, np.ndarray) or \ - isinstance(a, int) or \ - isinstance(a, float): - new_batch.append(to_jt(a)) - else: - new_batch.append(self.to_jittor(a)) - return new_batch - - def collate_batch(self, batch): - ''' - Puts each data field into a tensor with outer dimension batch size. - - Args:: - - [in] batch(list): A list of variables, such as jt.var, Image.Image, np.ndarray, int, float, str and so on. - - ''' - return collate_batch(batch) - - def terminate(self): - ''' - Terminate is used to terminate multi-process worker reading data. - ''' - if hasattr(self, "workers"): - for w in self.workers: - w.p.terminate() - - def _worker_main(self, worker_id, buffer, status): - import jittor_utils - jittor_utils.cc.init_subprocess() - jt.jt_init_subprocess() - # parallel_op_compiler still problematic, - # it is not work on ubuntu 16.04. but worked on ubuntu 20.04 - # it seems like the static value of parallel compiler - # is not correctly init. - jt.flags.use_parallel_op_compiler = 0 - import time - try: - gid_obj = self.gid.get_obj() - gid_lock = self.gid.get_lock() - start = time.time() - while True: - # get id - with gid_lock: - while gid_obj.value >= self.batch_len or buffer.is_stop(): - self.num_idle.value += 1 - self.num_idle_c.notify() - self.gidc.wait() - self.num_idle.value -= 1 - cid = gid_obj.value - self.idmap[cid] = worker_id - gid_obj.value += 1 - self.gidc.notify() - now = time.time() - other_time = now - start - start = now - - # load and transform data - batch = [] - if mp_log_v: - print(f"#{worker_id} {os.getpid()} load batch", cid * self.real_batch_size, - min(self.real_len, (cid + 1) * self.real_batch_size)) - for i in range(cid * self.real_batch_size, min(self.real_len, (cid + 1) * self.real_batch_size)): - batch.append(self[self.index_list[i]]) - batch = self.collate_batch(batch) - now = time.time() - data_time = now - start - start = now - - # send data to main process - if mp_log_v: - print(f"#{worker_id} {os.getpid()} send", type(batch).__name__, [type(b).__name__ for b in batch], - buffer) - try: - buffer.send(batch) - except: - if buffer.is_stop(): - continue - raise - now = time.time() - send_time = now - start - start = now - status[0], status[1], status[2], status[3], status[4] = \ - other_time, data_time, send_time, \ - other_time + data_time + send_time, \ - img_open_hook.duration - img_open_hook.duration = 0.0 - except: - import traceback - line = traceback.format_exc() - print(line) - os.kill(os.getppid(), signal.SIGINT) - exit(0) - - def display_worker_status(self): - ''' Display dataset worker status, when dataset.num_workers > 0, it will display infomation blow: - -.. code-block:: console - - progress:479/5005 - batch(s): 0.302 wait(s):0.000 - recv(s): 0.069 to_jittor(s):0.021 - recv_raw_call: 6720.0 - last 10 workers: [6, 7, 3, 0, 2, 4, 7, 5, 6, 1] - ID wait(s) load(s) send(s) total - #0 0.000 1.340 2.026 3.366 Buffer(free=0.000% l=462425368 r=462425368 size=536870912) - #1 0.000 1.451 3.607 5.058 Buffer(free=0.000% l=462425368 r=462425368 size=536870912) - #2 0.000 1.278 1.235 2.513 Buffer(free=0.000% l=462425368 r=462425368 size=536870912) - #3 0.000 1.426 1.927 3.353 Buffer(free=0.000% l=462425368 r=462425368 size=536870912) - #4 0.000 1.452 1.074 2.526 Buffer(free=0.000% l=462425368 r=462425368 size=536870912) - #5 0.000 1.422 3.204 4.625 Buffer(free=0.000% l=462425368 r=462425368 size=536870912) - #6 0.000 1.445 1.953 3.398 Buffer(free=0.000% l=462425368 r=462425368 size=536870912) - #7 0.000 1.582 0.507 2.090 Buffer(free=0.000% l=308283552 r=308283552 size=536870912) - -Meaning of the outputs: - -* progress: dataset loading progress (current/total) -* batch: batch time, exclude data loading time -* wait: time of main proc wait worker proc -* recv: time of recv batch data -* to_jittor: time of batch data to jittor variable -* recv_raw_call: total number of underlying recv_raw called -* last 10 workers: id of last 10 workers which main proc load from. -* table meaning - * ID: worker id - * wait: worker wait time - * open: worker image open time - * load: worker load time - * buffer: ring buffer status, such as how many free space, left index, right index, total size(bytes). - -Example:: - - from jittor.dataset import Dataset - class YourDataset(Dataset): - pass - dataset = YourDataset().set_attrs(num_workers=8) - for x, y in dataset: - dataset.display_worker_status() - ''' - if not hasattr(self, "workers"): - return - msg = [""] - msg.append(f"progress:{self.last_id}/{self.batch_len}") - msg.append(f"batch(s): {self.batch_time:.3f}\twait(s):{self.wait_time:.3f}") - msg.append(f"recv(s): {self.recv_time:.3f}\tto_jittor(s):{self.to_jittor_time:.3f}") - msg.append(f"last 10 workers: {self.idmap[max(0, self.last_id - 9):self.last_id + 1]}") - msg.append(f"ID\twait(s)\topen(s)\tload(s)\tsend(s)\ttotal(s)") - for i in range(self.num_workers): - w = self.workers[i] - s = w.status - msg.append(f"#{i}\t{s[0]:.3f}\t{s[4]:.3f}\t{s[1]:.3f}\t{s[2]:.3f}\t{s[3]:.3f}\t{w.buffer}") - LOG.i('\n'.join(msg)) - - def _stop_all_workers(self): - # stop workers - for w in self.workers: - w.buffer.stop() - # wait until all workers idle - if self.num_idle.value < self.num_workers: - with self.gid.get_lock(): - self.gid.get_obj().value = self.batch_len - if mp_log_v: - print("idle num", self.num_idle.value) - while self.num_idle.value < self.num_workers: - self.num_idle_c.wait() - if mp_log_v: - print("idle num", self.num_idle.value) - # clean workers' buffer - for w in self.workers: - w.buffer.clear() - - def _init_workers(self): - jt.clean() - jt.gc() - self.index_list = mp.Array('i', self.real_len, lock=False) - workers = [] - # batch id to worker id - self.idmap = mp.Array('i', self.batch_len, lock=False) - # global token index - self.gid = mp.Value('i', self.batch_len) - # global token index condition - self.gidc = mp.Condition(self.gid.get_lock()) - # number of idle workers - self.num_idle = mp.Value('i', 0, lock=False) - # number of idle workers condition - self.num_idle_c = mp.Condition(self.gid.get_lock()) - for i in range(self.num_workers): - w = Worker(target=self._worker_main, args=(i,), - buffer_size=self.buffer_size, - keep_numpy_array=self.keep_numpy_array) - workers.append(w) - self.workers = workers - self.index_list_numpy = np.ndarray(dtype='int32', shape=self.real_len, buffer=self.index_list) - - def reset(self): - if not hasattr(self, "workers"): - return - self._stop_all_workers() - self.terminate() - del self.index_list - del self.idmap - del self.gid - del self.gidc - del self.num_idle - del self.num_idle_c - del self.workers - del self.index_list_numpy - - def __del__(self): - if mp_log_v: - print("dataset deleted") - self.terminate() - - def __iter__(self): - if self.total_len is None: - self.total_len = len(self) - if self.shuffle == False: - index_list = get_order_list(self.total_len) - else: - index_list = get_random_list(self.total_len) - - # scatter index_list for all mpi process - # scatter rule: - # batch 1 batch 2 - # [........] [........] ... - # 00011122 00011122 - # if last batch is smaller than world_size - # pad to world_size - # last batch - # [.] -> [012] - if jt.in_mpi: - world_size = mpi.world_size() - world_rank = mpi.world_rank() - index_list = np.int32(index_list) - mpi.broadcast(index_list, 0) - - assert self.batch_size >= world_size, \ - f"Batch size({self.batch_size}) is smaller than MPI world_size({world_size})" - real_batch_size = (self.batch_size - 1) // world_size + 1 - if real_batch_size * world_size != self.batch_size: - LOG.w("Batch size is not divisible by MPI world size, " - "The distributed version may be different from " - "the single-process version.") - fix_batch = self.total_len // self.batch_size - last_batch = self.total_len - fix_batch * self.batch_size - fix_batch_l = index_list[0:fix_batch * self.batch_size] \ - .reshape(-1, self.batch_size) - fix_batch_l = fix_batch_l[ - :, real_batch_size * world_rank:real_batch_size * (world_rank + 1)] - real_batch_size = fix_batch_l.shape[1] - fix_batch_l = fix_batch_l.flatten() - if not self.drop_last and last_batch > 0: - last_batch_l = index_list[-last_batch:] - real_last_batch = (last_batch - 1) // world_size + 1 - l = real_last_batch * world_rank - r = l + real_last_batch - if r > last_batch: r = last_batch - if l >= r: l = r - 1 - index_list = np.concatenate([fix_batch_l, last_batch_l[l:r]]) - else: - index_list = fix_batch_l - - self.real_len = len(index_list) - self.real_batch_size = real_batch_size - assert self.total_len // self.batch_size == \ - self.real_len // self.real_batch_size, f"Number of batches({self.total_len // self.batch_size}!={self.real_len // self.real_batch_size}) not match, total_len: {self.total_len}, batch_size: {self.batch_size}, real_len: {self.real_len}, real_batch_size: {self.real_batch_size}" - else: - self.real_len = self.total_len - self.real_batch_size = self.batch_size - - self.batch_len = self.__batch_len__() - - if not hasattr(self, "workers") and self.num_workers: - self._init_workers() - - if self.num_workers: - self._stop_all_workers() - self.index_list_numpy[:] = index_list - gid_obj = self.gid.get_obj() - gid_lock = self.gid.get_lock() - with gid_lock: - gid_obj.value = 0 - self.gidc.notify_all() - start = time.time() - self.batch_time = 0 - for i in range(self.batch_len): - # try not get lock first - if gid_obj.value <= i: - with gid_lock: - if gid_obj.value <= i: - if mp_log_v: - print("wait") - self.gidc.wait() - now = time.time() - self.wait_time = now - start - start = now - - self.last_id = i - worker_id = self.idmap[i] - w = self.workers[worker_id] - if mp_log_v: - print(f"#{worker_id} {os.getpid()} recv buffer", w.buffer) - batch = w.buffer.recv() - now = time.time() - self.recv_time = now - start - start = now - - if mp_log_v: - print(f"#{worker_id} {os.getpid()} recv", type(batch).__name__, [type(b).__name__ for b in batch]) - batch = self.to_jittor(batch) - now = time.time() - self.to_jittor_time = now - start - start = now - - yield batch - - now = time.time() - self.batch_time = now - start - start = now - else: - batch_data = [] - for idx in index_list: - batch_data.append(self[int(idx)]) - if len(batch_data) == self.real_batch_size: - batch_data = self.collate_batch(batch_data) - batch_data = self.to_jittor(batch_data) - yield batch_data - batch_data = [] - - # depend on drop_last - if not self.drop_last and len(batch_data) > 0: - batch_data = self.collate_batch(batch_data) - batch_data = self.to_jittor(batch_data) - yield batch_data - - -class ImageFolder(Dataset): - """ - A image classify dataset, load image and label from directory:: - - * root/label1/img1.png - * root/label1/img2.png - * ... - * root/label2/img1.png - * root/label2/img2.png - * ... - - Args:: - - [in] root(string): Root directory path. - - Attributes:: - - * classes(list): List of the class names. - * class_to_idx(dict): map from class_name to class_index. - * imgs(list): List of (image_path, class_index) tuples - - Example:: - - train_dir = './data/celebA_train' - train_loader = ImageFolder(train_dir).set_attrs(batch_size=batch_size, shuffle=True) - for batch_idx, (x_, target) in enumerate(train_loader): - ... - - """ - - def __init__(self, root, transform=None): - super().__init__() - self.root = root - self.transform = transform - self.classes = sorted([d.name for d in os.scandir(root) if d.is_dir()]) - self.class_to_idx = {v: k for k, v in enumerate(self.classes)} - self.imgs = [] - image_exts = set(('.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff')) - - for i, class_name in enumerate(self.classes): - class_dir = os.path.join(root, class_name) - for dname, _, fnames in sorted(os.walk(class_dir, followlinks=True)): - for fname in sorted(fnames): - if os.path.splitext(fname)[-1].lower() in image_exts: - path = os.path.join(class_dir, fname) - self.imgs.append((path, i)) - LOG.i(f"Found {len(self.classes)} classes and {len(self.imgs)} images.") - self.set_attrs(total_len=len(self.imgs)) - - def __getitem__(self, k): - with open(self.imgs[k][0], 'rb') as f: - img = Image.open(f).convert('RGB') - if self.transform: - img = self.transform(img) - return img, self.imgs[k][1] From 9452f8b8f1999506dcdef072892abded89b6c1bb Mon Sep 17 00:00:00 2001 From: Exusial <2247838039@qq.com> Date: Sat, 20 Mar 2021 19:03:03 +0800 Subject: [PATCH 12/36] fix space. --- python/jittor/dataset/dataset.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/python/jittor/dataset/dataset.py b/python/jittor/dataset/dataset.py index 1f9ee9a1..60e19fb4 100644 --- a/python/jittor/dataset/dataset.py +++ b/python/jittor/dataset/dataset.py @@ -27,7 +27,6 @@ mpi = jt.mpi img_open_hook = HookTimer(Image, "open") - class Worker: def __init__(self, target, args, buffer_size, keep_numpy_array=False): self.buffer = jt.RingBuffer(buffer_size) @@ -38,7 +37,6 @@ def __init__(self, target, args, buffer_size, keep_numpy_array=False): self.p.daemon = True self.p.start() - class Dataset(object): ''' Base class for reading data. @@ -65,7 +63,6 @@ def __getitem__(self, k): for x, y in dataset: ...... ''' - def __init__(self, batch_size=16, shuffle=False, @@ -195,9 +192,8 @@ def _worker_main(self, worker_id, buffer, status): # load and transform data batch = [] if mp_log_v: - print(f"#{worker_id} {os.getpid()} load batch", cid * self.real_batch_size, - min(self.real_len, (cid + 1) * self.real_batch_size)) - for i in range(cid * self.real_batch_size, min(self.real_len, (cid + 1) * self.real_batch_size)): + print(f"#{worker_id} {os.getpid()} load batch", cid*self.real_batch_size, min(self.real_len, (cid+1)*self.real_batch_size)) + for i in range(cid*self.real_batch_size, min(self.real_len, (cid+1)*self.real_batch_size)): batch.append(self[self.index_list[i]]) batch = self.collate_batch(batch) now = time.time() @@ -206,8 +202,7 @@ def _worker_main(self, worker_id, buffer, status): # send data to main process if mp_log_v: - print(f"#{worker_id} {os.getpid()} send", type(batch).__name__, [type(b).__name__ for b in batch], - buffer) + print(f"#{worker_id} {os.getpid()} send", type(batch).__name__, [type(b).__name__ for b in batch], buffer) try: buffer.send(batch) except: @@ -500,7 +495,6 @@ class ImageFolder(Dataset): ... """ - def __init__(self, root, transform=None): super().__init__() self.root = root @@ -525,4 +519,4 @@ def __getitem__(self, k): img = Image.open(f).convert('RGB') if self.transform: img = self.transform(img) - return img, self.imgs[k][1] + return img, self.imgs[k][1] \ No newline at end of file From 2bc9ee1133db76f91cba9bf3a2e1c2394d258860 Mon Sep 17 00:00:00 2001 From: Exusial <2247838039@qq.com> Date: Sat, 20 Mar 2021 21:55:01 +0800 Subject: [PATCH 13/36] fix?. --- python/jittor/dataset/.idea/.gitignore | 2 + .../.idea/codeStyles/codeStyleConfig.xml | 5 + python/jittor/dataset/.idea/dataset.iml | 9 + .../inspectionProfiles/Project_Default.xml | 23 ++ python/jittor/dataset/.idea/misc.xml | 6 + python/jittor/dataset/.idea/modules.xml | 8 + python/jittor/dataset/.idea/vcs.xml | 6 + python/jittor/dataset/dataset.py | 372 +++++++++--------- 8 files changed, 245 insertions(+), 186 deletions(-) create mode 100644 python/jittor/dataset/.idea/.gitignore create mode 100644 python/jittor/dataset/.idea/codeStyles/codeStyleConfig.xml create mode 100644 python/jittor/dataset/.idea/dataset.iml create mode 100644 python/jittor/dataset/.idea/inspectionProfiles/Project_Default.xml create mode 100644 python/jittor/dataset/.idea/misc.xml create mode 100644 python/jittor/dataset/.idea/modules.xml create mode 100644 python/jittor/dataset/.idea/vcs.xml diff --git a/python/jittor/dataset/.idea/.gitignore b/python/jittor/dataset/.idea/.gitignore new file mode 100644 index 00000000..5c98b428 --- /dev/null +++ b/python/jittor/dataset/.idea/.gitignore @@ -0,0 +1,2 @@ +# Default ignored files +/workspace.xml \ No newline at end of file diff --git a/python/jittor/dataset/.idea/codeStyles/codeStyleConfig.xml b/python/jittor/dataset/.idea/codeStyles/codeStyleConfig.xml new file mode 100644 index 00000000..a55e7a17 --- /dev/null +++ b/python/jittor/dataset/.idea/codeStyles/codeStyleConfig.xml @@ -0,0 +1,5 @@ + + + + \ No newline at end of file diff --git a/python/jittor/dataset/.idea/dataset.iml b/python/jittor/dataset/.idea/dataset.iml new file mode 100644 index 00000000..d6ebd480 --- /dev/null +++ b/python/jittor/dataset/.idea/dataset.iml @@ -0,0 +1,9 @@ + + + + + + + + + \ No newline at end of file diff --git a/python/jittor/dataset/.idea/inspectionProfiles/Project_Default.xml b/python/jittor/dataset/.idea/inspectionProfiles/Project_Default.xml new file mode 100644 index 00000000..95a8d909 --- /dev/null +++ b/python/jittor/dataset/.idea/inspectionProfiles/Project_Default.xml @@ -0,0 +1,23 @@ + + + + \ No newline at end of file diff --git a/python/jittor/dataset/.idea/misc.xml b/python/jittor/dataset/.idea/misc.xml new file mode 100644 index 00000000..37e641e9 --- /dev/null +++ b/python/jittor/dataset/.idea/misc.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/python/jittor/dataset/.idea/modules.xml b/python/jittor/dataset/.idea/modules.xml new file mode 100644 index 00000000..77bc753c --- /dev/null +++ b/python/jittor/dataset/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/python/jittor/dataset/.idea/vcs.xml b/python/jittor/dataset/.idea/vcs.xml new file mode 100644 index 00000000..c2365ab1 --- /dev/null +++ b/python/jittor/dataset/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/python/jittor/dataset/dataset.py b/python/jittor/dataset/dataset.py index 60e19fb4..9cb5106c 100644 --- a/python/jittor/dataset/dataset.py +++ b/python/jittor/dataset/dataset.py @@ -12,8 +12,8 @@ import gzip import pickle import os -from jittor.dataset.utils import get_random_list, get_order_list, collate_batch, HookTimer -from collections.abc import Sequence, Mapping +from jittor.dataset.utils import get_random_list,get_order_list,collate_batch,HookTimer +from collections.abc import Sequence,Mapping import pathlib from PIL import Image import multiprocessing as mp @@ -22,21 +22,23 @@ import jittor as jt import time -dataset_root = os.path.join(pathlib.Path.home(), ".cache", "jittor", "dataset") -mp_log_v = os.environ.get("mp_log_v", 0) -mpi = jt.mpi -img_open_hook = HookTimer(Image, "open") +dataset_root=os.path.join(pathlib.Path.home(),".cache","jittor","dataset") +mp_log_v=os.environ.get("mp_log_v",0) +mpi=jt.mpi +img_open_hook=HookTimer(Image,"open") + class Worker: - def __init__(self, target, args, buffer_size, keep_numpy_array=False): - self.buffer = jt.RingBuffer(buffer_size) + def __init__(self,target,args,buffer_size,keep_numpy_array=False): + self.buffer=jt.RingBuffer(buffer_size) self.buffer.keep_numpy_array(keep_numpy_array) - self.status = mp.Array('f', 5, lock=False) - self.p = mp.Process(target=target, args=args + (self.buffer, self.status)) - self.p.daemon = True + self.status=mp.Array('f',5,lock=False) + self.p=mp.Process(target=target,args=args+(self.buffer,self.status)) + self.p.daemon=True self.p.start() + class Dataset(object): ''' Base class for reading data. @@ -55,10 +57,8 @@ class YourDataset(Dataset): def __init__(self): super().__init__() self.set_attrs(total_len=1024) - def __getitem__(self, k): return k, k*k - dataset = YourDataset().set_attrs(batch_size=256, shuffle=True) for x, y in dataset: ...... @@ -68,33 +68,33 @@ def __init__(self, shuffle=False, drop_last=False, num_workers=0, - buffer_size=512 * 1024 * 1024, + buffer_size=512*1024*1024, stop_grad=True, keep_numpy_array=False): super().__init__() - self.total_len = None - self.batch_size = batch_size - self.shuffle = shuffle - self.drop_last = drop_last - self.num_workers = num_workers - self.buffer_size = buffer_size - self.stop_grad = stop_grad - self.keep_numpy_array = keep_numpy_array - - def __getitem__(self, index): + self.total_len=None + self.batch_size=batch_size + self.shuffle=shuffle + self.drop_last=drop_last + self.num_workers=num_workers + self.buffer_size=buffer_size + self.stop_grad=stop_grad + self.keep_numpy_array=keep_numpy_array + + def __getitem__(self,index): raise NotImplementedError def __batch_len__(self): - assert self.total_len >= 0 - assert self.batch_size > 0 + assert self.total_len>=0 + assert self.batch_size>0 if self.drop_last: - return self.total_len // self.batch_size - return (self.total_len - 1) // self.batch_size + 1 + return self.total_len//self.batch_size + return (self.total_len-1)//self.batch_size+1 def __len__(self): return self.__batch_len__() - def set_attrs(self, **kw): + def set_attrs(self,**kw): ''' You can set attributes of dataset by using set_attrs function, including total_len, batch_size, shuffle, drop_last, num_workers, buffer_size. @@ -112,35 +112,35 @@ def set_attrs(self, **kw): * buffer_size: buffer size for each worker in bytes, default(512MB). * stop_grad: stop grad for data, default(True). ''' - for k, v in kw.items(): - assert hasattr(self, k), k - setattr(self, k, v) + for k,v in kw.items(): + assert hasattr(self,k),k + setattr(self,k,v) self.reset() return self - def to_jittor(self, batch): + def to_jittor(self,batch): ''' Change batch data to jittor array, such as np.ndarray, int, and float. ''' if self.keep_numpy_array: return batch - if isinstance(batch, jt.Var): return batch - to_jt = lambda x: jt.array(x).stop_grad() \ + if isinstance(batch,jt.Var): return batch + to_jt=lambda x:jt.array(x).stop_grad()\ if self.stop_grad else jt.array(x) - if isinstance(batch, np.ndarray): + if isinstance(batch,np.ndarray): return to_jt(batch) - if not isinstance(batch, (list, tuple)): + if not isinstance(batch,(list,tuple)): return batch - new_batch = [] + new_batch=[] for a in batch: - if isinstance(a, np.ndarray) or \ - isinstance(a, int) or \ - isinstance(a, float): + if isinstance(a,np.ndarray) or\ + isinstance(a,int) or\ + isinstance(a,float): new_batch.append(to_jt(a)) else: new_batch.append(self.to_jittor(a)) return new_batch - def collate_batch(self, batch): + def collate_batch(self,batch): ''' Puts each data field into a tensor with outer dimension batch size. @@ -155,11 +155,11 @@ def terminate(self): ''' Terminate is used to terminate multi-process worker reading data. ''' - if hasattr(self, "workers"): + if hasattr(self,"workers"): for w in self.workers: w.p.terminate() - def _worker_main(self, worker_id, buffer, status): + def _worker_main(self,worker_id,buffer,status): import jittor_utils jittor_utils.cc.init_subprocess() jt.jt_init_subprocess() @@ -167,61 +167,61 @@ def _worker_main(self, worker_id, buffer, status): # it is not work on ubuntu 16.04. but worked on ubuntu 20.04 # it seems like the static value of parallel compiler # is not correctly init. - jt.flags.use_parallel_op_compiler = 0 + jt.flags.use_parallel_op_compiler=0 import time try: - gid_obj = self.gid.get_obj() - gid_lock = self.gid.get_lock() - start = time.time() + gid_obj=self.gid.get_obj() + gid_lock=self.gid.get_lock() + start=time.time() while True: # get id with gid_lock: - while gid_obj.value >= self.batch_len or buffer.is_stop(): - self.num_idle.value += 1 + while gid_obj.value>=self.batch_len or buffer.is_stop(): + self.num_idle.value+=1 self.num_idle_c.notify() self.gidc.wait() - self.num_idle.value -= 1 - cid = gid_obj.value - self.idmap[cid] = worker_id - gid_obj.value += 1 + self.num_idle.value-=1 + cid=gid_obj.value + self.idmap[cid]=worker_id + gid_obj.value+=1 self.gidc.notify() - now = time.time() - other_time = now - start - start = now + now=time.time() + other_time=now-start + start=now # load and transform data - batch = [] + batch=[] if mp_log_v: - print(f"#{worker_id} {os.getpid()} load batch", cid*self.real_batch_size, min(self.real_len, (cid+1)*self.real_batch_size)) - for i in range(cid*self.real_batch_size, min(self.real_len, (cid+1)*self.real_batch_size)): + print(f"#{worker_id} {os.getpid()} load batch",cid*self.real_batch_size, min(self.real_len,(cid+1)*self.real_batch_size)) + for i in range(cid*self.real_batch_size,min(self.real_len,(cid+1)*self.real_batch_size)): batch.append(self[self.index_list[i]]) - batch = self.collate_batch(batch) - now = time.time() - data_time = now - start - start = now + batch=self.collate_batch(batch) + now=time.time() + data_time=now-start + start=now # send data to main process if mp_log_v: - print(f"#{worker_id} {os.getpid()} send", type(batch).__name__, [type(b).__name__ for b in batch], buffer) + print(f"#{worker_id} {os.getpid()} send",type(batch).__name__,[type(b).__name__ for b in batch], buffer) try: buffer.send(batch) except: if buffer.is_stop(): continue raise - now = time.time() - send_time = now - start - start = now - status[0], status[1], status[2], status[3], status[4] = \ - other_time, data_time, send_time, \ - other_time + data_time + send_time, \ + now=time.time() + send_time=now-start + start=now + status[0],status[1],status[2],status[3],status[4]=\ + other_time,data_time,send_time,\ + other_time+data_time+send_time,\ img_open_hook.duration - img_open_hook.duration = 0.0 + img_open_hook.duration=0.0 except: import traceback - line = traceback.format_exc() + line=traceback.format_exc() print(line) - os.kill(os.getppid(), signal.SIGINT) + os.kill(os.getppid(),signal.SIGINT) exit(0) def display_worker_status(self): @@ -269,17 +269,17 @@ class YourDataset(Dataset): for x, y in dataset: dataset.display_worker_status() ''' - if not hasattr(self, "workers"): + if not hasattr(self,"workers"): return - msg = [""] + msg=[""] msg.append(f"progress:{self.last_id}/{self.batch_len}") msg.append(f"batch(s): {self.batch_time:.3f}\twait(s):{self.wait_time:.3f}") msg.append(f"recv(s): {self.recv_time:.3f}\tto_jittor(s):{self.to_jittor_time:.3f}") - msg.append(f"last 10 workers: {self.idmap[max(0, self.last_id - 9):self.last_id + 1]}") + msg.append(f"last 10 workers: {self.idmap[max(0,self.last_id-9):self.last_id+1]}") msg.append(f"ID\twait(s)\topen(s)\tload(s)\tsend(s)\ttotal(s)") for i in range(self.num_workers): - w = self.workers[i] - s = w.status + w=self.workers[i] + s=w.status msg.append(f"#{i}\t{s[0]:.3f}\t{s[4]:.3f}\t{s[1]:.3f}\t{s[2]:.3f}\t{s[3]:.3f}\t{w.buffer}") LOG.i('\n'.join(msg)) @@ -288,15 +288,15 @@ def _stop_all_workers(self): for w in self.workers: w.buffer.stop() # wait until all workers idle - if self.num_idle.value < self.num_workers: + if self.num_idle.value [012] if jt.in_mpi: - world_size = mpi.world_size() - world_rank = mpi.world_rank() - index_list = np.int32(index_list) - mpi.broadcast(index_list, 0) + world_size=mpi.world_size() + world_rank=mpi.world_rank() + index_list=np.int32(index_list) + mpi.broadcast(index_list,0) - assert self.batch_size >= world_size, \ + assert self.batch_size>=world_size,\ f"Batch size({self.batch_size}) is smaller than MPI world_size({world_size})" - real_batch_size = (self.batch_size - 1) // world_size + 1 - if real_batch_size * world_size != self.batch_size: + real_batch_size=(self.batch_size-1)//world_size+1 + if real_batch_size*world_size!=self.batch_size: LOG.w("Batch size is not divisible by MPI world size, " "The distributed version may be different from " "the single-process version.") - fix_batch = self.total_len // self.batch_size - last_batch = self.total_len - fix_batch * self.batch_size - fix_batch_l = index_list[0:fix_batch * self.batch_size] \ - .reshape(-1, self.batch_size) - fix_batch_l = fix_batch_l[ - :, real_batch_size * world_rank:real_batch_size * (world_rank + 1)] - real_batch_size = fix_batch_l.shape[1] - fix_batch_l = fix_batch_l.flatten() - if not self.drop_last and last_batch > 0: - last_batch_l = index_list[-last_batch:] - real_last_batch = (last_batch - 1) // world_size + 1 - l = real_last_batch * world_rank - r = l + real_last_batch - if r > last_batch: r = last_batch - if l >= r: l = r - 1 - index_list = np.concatenate([fix_batch_l, last_batch_l[l:r]]) + fix_batch=self.total_len//self.batch_size + last_batch=self.total_len-fix_batch*self.batch_size + fix_batch_l=index_list[0:fix_batch*self.batch_size]\ + .reshape(-1,self.batch_size) + fix_batch_l=fix_batch_l[ + :,real_batch_size*world_rank:real_batch_size*(world_rank+1)] + real_batch_size=fix_batch_l.shape[1] + fix_batch_l=fix_batch_l.flatten() + if not self.drop_last and last_batch>0: + last_batch_l=index_list[-last_batch:] + real_last_batch=(last_batch-1)//world_size+1 + l=real_last_batch*world_rank + r=l+real_last_batch + if r>last_batch: r=last_batch + if l>=r: l=r-1 + index_list=np.concatenate([fix_batch_l,last_batch_l[l:r]]) else: - index_list = fix_batch_l + index_list=fix_batch_l - self.real_len = len(index_list) - self.real_batch_size = real_batch_size - assert self.total_len // self.batch_size == \ - self.real_len // self.real_batch_size, f"Number of batches({self.total_len // self.batch_size}!={self.real_len // self.real_batch_size}) not match, total_len: {self.total_len}, batch_size: {self.batch_size}, real_len: {self.real_len}, real_batch_size: {self.real_batch_size}" + self.real_len=len(index_list) + self.real_batch_size=real_batch_size + assert self.total_len//self.batch_size==\ + self.real_len//self.real_batch_size,f"Number of batches({self.total_len//self.batch_size}!={self.real_len//self.real_batch_size}) not match, total_len: {self.total_len}, batch_size: {self.batch_size}, real_len: {self.real_len}, real_batch_size: {self.real_batch_size}" else: - self.real_len = self.total_len - self.real_batch_size = self.batch_size + self.real_len=self.total_len + self.real_batch_size=self.batch_size - self.batch_len = self.__batch_len__() + self.batch_len=self.__batch_len__() - if not hasattr(self, "workers") and self.num_workers: + if not hasattr(self,"workers") and self.num_workers: self._init_workers() if self.num_workers: self._stop_all_workers() - self.index_list_numpy[:] = index_list - gid_obj = self.gid.get_obj() - gid_lock = self.gid.get_lock() + self.index_list_numpy[:]=index_list + gid_obj=self.gid.get_obj() + gid_lock=self.gid.get_lock() with gid_lock: - gid_obj.value = 0 + gid_obj.value=0 self.gidc.notify_all() - start = time.time() - self.batch_time = 0 + start=time.time() + self.batch_time=0 for i in range(self.batch_len): # try not get lock first - if gid_obj.value <= i: + if gid_obj.value<=i: with gid_lock: - if gid_obj.value <= i: + if gid_obj.value<=i: if mp_log_v: print("wait") self.gidc.wait() - now = time.time() - self.wait_time = now - start - start = now + now=time.time() + self.wait_time=now-start + start=now - self.last_id = i - worker_id = self.idmap[i] - w = self.workers[worker_id] + self.last_id=i + worker_id=self.idmap[i] + w=self.workers[worker_id] if mp_log_v: - print(f"#{worker_id} {os.getpid()} recv buffer", w.buffer) - batch = w.buffer.recv() - now = time.time() - self.recv_time = now - start - start = now + print(f"#{worker_id} {os.getpid()} recv buffer",w.buffer) + batch=w.buffer.recv() + now=time.time() + self.recv_time=now-start + start=now if mp_log_v: - print(f"#{worker_id} {os.getpid()} recv", type(batch).__name__, [type(b).__name__ for b in batch]) - batch = self.to_jittor(batch) - now = time.time() - self.to_jittor_time = now - start - start = now + print(f"#{worker_id} {os.getpid()} recv",type(batch).__name__,[type(b).__name__ for b in batch]) + batch=self.to_jittor(batch) + now=time.time() + self.to_jittor_time=now-start + start=now yield batch - now = time.time() - self.batch_time = now - start - start = now + now=time.time() + self.batch_time=now-start + start=now else: - batch_data = [] + batch_data=[] for idx in index_list: batch_data.append(self[int(idx)]) - if len(batch_data) == self.real_batch_size: - batch_data = self.collate_batch(batch_data) - batch_data = self.to_jittor(batch_data) + if len(batch_data)==self.real_batch_size: + batch_data=self.collate_batch(batch_data) + batch_data=self.to_jittor(batch_data) yield batch_data - batch_data = [] + batch_data=[] # depend on drop_last - if not self.drop_last and len(batch_data) > 0: - batch_data = self.collate_batch(batch_data) - batch_data = self.to_jittor(batch_data) + if not self.drop_last and len(batch_data)>0: + batch_data=self.collate_batch(batch_data) + batch_data=self.to_jittor(batch_data) yield batch_data @@ -495,28 +495,28 @@ class ImageFolder(Dataset): ... """ - def __init__(self, root, transform=None): + def __init__(self,root,transform=None): super().__init__() - self.root = root - self.transform = transform - self.classes = sorted([d.name for d in os.scandir(root) if d.is_dir()]) - self.class_to_idx = {v: k for k, v in enumerate(self.classes)} - self.imgs = [] - image_exts = set(('.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff')) - - for i, class_name in enumerate(self.classes): - class_dir = os.path.join(root, class_name) - for dname, _, fnames in sorted(os.walk(class_dir, followlinks=True)): + self.root=root + self.transform=transform + self.classes=sorted([d.name for d in os.scandir(root) if d.is_dir()]) + self.class_to_idx={v:k for k,v in enumerate(self.classes)} + self.imgs=[] + image_exts=set(('.jpg','.jpeg','.png','.bmp','.tif','.tiff')) + + for i,class_name in enumerate(self.classes): + class_dir=os.path.join(root,class_name) + for dname,_,fnames in sorted(os.walk(class_dir,followlinks=True)): for fname in sorted(fnames): if os.path.splitext(fname)[-1].lower() in image_exts: - path = os.path.join(class_dir, fname) - self.imgs.append((path, i)) + path=os.path.join(class_dir,fname) + self.imgs.append((path,i)) LOG.i(f"Found {len(self.classes)} classes and {len(self.imgs)} images.") self.set_attrs(total_len=len(self.imgs)) - def __getitem__(self, k): - with open(self.imgs[k][0], 'rb') as f: - img = Image.open(f).convert('RGB') + def __getitem__(self,k): + with open(self.imgs[k][0],'rb') as f: + img=Image.open(f).convert('RGB') if self.transform: - img = self.transform(img) - return img, self.imgs[k][1] \ No newline at end of file + img=self.transform(img) + return img,self.imgs[k][1] \ No newline at end of file From 98605812be5ffd2e9780d179553c3c4f2476d2b1 Mon Sep 17 00:00:00 2001 From: Exusial <2247838039@qq.com> Date: Sat, 20 Mar 2021 22:16:12 +0800 Subject: [PATCH 14/36] copy. --- python/jittor/dataset/dataset.py | 412 +++++++++++++++---------------- 1 file changed, 206 insertions(+), 206 deletions(-) diff --git a/python/jittor/dataset/dataset.py b/python/jittor/dataset/dataset.py index 9cb5106c..2f1ba092 100644 --- a/python/jittor/dataset/dataset.py +++ b/python/jittor/dataset/dataset.py @@ -1,9 +1,9 @@ # *************************************************************** -# Copyright (c) 2021 Jittor. All Rights Reserved. -# Maintainers: +# Copyright (c) 2021 Jittor. All Rights Reserved. +# Maintainers: # Meng-Hao Guo -# Dun Liang . -# +# Dun Liang . +# # This file is subject to the terms and conditions defined in # file 'LICENSE.txt', which is part of this source code package. # *************************************************************** @@ -12,8 +12,8 @@ import gzip import pickle import os -from jittor.dataset.utils import get_random_list,get_order_list,collate_batch,HookTimer -from collections.abc import Sequence,Mapping +from jittor.dataset.utils import get_random_list, get_order_list, collate_batch, HookTimer +from collections.abc import Sequence, Mapping import pathlib from PIL import Image import multiprocessing as mp @@ -22,23 +22,21 @@ import jittor as jt import time -dataset_root=os.path.join(pathlib.Path.home(),".cache","jittor","dataset") -mp_log_v=os.environ.get("mp_log_v",0) -mpi=jt.mpi -img_open_hook=HookTimer(Image,"open") - +dataset_root = os.path.join(pathlib.Path.home(), ".cache", "jittor", "dataset") +mp_log_v = os.environ.get("mp_log_v", 0) +mpi = jt.mpi +img_open_hook = HookTimer(Image, "open") class Worker: - def __init__(self,target,args,buffer_size,keep_numpy_array=False): - self.buffer=jt.RingBuffer(buffer_size) + def __init__(self, target, args, buffer_size, keep_numpy_array=False): + self.buffer = jt.RingBuffer(buffer_size) self.buffer.keep_numpy_array(keep_numpy_array) - self.status=mp.Array('f',5,lock=False) - self.p=mp.Process(target=target,args=args+(self.buffer,self.status)) - self.p.daemon=True + self.status = mp.Array('f', 5, lock=False) + self.p = mp.Process(target=target, args=args+(self.buffer,self.status)) + self.p.daemon = True self.p.start() - class Dataset(object): ''' Base class for reading data. @@ -50,54 +48,56 @@ class Dataset(object): [in] drop_last(bool): if true, the last batch of dataset might smaller than batch_size, default True. [in] num_workers(int): number of workers for loading data. [in] buffer_size(int): buffer size for each worker in bytes, default(512MB). - + Example:: class YourDataset(Dataset): def __init__(self): super().__init__() self.set_attrs(total_len=1024) + def __getitem__(self, k): return k, k*k + dataset = YourDataset().set_attrs(batch_size=256, shuffle=True) for x, y in dataset: ...... ''' def __init__(self, - batch_size=16, - shuffle=False, - drop_last=False, - num_workers=0, - buffer_size=512*1024*1024, - stop_grad=True, - keep_numpy_array=False): + batch_size = 16, + shuffle = False, + drop_last = False, + num_workers = 0, + buffer_size = 512*1024*1024, + stop_grad = True, + keep_numpy_array = False): super().__init__() - self.total_len=None - self.batch_size=batch_size - self.shuffle=shuffle - self.drop_last=drop_last - self.num_workers=num_workers - self.buffer_size=buffer_size - self.stop_grad=stop_grad - self.keep_numpy_array=keep_numpy_array - - def __getitem__(self,index): + self.total_len = None + self.batch_size = batch_size + self.shuffle = shuffle + self.drop_last = drop_last + self.num_workers = num_workers + self.buffer_size = buffer_size + self.stop_grad = stop_grad + self.keep_numpy_array = keep_numpy_array + + def __getitem__(self, index): raise NotImplementedError def __batch_len__(self): - assert self.total_len>=0 - assert self.batch_size>0 + assert self.total_len >= 0 + assert self.batch_size > 0 if self.drop_last: - return self.total_len//self.batch_size - return (self.total_len-1)//self.batch_size+1 + return self.total_len // self.batch_size + return (self.total_len-1) // self.batch_size + 1 def __len__(self): return self.__batch_len__() - def set_attrs(self,**kw): - ''' + def set_attrs(self, **kw): + ''' You can set attributes of dataset by using set_attrs function, including total_len, batch_size, shuffle, drop_last, num_workers, buffer_size. - + Example:: dataset = YourDataset().set_attrs(batch_size=256, shuffle=True) @@ -113,34 +113,34 @@ def set_attrs(self,**kw): * stop_grad: stop grad for data, default(True). ''' for k,v in kw.items(): - assert hasattr(self,k),k - setattr(self,k,v) + assert hasattr(self, k), k + setattr(self, k, v) self.reset() return self - def to_jittor(self,batch): + def to_jittor(self, batch): ''' Change batch data to jittor array, such as np.ndarray, int, and float. ''' if self.keep_numpy_array: return batch - if isinstance(batch,jt.Var): return batch - to_jt=lambda x:jt.array(x).stop_grad()\ + if isinstance(batch, jt.Var): return batch + to_jt = lambda x: jt.array(x).stop_grad() \ if self.stop_grad else jt.array(x) - if isinstance(batch,np.ndarray): + if isinstance(batch, np.ndarray): return to_jt(batch) - if not isinstance(batch,(list,tuple)): + if not isinstance(batch, (list, tuple)): return batch - new_batch=[] + new_batch = [] for a in batch: - if isinstance(a,np.ndarray) or\ - isinstance(a,int) or\ - isinstance(a,float): + if isinstance(a, np.ndarray) or \ + isinstance(a, int) or \ + isinstance(a, float): new_batch.append(to_jt(a)) else: new_batch.append(self.to_jittor(a)) return new_batch - def collate_batch(self,batch): + def collate_batch(self, batch): ''' Puts each data field into a tensor with outer dimension batch size. @@ -155,11 +155,11 @@ def terminate(self): ''' Terminate is used to terminate multi-process worker reading data. ''' - if hasattr(self,"workers"): + if hasattr(self, "workers"): for w in self.workers: w.p.terminate() - - def _worker_main(self,worker_id,buffer,status): + + def _worker_main(self, worker_id, buffer, status): import jittor_utils jittor_utils.cc.init_subprocess() jt.jt_init_subprocess() @@ -167,61 +167,61 @@ def _worker_main(self,worker_id,buffer,status): # it is not work on ubuntu 16.04. but worked on ubuntu 20.04 # it seems like the static value of parallel compiler # is not correctly init. - jt.flags.use_parallel_op_compiler=0 + jt.flags.use_parallel_op_compiler = 0 import time try: - gid_obj=self.gid.get_obj() - gid_lock=self.gid.get_lock() - start=time.time() + gid_obj = self.gid.get_obj() + gid_lock = self.gid.get_lock() + start = time.time() while True: # get id with gid_lock: - while gid_obj.value>=self.batch_len or buffer.is_stop(): - self.num_idle.value+=1 + while gid_obj.value >= self.batch_len or buffer.is_stop(): + self.num_idle.value += 1 self.num_idle_c.notify() self.gidc.wait() - self.num_idle.value-=1 - cid=gid_obj.value - self.idmap[cid]=worker_id - gid_obj.value+=1 + self.num_idle.value -= 1 + cid = gid_obj.value + self.idmap[cid] = worker_id + gid_obj.value += 1 self.gidc.notify() - now=time.time() - other_time=now-start - start=now + now = time.time() + other_time = now - start + start = now # load and transform data - batch=[] + batch = [] if mp_log_v: - print(f"#{worker_id} {os.getpid()} load batch",cid*self.real_batch_size, min(self.real_len,(cid+1)*self.real_batch_size)) - for i in range(cid*self.real_batch_size,min(self.real_len,(cid+1)*self.real_batch_size)): + print(f"#{worker_id} {os.getpid()} load batch", cid*self.real_batch_size, min(self.real_len, (cid+1)*self.real_batch_size)) + for i in range(cid*self.real_batch_size, min(self.real_len, (cid+1)*self.real_batch_size)): batch.append(self[self.index_list[i]]) - batch=self.collate_batch(batch) - now=time.time() - data_time=now-start - start=now + batch = self.collate_batch(batch) + now = time.time() + data_time = now - start + start = now # send data to main process if mp_log_v: - print(f"#{worker_id} {os.getpid()} send",type(batch).__name__,[type(b).__name__ for b in batch], buffer) + print(f"#{worker_id} {os.getpid()} send", type(batch).__name__, [ type(b).__name__ for b in batch ], buffer) try: buffer.send(batch) except: if buffer.is_stop(): continue raise - now=time.time() - send_time=now-start - start=now - status[0],status[1],status[2],status[3],status[4]=\ - other_time,data_time,send_time,\ - other_time+data_time+send_time,\ + now = time.time() + send_time = now - start + start = now + status[0], status[1], status[2], status[3], status[4] = \ + other_time, data_time, send_time, \ + other_time + data_time + send_time, \ img_open_hook.duration - img_open_hook.duration=0.0 + img_open_hook.duration = 0.0 except: import traceback - line=traceback.format_exc() + line = traceback.format_exc() print(line) - os.kill(os.getppid(),signal.SIGINT) + os.kill(os.getppid(), signal.SIGINT) exit(0) def display_worker_status(self): @@ -261,7 +261,7 @@ def display_worker_status(self): * buffer: ring buffer status, such as how many free space, left index, right index, total size(bytes). Example:: - + from jittor.dataset import Dataset class YourDataset(Dataset): pass @@ -269,17 +269,17 @@ class YourDataset(Dataset): for x, y in dataset: dataset.display_worker_status() ''' - if not hasattr(self,"workers"): + if not hasattr(self, "workers"): return - msg=[""] + msg = [""] msg.append(f"progress:{self.last_id}/{self.batch_len}") msg.append(f"batch(s): {self.batch_time:.3f}\twait(s):{self.wait_time:.3f}") msg.append(f"recv(s): {self.recv_time:.3f}\tto_jittor(s):{self.to_jittor_time:.3f}") - msg.append(f"last 10 workers: {self.idmap[max(0,self.last_id-9):self.last_id+1]}") + msg.append(f"last 10 workers: {self.idmap[max(0, self.last_id-9):self.last_id+1]}") msg.append(f"ID\twait(s)\topen(s)\tload(s)\tsend(s)\ttotal(s)") for i in range(self.num_workers): - w=self.workers[i] - s=w.status + w = self.workers[i] + s = w.status msg.append(f"#{i}\t{s[0]:.3f}\t{s[4]:.3f}\t{s[1]:.3f}\t{s[2]:.3f}\t{s[3]:.3f}\t{w.buffer}") LOG.i('\n'.join(msg)) @@ -288,44 +288,44 @@ def _stop_all_workers(self): for w in self.workers: w.buffer.stop() # wait until all workers idle - if self.num_idle.value [012] if jt.in_mpi: - world_size=mpi.world_size() - world_rank=mpi.world_rank() - index_list=np.int32(index_list) - mpi.broadcast(index_list,0) + world_size = mpi.world_size() + world_rank = mpi.world_rank() + index_list = np.int32(index_list) + mpi.broadcast(index_list, 0) - assert self.batch_size>=world_size,\ + assert self.batch_size >= world_size, \ f"Batch size({self.batch_size}) is smaller than MPI world_size({world_size})" - real_batch_size=(self.batch_size-1)//world_size+1 - if real_batch_size*world_size!=self.batch_size: + real_batch_size = (self.batch_size-1) // world_size + 1 + if real_batch_size * world_size != self.batch_size: LOG.w("Batch size is not divisible by MPI world size, " "The distributed version may be different from " "the single-process version.") - fix_batch=self.total_len//self.batch_size - last_batch=self.total_len-fix_batch*self.batch_size - fix_batch_l=index_list[0:fix_batch*self.batch_size]\ + fix_batch = self.total_len // self.batch_size + last_batch = self.total_len - fix_batch * self.batch_size + fix_batch_l = index_list[0:fix_batch*self.batch_size] \ .reshape(-1,self.batch_size) - fix_batch_l=fix_batch_l[ - :,real_batch_size*world_rank:real_batch_size*(world_rank+1)] - real_batch_size=fix_batch_l.shape[1] - fix_batch_l=fix_batch_l.flatten() - if not self.drop_last and last_batch>0: - last_batch_l=index_list[-last_batch:] - real_last_batch=(last_batch-1)//world_size+1 - l=real_last_batch*world_rank - r=l+real_last_batch - if r>last_batch: r=last_batch - if l>=r: l=r-1 - index_list=np.concatenate([fix_batch_l,last_batch_l[l:r]]) + fix_batch_l = fix_batch_l[ + :,real_batch_size*world_rank:real_batch_size*(world_rank+1)] + real_batch_size = fix_batch_l.shape[1] + fix_batch_l = fix_batch_l.flatten() + if not self.drop_last and last_batch > 0: + last_batch_l = index_list[-last_batch:] + real_last_batch = (last_batch-1)//world_size+1 + l = real_last_batch * world_rank + r = l + real_last_batch + if r > last_batch: r = last_batch + if l >= r: l = r-1 + index_list = np.concatenate([fix_batch_l, last_batch_l[l:r]]) else: - index_list=fix_batch_l + index_list = fix_batch_l - self.real_len=len(index_list) - self.real_batch_size=real_batch_size - assert self.total_len//self.batch_size==\ - self.real_len//self.real_batch_size,f"Number of batches({self.total_len//self.batch_size}!={self.real_len//self.real_batch_size}) not match, total_len: {self.total_len}, batch_size: {self.batch_size}, real_len: {self.real_len}, real_batch_size: {self.real_batch_size}" + self.real_len = len(index_list) + self.real_batch_size = real_batch_size + assert self.total_len // self.batch_size == \ + self.real_len // self.real_batch_size, f"Number of batches({self.total_len // self.batch_size}!={self.real_len // self.real_batch_size}) not match, total_len: {self.total_len}, batch_size: {self.batch_size}, real_len: {self.real_len}, real_batch_size: {self.real_batch_size}" else: - self.real_len=self.total_len - self.real_batch_size=self.batch_size - - self.batch_len=self.__batch_len__() - - if not hasattr(self,"workers") and self.num_workers: + self.real_len = self.total_len + self.real_batch_size = self.batch_size + + self.batch_len = self.__batch_len__() + + if not hasattr(self, "workers") and self.num_workers: self._init_workers() - + if self.num_workers: self._stop_all_workers() - self.index_list_numpy[:]=index_list - gid_obj=self.gid.get_obj() - gid_lock=self.gid.get_lock() + self.index_list_numpy[:] = index_list + gid_obj = self.gid.get_obj() + gid_lock = self.gid.get_lock() with gid_lock: - gid_obj.value=0 + gid_obj.value = 0 self.gidc.notify_all() - start=time.time() - self.batch_time=0 + start = time.time() + self.batch_time = 0 for i in range(self.batch_len): # try not get lock first - if gid_obj.value<=i: + if gid_obj.value <= i: with gid_lock: - if gid_obj.value<=i: + if gid_obj.value <= i: if mp_log_v: print("wait") self.gidc.wait() - now=time.time() - self.wait_time=now-start - start=now + now = time.time() + self.wait_time = now - start + start = now - self.last_id=i - worker_id=self.idmap[i] - w=self.workers[worker_id] + self.last_id = i + worker_id = self.idmap[i] + w = self.workers[worker_id] if mp_log_v: - print(f"#{worker_id} {os.getpid()} recv buffer",w.buffer) - batch=w.buffer.recv() - now=time.time() - self.recv_time=now-start - start=now + print(f"#{worker_id} {os.getpid()} recv buffer", w.buffer) + batch = w.buffer.recv() + now = time.time() + self.recv_time = now - start + start = now if mp_log_v: - print(f"#{worker_id} {os.getpid()} recv",type(batch).__name__,[type(b).__name__ for b in batch]) - batch=self.to_jittor(batch) - now=time.time() - self.to_jittor_time=now-start - start=now + print(f"#{worker_id} {os.getpid()} recv", type(batch).__name__, [ type(b).__name__ for b in batch ]) + batch = self.to_jittor(batch) + now = time.time() + self.to_jittor_time = now - start + start = now yield batch - now=time.time() - self.batch_time=now-start - start=now + now = time.time() + self.batch_time = now - start + start = now else: - batch_data=[] + batch_data = [] for idx in index_list: batch_data.append(self[int(idx)]) - if len(batch_data)==self.real_batch_size: - batch_data=self.collate_batch(batch_data) - batch_data=self.to_jittor(batch_data) + if len(batch_data) == self.real_batch_size: + batch_data = self.collate_batch(batch_data) + batch_data = self.to_jittor(batch_data) yield batch_data - batch_data=[] + batch_data = [] # depend on drop_last - if not self.drop_last and len(batch_data)>0: - batch_data=self.collate_batch(batch_data) - batch_data=self.to_jittor(batch_data) + if not self.drop_last and len(batch_data) > 0: + batch_data = self.collate_batch(batch_data) + batch_data = self.to_jittor(batch_data) yield batch_data class ImageFolder(Dataset): """ A image classify dataset, load image and label from directory:: - + * root/label1/img1.png * root/label1/img2.png * ... @@ -495,28 +495,28 @@ class ImageFolder(Dataset): ... """ - def __init__(self,root,transform=None): + def __init__(self, root, transform=None): super().__init__() - self.root=root - self.transform=transform - self.classes=sorted([d.name for d in os.scandir(root) if d.is_dir()]) - self.class_to_idx={v:k for k,v in enumerate(self.classes)} - self.imgs=[] - image_exts=set(('.jpg','.jpeg','.png','.bmp','.tif','.tiff')) - - for i,class_name in enumerate(self.classes): - class_dir=os.path.join(root,class_name) - for dname,_,fnames in sorted(os.walk(class_dir,followlinks=True)): + self.root = root + self.transform = transform + self.classes = sorted([d.name for d in os.scandir(root) if d.is_dir()]) + self.class_to_idx = {v:k for k,v in enumerate(self.classes)} + self.imgs = [] + image_exts = set(('.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff')) + + for i, class_name in enumerate(self.classes): + class_dir = os.path.join(root, class_name) + for dname, _, fnames in sorted(os.walk(class_dir, followlinks=True)): for fname in sorted(fnames): if os.path.splitext(fname)[-1].lower() in image_exts: - path=os.path.join(class_dir,fname) - self.imgs.append((path,i)) + path = os.path.join(class_dir, fname) + self.imgs.append((path, i)) LOG.i(f"Found {len(self.classes)} classes and {len(self.imgs)} images.") self.set_attrs(total_len=len(self.imgs)) - - def __getitem__(self,k): - with open(self.imgs[k][0],'rb') as f: - img=Image.open(f).convert('RGB') + + def __getitem__(self, k): + with open(self.imgs[k][0], 'rb') as f: + img = Image.open(f).convert('RGB') if self.transform: - img=self.transform(img) - return img,self.imgs[k][1] \ No newline at end of file + img = self.transform(img) + return img, self.imgs[k][1] \ No newline at end of file From 9711e85b37f32d9b2afb2d5a5ff475fb9a609453 Mon Sep 17 00:00:00 2001 From: Exusial <2247838039@qq.com> Date: Sat, 20 Mar 2021 22:20:34 +0800 Subject: [PATCH 15/36] copy. --- python/jittor/dataset/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/jittor/dataset/dataset.py b/python/jittor/dataset/dataset.py index 2f1ba092..e77c980d 100644 --- a/python/jittor/dataset/dataset.py +++ b/python/jittor/dataset/dataset.py @@ -519,4 +519,4 @@ def __getitem__(self, k): img = Image.open(f).convert('RGB') if self.transform: img = self.transform(img) - return img, self.imgs[k][1] \ No newline at end of file + return img, self.imgs[k][1] From 1d76ae2a680d40558d31c20dc0602b003ab4816a Mon Sep 17 00:00:00 2001 From: Exusial <2247838039@qq.com> Date: Mon, 22 Mar 2021 10:33:00 +0800 Subject: [PATCH 16/36] delete .idea --- python/jittor/dataset/.idea/.gitignore | 2 -- .../.idea/codeStyles/codeStyleConfig.xml | 5 ---- python/jittor/dataset/.idea/dataset.iml | 9 -------- .../inspectionProfiles/Project_Default.xml | 23 ------------------- python/jittor/dataset/.idea/misc.xml | 6 ----- python/jittor/dataset/.idea/modules.xml | 8 ------- python/jittor/dataset/.idea/vcs.xml | 6 ----- 7 files changed, 59 deletions(-) delete mode 100644 python/jittor/dataset/.idea/.gitignore delete mode 100644 python/jittor/dataset/.idea/codeStyles/codeStyleConfig.xml delete mode 100644 python/jittor/dataset/.idea/dataset.iml delete mode 100644 python/jittor/dataset/.idea/inspectionProfiles/Project_Default.xml delete mode 100644 python/jittor/dataset/.idea/misc.xml delete mode 100644 python/jittor/dataset/.idea/modules.xml delete mode 100644 python/jittor/dataset/.idea/vcs.xml diff --git a/python/jittor/dataset/.idea/.gitignore b/python/jittor/dataset/.idea/.gitignore deleted file mode 100644 index 5c98b428..00000000 --- a/python/jittor/dataset/.idea/.gitignore +++ /dev/null @@ -1,2 +0,0 @@ -# Default ignored files -/workspace.xml \ No newline at end of file diff --git a/python/jittor/dataset/.idea/codeStyles/codeStyleConfig.xml b/python/jittor/dataset/.idea/codeStyles/codeStyleConfig.xml deleted file mode 100644 index a55e7a17..00000000 --- a/python/jittor/dataset/.idea/codeStyles/codeStyleConfig.xml +++ /dev/null @@ -1,5 +0,0 @@ - - - - \ No newline at end of file diff --git a/python/jittor/dataset/.idea/dataset.iml b/python/jittor/dataset/.idea/dataset.iml deleted file mode 100644 index d6ebd480..00000000 --- a/python/jittor/dataset/.idea/dataset.iml +++ /dev/null @@ -1,9 +0,0 @@ - - - - - - - - - \ No newline at end of file diff --git a/python/jittor/dataset/.idea/inspectionProfiles/Project_Default.xml b/python/jittor/dataset/.idea/inspectionProfiles/Project_Default.xml deleted file mode 100644 index 95a8d909..00000000 --- a/python/jittor/dataset/.idea/inspectionProfiles/Project_Default.xml +++ /dev/null @@ -1,23 +0,0 @@ - - - - \ No newline at end of file diff --git a/python/jittor/dataset/.idea/misc.xml b/python/jittor/dataset/.idea/misc.xml deleted file mode 100644 index 37e641e9..00000000 --- a/python/jittor/dataset/.idea/misc.xml +++ /dev/null @@ -1,6 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/python/jittor/dataset/.idea/modules.xml b/python/jittor/dataset/.idea/modules.xml deleted file mode 100644 index 77bc753c..00000000 --- a/python/jittor/dataset/.idea/modules.xml +++ /dev/null @@ -1,8 +0,0 @@ - - - - - - - - \ No newline at end of file diff --git a/python/jittor/dataset/.idea/vcs.xml b/python/jittor/dataset/.idea/vcs.xml deleted file mode 100644 index c2365ab1..00000000 --- a/python/jittor/dataset/.idea/vcs.xml +++ /dev/null @@ -1,6 +0,0 @@ - - - - - - \ No newline at end of file From fa62b3055078a10e7fcc26fb65261dfa114e7b19 Mon Sep 17 00:00:00 2001 From: Dun Liang Date: Tue, 23 Mar 2021 18:56:50 +0800 Subject: [PATCH 17/36] add sampler hook in dataset --- python/jittor/dataset/dataset.py | 24 ++++++++++++--- python/jittor/dataset/sampler.py | 49 ++++++++++++++++++++---------- python/jittor/test/test_sampler.py | 23 +++++++++++--- 3 files changed, 70 insertions(+), 26 deletions(-) diff --git a/python/jittor/dataset/dataset.py b/python/jittor/dataset/dataset.py index e77c980d..423c76d5 100644 --- a/python/jittor/dataset/dataset.py +++ b/python/jittor/dataset/dataset.py @@ -80,6 +80,7 @@ def __init__(self, self.buffer_size = buffer_size self.stop_grad = stop_grad self.keep_numpy_array = keep_numpy_array + self.sampler = None def __getitem__(self, index): raise NotImplementedError @@ -343,10 +344,23 @@ def __del__(self): print("dataset deleted") self.terminate() + def __real_len__(self): + if self.total_len is None: + self.total_len = len(self) + return self.total_len + def __iter__(self): if self.total_len is None: self.total_len = len(self) - if self.shuffle == False: + # maybe rewrite by sampler + total_len = self.total_len + if self.sampler: + index_list = list(self.sampler.__iter__()) + total_len = len(index_list) + # check is not batch sampler + if len(index_list): + assert not isinstance(index_list[0], (list,tuple)), "Batch sampler not support yet." + elif self.shuffle == False: index_list = get_order_list(self.total_len) else: index_list = get_random_list(self.total_len) @@ -373,8 +387,8 @@ def __iter__(self): LOG.w("Batch size is not divisible by MPI world size, " "The distributed version may be different from " "the single-process version.") - fix_batch = self.total_len // self.batch_size - last_batch = self.total_len - fix_batch * self.batch_size + fix_batch = total_len // self.batch_size + last_batch = total_len - fix_batch * self.batch_size fix_batch_l = index_list[0:fix_batch*self.batch_size] \ .reshape(-1,self.batch_size) fix_batch_l = fix_batch_l[ @@ -394,8 +408,8 @@ def __iter__(self): self.real_len = len(index_list) self.real_batch_size = real_batch_size - assert self.total_len // self.batch_size == \ - self.real_len // self.real_batch_size, f"Number of batches({self.total_len // self.batch_size}!={self.real_len // self.real_batch_size}) not match, total_len: {self.total_len}, batch_size: {self.batch_size}, real_len: {self.real_len}, real_batch_size: {self.real_batch_size}" + assert total_len // self.batch_size == \ + self.real_len // self.real_batch_size, f"Number of batches({total_len // self.batch_size}!={self.real_len // self.real_batch_size}) not match, total_len: {total_len}, batch_size: {self.batch_size}, real_len: {self.real_len}, real_batch_size: {self.real_batch_size}" else: self.real_len = self.total_len self.real_batch_size = self.batch_size diff --git a/python/jittor/dataset/sampler.py b/python/jittor/dataset/sampler.py index 0f018b07..4ace78a4 100644 --- a/python/jittor/dataset/sampler.py +++ b/python/jittor/dataset/sampler.py @@ -1,3 +1,12 @@ +# *************************************************************** +# Copyright (c) 2021 Jittor. All Rights Reserved. +# Maintainers: +# Hao-Yang Peng +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** import jittor as jt from .dataset import Dataset import numpy as np @@ -5,57 +14,65 @@ class Sampler(): - def __init__(self, data_source): - self.data_source = data_source + def __init__(self, dataset): + self.dataset = dataset + # MUST set sampler here + dataset.sampler = self def __iter__(self): - pass + raise NotImplementedError def __len__(self): - pass + raise NotImplementedError class SequentialSampler(Sampler): - def __init__(self, data_source): - self.data_source = data_source + def __init__(self, dataset): + # MUST set sampler here + dataset.sampler = self + self.dataset = dataset def __iter__(self): - return iter(range(len(self.data_source))) + return iter(range(self.dataset.__real_len__())) def __len__(self): - return len(self.data_source) + return self.dataset.__real_len__() class RandomSampler(Sampler): - def __init__(self, data_source, replacement=False, num_samples=None): - self.data_source = data_source + def __init__(self, dataset, replacement=False, num_samples=None): + # MUST set sampler here + dataset.sampler = self + self.dataset = dataset self.rep = replacement self._num_samples = num_samples @property def num_samples(self): if self._num_samples is None: - return len(self.data_source) + return self.dataset.__real_len__() return self._num_samples def __len__(self): return self.num_samples def __iter__(self): - n = len(self.data_source) + n = self.dataset.__real_len__() if self.rep: return iter(np.random.randint(low=0, high=n, size=(self.num_samples,), dtype=np.int64).tolist()) return iter(np.random.permutation(n).tolist()) class SubsetRandomSampler(Sampler): - def __init__(self, data_source, indice): - self.data_source = data_source + def __init__(self, dataset, indice): + # MUST set sampler here + dataset.sampler = self + self.dataset = dataset self.indices = indice - assert indice[0] >= 0 and indice[1] < data_source.total_len and indice[0] < indice[1] + assert indice[0] >= 0 and indice[1] < dataset.__real_len__() and indice[0] < indice[1] def __iter__(self): - return (self.data_source[i + self.indices[0]] for i in np.random.permutation(self.indices[1] - self.indices[0])) + return (int(i) + self.indices[0] for i in np.random.permutation(self.indices[1] - self.indices[0])) def __len__(self): return self.indices[1] - self.indices[0] diff --git a/python/jittor/test/test_sampler.py b/python/jittor/test/test_sampler.py index 824ff18a..2cb2d864 100644 --- a/python/jittor/test/test_sampler.py +++ b/python/jittor/test/test_sampler.py @@ -4,7 +4,6 @@ import numpy as np import unittest -test_img = np.random.normal(size=(40, 1, 2, 2)) class TestSamplerDataset(Dataset): @@ -13,28 +12,42 @@ def __init__(self): self.set_attrs(total_len=40, batch_size=1) def __getitem__(self, idx): - return test_img[idx:(idx + 1), ...] - - -testdataset = TestSamplerDataset() + return idx**2 class TestSampler(unittest.TestCase): def test_sequential_sampler(self): + testdataset = TestSamplerDataset() seqsampler = SequentialSampler(testdataset) assert len(seqsampler) == 40 for idx, batch in enumerate(seqsampler): assert idx == batch + for i, data in enumerate(testdataset): + assert data.item() == i**2 def test_random_sampler(self): + testdataset = TestSamplerDataset() randomsampler = RandomSampler(testdataset) assert len(randomsampler) == 40 + diff = 0 + for i, data in enumerate(testdataset): + diff += data.item() == i**2 + assert diff < 10 def test_subset_random_sampler(self): + testdataset = TestSamplerDataset() subsetsampler = SubsetRandomSampler(testdataset, (20, 30)) assert len(subsetsampler) == 10 + s = 0 + for i, data in enumerate(testdataset): + s += data.item() + s2 = 0 + for i in range(20,30): + s2 += i**2 + assert s == s2, (s, s2) def test_batch_sampler(self): + testdataset = TestSamplerDataset() seqforbatch = SequentialSampler(testdataset) batchsampler = BatchSampler(seqforbatch, 4, drop_last=False) assert len(batchsampler) == 10 From 74a93e78ccd2ad6a5bdfcfcc0cfda2255d9f4e22 Mon Sep 17 00:00:00 2001 From: Dun Liang Date: Tue, 23 Mar 2021 18:59:59 +0800 Subject: [PATCH 18/36] add doc --- python/jittor/dataset/sampler.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/python/jittor/dataset/sampler.py b/python/jittor/dataset/sampler.py index 4ace78a4..465ee344 100644 --- a/python/jittor/dataset/sampler.py +++ b/python/jittor/dataset/sampler.py @@ -65,6 +65,15 @@ def __iter__(self): class SubsetRandomSampler(Sampler): def __init__(self, dataset, indice): + ''' + testdataset = TestSamplerDataset() + subsetsampler = SubsetRandomSampler(testdataset, (20, 30)) + + for i, data in enumerate(testdataset): + # data between 20 ~ 29 + ...... + + ''' # MUST set sampler here dataset.sampler = self self.dataset = dataset From a20a1bd7c50436f37e5938ef960265856e050876 Mon Sep 17 00:00:00 2001 From: Dun Liang Date: Tue, 23 Mar 2021 19:07:01 +0800 Subject: [PATCH 19/36] add sampler --- python/jittor/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/jittor/__init__.py b/python/jittor/__init__.py index 84e20bb0..b3a0973f 100644 --- a/python/jittor/__init__.py +++ b/python/jittor/__init__.py @@ -8,7 +8,7 @@ # This file is subject to the terms and conditions defined in # file 'LICENSE.txt', which is part of this source code package. # *************************************************************** -__version__ = '1.2.2.55' +__version__ = '1.2.2.56' from . import lock with lock.lock_scope(): ori_int = int From 0b9c927e51e496bf6d01eae355b1c2d47f84ed5e Mon Sep 17 00:00:00 2001 From: Exusial <2247838039@qq.com> Date: Thu, 25 Mar 2021 10:18:59 +0800 Subject: [PATCH 20/36] add onednn install+support.pass mkl test. --- python/jittor/compile_extern.py | 95 ++++++++++++++++++++++----------- 1 file changed, 64 insertions(+), 31 deletions(-) diff --git a/python/jittor/compile_extern.py b/python/jittor/compile_extern.py index 37fda96b..323739c8 100644 --- a/python/jittor/compile_extern.py +++ b/python/jittor/compile_extern.py @@ -20,21 +20,42 @@ def search_file(dirs, name): def install_mkl(root_folder): # origin url is # url = "https://github.com/intel/mkl-dnn/releases/download/v1.0.2/mkldnn_lnx_1.0.2_cpu_gomp.tgz" - url = "https://cloud.tsinghua.edu.cn/f/da02bf62b55b4aa3b8ee/?dl=1" - filename = "mkldnn_lnx_1.0.2_cpu_gomp.tgz" - fullname = os.path.join(root_folder, filename) - dirname = os.path.join(root_folder, filename.replace(".tgz","")) - - if not os.path.isfile(os.path.join(dirname, "examples", "test")): - LOG.i("Downloading mkl...") - download_url_to_local(url, filename, root_folder, "47187284ede27ad3bd64b5f0e7d5e730") + if os.environ.get("use_onednn","1")=="1": + print("get in") + url = "https://cloud.tsinghua.edu.cn/f/cd63e0df3c5c4c52b76d/?dl=1" + filename = "oneDNN-2.2-rc.tar.gz" + fullname = os.path.join(root_folder, filename) + dirname = os.path.join(root_folder, filename.replace(".tar.gz","")) + download_url_to_local(url, filename, root_folder, "fd6e22bb49dedcf0430495098b3dcf1f") import tarfile - with tarfile.open(fullname, "r") as tar: tar.extractall(root_folder) - - assert 0 == os.system(f"cd {dirname}/examples && " - f"{cc_path} -std=c++14 cpu_cnn_inference_f32.cpp -Ofast -lmkldnn -I ../include -L ../lib -o test && LD_LIBRARY_PATH=../lib/ ./test") + import platform + if platform.machine() == "aarch64": + os.system(f"cd {dirname} && mkdir -p build && cd build && export CC=aarch64-linux-gnu-gcc && export CXX=aarch64-linux-gnu-g++ && cmake .. \ + -DCMAKE_SYSTEM_NAME=Linux \ + -DCMAKE_SYSTEM_PROCESSOR=AARCH64 \ + -DCMAKE_LIBRARY_PATH=/usr/aarch64-linux-gnu/lib \ + -D CMAKE_INSTALL_PREFIX={root_folder} && make -j && make install") + else: + os.system(f"cd {dirname} && mkdir -p build && cd build && cmake -D CMAKE_INSTALL_PREFIX={root_folder} .. && make -j && make install") + # TODO add completition test. + else: + url = "https://cloud.tsinghua.edu.cn/f/da02bf62b55b4aa3b8ee/?dl=1" + filename = "mkldnn_lnx_1.0.2_cpu_gomp.tgz" + fullname = os.path.join(root_folder, filename) + dirname = os.path.join(root_folder, filename.replace(".tgz","")) + + if not os.path.isfile(os.path.join(dirname, "examples", "test")): + LOG.i("Downloading mkl...") + download_url_to_local(url, filename, root_folder, "47187284ede27ad3bd64b5f0e7d5e730") + import tarfile + + with tarfile.open(fullname, "r") as tar: + tar.extractall(root_folder) + + assert 0 == os.system(f"cd {dirname}/examples && " + f"{cc_path} -std=c++14 cpu_cnn_inference_f32.cpp -Ofast -lmkldnn -I ../include -L ../lib -o test && LD_LIBRARY_PATH=../lib/ ./test") def setup_mkl(): global mkl_ops, use_mkl @@ -43,25 +64,35 @@ def setup_mkl(): if not use_mkl: return mkl_include_path = os.environ.get("mkl_include_path") mkl_lib_path = os.environ.get("mkl_lib_path") - + if mkl_lib_path is None or mkl_include_path is None: - mkl_install_sh = os.path.join(jittor_path, "script", "install_mkl.sh") - LOG.v("setup mkl...") - # mkl_path = os.path.join(cache_path, "mkl") - # mkl_path decouple with cc_path - from pathlib import Path - mkl_path = os.path.join(str(Path.home()), ".cache", "jittor", "mkl") - - make_cache_dir(mkl_path) - install_mkl(mkl_path) - mkl_home = "" - for name in os.listdir(mkl_path): - if name.startswith("mkldnn_lnx") and os.path.isdir(os.path.join(mkl_path, name)): - mkl_home = os.path.join(mkl_path, name) - break - assert mkl_home!="" - mkl_include_path = os.path.join(mkl_home, "include") - mkl_lib_path = os.path.join(mkl_home, "lib") + if os.environ.get("use_onednn","1")=="1": + LOG.v("setup onednn...") + from pathlib import Path + one_path = os.path.join(str(Path.home()),".cache", "jittor", "one") + make_cache_dir(one_path) + mkl_include_path = os.path.join(one_path,"include") + mkl_lib_path = os.path.join(one_path,"lib") + if not os.path.isdir(mkl_include_path) or not os.path.isdir(mkl_lib_path) or not os.path.isfile(os.path.join(mkl_lib_path, "libmkldnn.so")): + install_mkl(one_path) + else: + mkl_install_sh = os.path.join(jittor_path, "script", "install_mkl.sh") + LOG.v("setup mkl...") + # mkl_path = os.path.join(cache_path, "mkl") + # mkl_path decouple with cc_path + from pathlib import Path + mkl_path = os.path.join(str(Path.home()), ".cache", "jittor", "mkl") + + make_cache_dir(mkl_path) + install_mkl(mkl_path) + mkl_home = "" + for name in os.listdir(mkl_path): + if name.startswith("mkldnn_lnx") and os.path.isdir(os.path.join(mkl_path, name)): + mkl_home = os.path.join(mkl_path, name) + break + assert mkl_home!="" + mkl_include_path = os.path.join(mkl_home, "include") + mkl_lib_path = os.path.join(mkl_home, "lib") mkl_lib_name = os.path.join(mkl_lib_path, "libmkldnn.so") assert os.path.isdir(mkl_include_path) @@ -72,7 +103,9 @@ def setup_mkl(): LOG.v(f"mkl_lib_name: {mkl_lib_name}") # We do not link manualy, link in custom ops # ctypes.CDLL(mkl_lib_name, dlopen_flags) - + print(f"mkl_include_path: {mkl_include_path}") + print(f"mkl_lib_path: {mkl_lib_path}") + print(f"mkl_lib_name: {mkl_lib_name}") mkl_op_dir = os.path.join(jittor_path, "extern", "mkl", "ops") mkl_op_files = [os.path.join(mkl_op_dir, name) for name in os.listdir(mkl_op_dir)] mkl_ops = compile_custom_ops(mkl_op_files, From b001f87ecc57d6bce5c54722ad4188ea52397c7d Mon Sep 17 00:00:00 2001 From: Exusial <2247838039@qq.com> Date: Fri, 16 Apr 2021 08:11:53 +0800 Subject: [PATCH 21/36] add opencl. --- opencl/inc/opencl_matmul_helper.h | 7 + opencl/lib/kernels.cl | 1327 +++++++++++++++++++++++++++++ opencl/lib/settings.h | 74 ++ opencl/ops/minimal.cpp | 151 ++++ opencl/ops/opencl_matmul_op.cc | 73 ++ opencl/ops/opencl_matmul_op.h | 17 + opencl/src/clGEMM.cc | 427 ++++++++++ opencl/src/kernels.cl | 1326 ++++++++++++++++++++++++++++ opencl/src/settings.h | 74 ++ python/jittor/compile_extern.py | 31 + 10 files changed, 3507 insertions(+) create mode 100644 opencl/inc/opencl_matmul_helper.h create mode 100644 opencl/lib/kernels.cl create mode 100644 opencl/lib/settings.h create mode 100644 opencl/ops/minimal.cpp create mode 100644 opencl/ops/opencl_matmul_op.cc create mode 100644 opencl/ops/opencl_matmul_op.h create mode 100644 opencl/src/clGEMM.cc create mode 100644 opencl/src/kernels.cl create mode 100644 opencl/src/settings.h diff --git a/opencl/inc/opencl_matmul_helper.h b/opencl/inc/opencl_matmul_helper.h new file mode 100644 index 00000000..d08d68b6 --- /dev/null +++ b/opencl/inc/opencl_matmul_helper.h @@ -0,0 +1,7 @@ +#pragma once +namespace jittor{ +// ================================================================================================= +void myclblas(float* A, float* B, float* C, + int K, int M, int N); +// ================================================================================================= +} \ No newline at end of file diff --git a/opencl/lib/kernels.cl b/opencl/lib/kernels.cl new file mode 100644 index 00000000..875ffbc5 --- /dev/null +++ b/opencl/lib/kernels.cl @@ -0,0 +1,1327 @@ + +// ================================================================================================= +// Project: +// Exploring the performance of general matrix-multiplication on an NVIDIA Tesla K40m GPU. +// +// File information: +// Institution.... SURFsara +// Author......... Cedric Nugteren +// Changed at..... 2014-11-06 +// License........ MIT license +// Tab-size....... 4 spaces +// Line length.... 100 characters +// +// ================================================================================================= +// +// Matrices in column-major format +// A: K columns, M rows +// B: N columns, K rows +// C: N columns, M rows +// +// N +// o-----o +// | | +// K | [B] | +// | | +// o-----o +// K N +// o-------o o-----o +// M | [A] | M | [C] | +// | | | | +// o-------o o-----o +// +// +// C-code for column-major matrix multiplication with alpha=1 and beta=0: +// +// for (int m=0; m +// Author......... Cedric Nugteren +// Changed at..... 2014-11-07 +// License........ MIT license +// Tab-size....... 4 spaces +// Line length.... 100 characters +// +// ================================================================================================= + +// Select a kernel +#define KERNEL 8 + +// Constants for kernels 1 -- 5 +#define TS 32 // The square-root of the 2D tile-size (== work-group dims) + +// Constants for kernels 3, 5 +#define WPT 8 // The amount of work-per-thread, i.e. the thread-coarsening factor +#define RTS (TS/WPT) // The reduced tile-size in one dimension + +// Constants for kernels 4, 7 -- 10 +#define WIDTH 4 // The vector-width (in number of floats) + +// Constants for kernel 5 +#define TSDK 16 // The tile-size in dimension K (for kernel 5 only) +#define LPT ((TSDK*WPT)/(TS)) // The amount of loads-per-thread (assume TSN==TSM) + +// Constants for kernels 6 -- 10 +#define TSM 128 // The tile-size in dimension M +#define TSN 128 // The tile-size in dimension N +#define TSK 16 // The tile-size in dimension K +#define WPTM 8 // The amount of work-per-thread in dimension M +#define WPTN 8 // The amount of work-per-thread in dimension N +#define RTSM (TSM/WPTM) // The reduced tile-size in dimension M (== number of threads) +#define RTSN (TSN/WPTN) // The reduced tile-size in dimension N (== number of threads) +#define LPTA ((TSK*WPTM*WPTN)/(TSN)) // The amount of loads-per-thread for A +#define LPTB ((TSK*WPTM*WPTN)/(TSM)) // The amount of loads-per-thread for B + +// Constraints on settings for kernels 6 -- 10 +// Note: TSM/WPTM has to be integer +// Note: TSN/WPTN has to be integer +// Note: TSM/WIDTH has to be integer +// Note: TSN/WIDTH has to be integer +// Note: (TSK*WPTM*WPTN)/(TSN*WIDTH) has to be integer +// Note: (TSK*WPTM*WPTN)/(TSM*WIDTH) has to be integer + +// Constants for kernel 11 (mimicing clBlas) +#define THREADSX 8 +#define THREADSY 8 +#define RX 8 +#define RY 4 +#define RK (RY) + +// Constants for the supporting transpose kernel +#define TRANSPOSEX 16 +#define TRANSPOSEY 16 + +// Constants for the supporting padding kernels +#define PADDINGX 16 +#define PADDINGY 16 + +// Macros for host and kernel code +#define MIN(a,b) ((a) > (b)) ? (b) : (a) +#define MAX(a,b) ((a) > (b)) ? (a) : (b) +#define CEIL_DIV(x,y) (((x) + (y) - 1) / (y)) +#define MOD2(x,y) ((x) % (y)) +#define DIV2(x,y) ((x) / (y)) + +// ================================================================================================= diff --git a/opencl/ops/minimal.cpp b/opencl/ops/minimal.cpp new file mode 100644 index 00000000..151af872 --- /dev/null +++ b/opencl/ops/minimal.cpp @@ -0,0 +1,151 @@ + +// ================================================================================================= +// Project: +// Exploring the performance of general matrix-multiplication on an NVIDIA Tesla K40m GPU. +// +// File information: +// Institution.... SURFsara +// Author......... Cedric Nugteren +// Changed at..... 2014-11-07 +// License........ MIT license +// Tab-size....... 4 spaces +// Line length.... 100 characters +// +// Compilation example: +// g++ -O3 -I$OPENCL_DIR/include minimal.cpp -o minimal -lOpenCL +// +// ================================================================================================= + +// Includes +#include +#include +#include +#include "opencl_matmul_helper.h" +// ================================================================================================= + +// Repeat all kernels multiple times to get an average timing result +#define NUM_RUNS 1 + +// Size of the matrices - K, M, N (squared) +#define SIZE 64 + +// Threadblock sizes (e.g. for kernels myGEMM1 or myGEMM2) +#define TS 1 + +// ================================================================================================= + +// Set the kernel as a string (better to do this in a separate file though) +namespace jittor{ +const char *kernelstring = + "__kernel void myGEMM1(const int M, const int N, const int K," + " const __global float* A," + " const __global float* B," + " __global float* C) {" + " const int globalRow = get_global_id(0);" + " const int globalCol = get_global_id(1);" + " float acc = 0.0f;" + " for (int k=0; k>> Initializing OpenCL...\n"); + cl_platform_id platform = 0; + clGetPlatformIDs(1, &platform, NULL); + cl_device_id device = 0; + clGetDeviceIDs(platform, CL_DEVICE_TYPE_GPU, 1, &device, NULL); + cl_context context = clCreateContext(NULL, 1, &device, NULL, NULL, NULL); + cl_command_queue queue = clCreateCommandQueue(context, device, 0, NULL); + char deviceName[1024]; + clGetDeviceInfo(device, CL_DEVICE_NAME, 1024, deviceName, NULL); + cl_event event = NULL; + + // Compile the kernel + cl_program program = clCreateProgramWithSource(context, 1, &kernelstring, NULL, NULL); + clBuildProgram(program, 0, NULL, "", NULL, NULL); + // Check for compilation errors + size_t logSize; + clGetProgramBuildInfo(program, device, CL_PROGRAM_BUILD_LOG, 0, NULL, &logSize); + char* messages = (char*)malloc((1+logSize)*sizeof(char)); + clGetProgramBuildInfo(program, device, CL_PROGRAM_BUILD_LOG, logSize, messages, NULL); + messages[logSize] = '\0'; + if (logSize > 10) { printf(">>> Compiler message: %s\n", messages); } + free(messages); + + // Prepare OpenCL memory objects + cl_mem bufA = clCreateBuffer(context, CL_MEM_READ_ONLY, M*K*sizeof(float), NULL, NULL); + cl_mem bufB = clCreateBuffer(context, CL_MEM_READ_ONLY, K*N*sizeof(float), NULL, NULL); + cl_mem bufC = clCreateBuffer(context, CL_MEM_READ_WRITE, M*N*sizeof(float), NULL, NULL); + + // Copy matrices to the GPU + clEnqueueWriteBuffer(queue, bufA, CL_TRUE, 0, M*K*sizeof(float), A, 0, NULL, NULL); + clEnqueueWriteBuffer(queue, bufB, CL_TRUE, 0, K*N*sizeof(float), B, 0, NULL, NULL); + clEnqueueWriteBuffer(queue, bufC, CL_TRUE, 0, M*N*sizeof(float), C, 0, NULL, NULL); + + // Configure the myGEMM kernel and set its arguments + cl_kernel kernel = clCreateKernel(program, "myGEMM1", NULL); + clSetKernelArg(kernel, 0, sizeof(int), (void*)&M); + clSetKernelArg(kernel, 1, sizeof(int), (void*)&N); + clSetKernelArg(kernel, 2, sizeof(int), (void*)&K); + clSetKernelArg(kernel, 3, sizeof(cl_mem), (void*)&bufA); + clSetKernelArg(kernel, 4, sizeof(cl_mem), (void*)&bufB); + clSetKernelArg(kernel, 5, sizeof(cl_mem), (void*)&bufC); + + // Start the timed loop + printf(">>> Starting %d myGEMM runs...\n", NUM_RUNS); + gettimeofday(&Tvalue, &dummy); + double starttime = (double)Tvalue.tv_sec + 1.0e-6*((double)Tvalue.tv_usec); + for (int r=0; r>> Done: took %.3lf seconds per run, %.1lf GFLOPS\n", runtime, gflop/runtime); + + // Copy the output matrix C back to the CPU memory + clEnqueueReadBuffer(queue, bufC, CL_TRUE, 0, M*N*sizeof(float), C, 0, NULL, NULL); + for(int i=0;i +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** + +#include "var.h" +#include "opencl_matmul_helper.h" +#include "opencl_matmul_op.h" +#include "common.h" +using namespace std; + +namespace jittor { + +#ifndef JIT + +OpenclMatmulOp::OpenclMatmulOp(Var* a, Var* b, bool trans_a, bool trans_b) + : a(a), b(b), trans_a(trans_a), trans_b(trans_b) { + // TODO: support int8 * int8 + ASSERT(a->dtype().is_float() && b->dtype().is_float()) << "type of two inputs should be the same"; + // TODO: support diffrent input type + ASSERT(a->dtype().dsize() == b->dtype().dsize()) << "type of two inputs should be the same"; + c = create_output(nullptr, a->dtype()); +} + +void OpenclMatmulOp::infer_shape() { + ASSERTop(a->shape.size(),==,2); + ASSERTop(b->shape.size(),==,2); + int n = a->shape[0], m = a->shape[1]; + int m_ = b->shape[0], k = b->shape[1]; + if (trans_a) { + swap(n, m); + } + if (trans_b) { + swap(m_, k); + } + ASSERTop(m,==,m_); + c->set_shape({n, k}); +} + +void OpenclMatmulOp::jit_prepare(JK& jk) { + jk << _CS("[T:") << a->dtype(); + jk << _CS("][Trans_a:") << (trans_a ? 'T' : 'N'); + jk << _CS("][Trans_b:") << (trans_b ? 'T' : 'N'); + jk << _CS("][op:") << (a->dtype().dsize() == 4 ? 'S' : 'D'); + jk << ']'; +} + +#else // JIT +#pragma clang diagnostic ignored "-Wtautological-compare" +void OpenclMatmulOp::jit_run() { + const auto& as = a->shape; + const auto& bs = b->shape; + auto n = as[0]; + auto m = as[1]; + auto k = bs[1]; + if ('@Trans_a'=='T') { + n = as[1]; + m = as[0]; + } + if ('@Trans_b'=='T') { + k = bs[0]; + } + // a: [n,m], b: [m,k], c: [n,k] + myclblas(a->ptr(),b->ptr(),c->ptr(),k,m,n); +} +#endif // JIT + +} // jittor diff --git a/opencl/ops/opencl_matmul_op.h b/opencl/ops/opencl_matmul_op.h new file mode 100644 index 00000000..fd506193 --- /dev/null +++ b/opencl/ops/opencl_matmul_op.h @@ -0,0 +1,17 @@ +#pragma once +#include "op.h" + +namespace jittor { + +struct OpenclMatmulOp : Op { + Var* a, * b, * c; + bool trans_a, trans_b; + OpenclMatmulOp(Var* a, Var* b, bool trans_a, bool trans_b); + + const char* name() const override { return "opencl_matmul"; } + void infer_shape() override; + DECLARE_jit_run; +}; + +} // jittor + diff --git a/opencl/src/clGEMM.cc b/opencl/src/clGEMM.cc new file mode 100644 index 00000000..e874b067 --- /dev/null +++ b/opencl/src/clGEMM.cc @@ -0,0 +1,427 @@ + +// ================================================================================================= +// Project: +// Exploring the performance of general matrix-multiplication on an NVIDIA Tesla K40m GPU. +// +// File information: +// Institution.... SURFsara +// Author......... Cedric Nugteren +// Changed at..... 2014-11-17 +// License........ MIT license +// Tab-size....... 4 spaces +// Line length.... 100 characters +// +// ================================================================================================= + +#include +#include "opencl_matmul_helper.h" +#include +#include +// Set the locations of the OpenCL kernel files +#define CL_INCLUDE_FILE "settings.h" +#define CL_KERNEL_FILE "kernels.cl" +// Determine the location where to output the PTX code +#define CL_PTX_FILE "bin/myGEMM.cl.ptx" + +// Define OpenCL compiler options, such as "-cl-nv-maxrregcount=127" +#define COMPILER_OPTIONS "" + +namespace jittor{ +// Forward declaration of the OpenCL error checking function +void checkError(cl_int error, int line); + +// ================================================================================================= +// ================================================================================================= + +// Matrix-multiplication using a custom OpenCL SGEMM kernel. This function also copies the input +// matrices to the GPU, runs SGEMM, and copies the output matrix back to the CPU. +void myclblas(float* A, float* B, float* C, + int K, int M, int N) { + + // In case of myGEMM10, compute matrix sizes K, M, N as rounded-up to form complete tiles + #if KERNEL == 10 + int K_XL = CEIL_DIV(K, TSK) * TSK; + int M_XL = CEIL_DIV(M, TSM) * TSM; + int N_XL = CEIL_DIV(N, TSN) * TSN; + #else + int K_XL = K; + int M_XL = M; + int N_XL = N; + #endif + + // Define OpenCL variables + cl_int err; + cl_platform_id platform = 0; + cl_device_id device = 0; + cl_device_id devices[MAX_NUM_DEVICES]; + cl_uint numDevices = 0; + cl_context_properties props[3] = {CL_CONTEXT_PLATFORM, 0, 0}; + cl_context context = 0; + cl_command_queue queue = 0; + cl_event event = NULL; + cl_program program = NULL; + char deviceName[MAX_DEVICE_NAME]; + + // Configure the OpenCL environment + err = clGetPlatformIDs(1, &platform, NULL); + err = clGetDeviceIDs(platform, CL_DEVICE_TYPE_GPU, 0, NULL, &numDevices); + err = clGetDeviceIDs(platform, CL_DEVICE_TYPE_GPU, numDevices, devices, NULL); + device = devices[CURRENT_DEVICE]; + props[1] = (cl_context_properties)platform; + context = clCreateContext(props, 1, &device, NULL, NULL, &err); + queue = clCreateCommandQueue(context, device, 0, &err); + err = clGetDeviceInfo(device, CL_DEVICE_NAME, MAX_DEVICE_NAME, deviceName, NULL); + checkError(err,__LINE__); + //printf("## %d devices, running on %d: '%s'\n", numDevices, CURRENT_DEVICE, deviceName); + + // Read the kernel file from disk + long sizeHeader, sizeSource; + char* header = readKernelFile(CL_INCLUDE_FILE, &sizeHeader); + char* source = readKernelFile(CL_KERNEL_FILE, &sizeSource); + long size = 2 + sizeHeader + sizeSource; + char* code = (char*)malloc(size*sizeof(char)); + for (int c=0; c 10) { printf("## Compiler message: %s\n", messages); } + free(messages); + + // Retrieve the PTX code from the OpenCL compiler and output it to disk + size_t binSize; + err = clGetProgramInfo(program, CL_PROGRAM_BINARY_SIZES, sizeof(size_t), &binSize, NULL); + checkError(err,__LINE__); + unsigned char *bin = (unsigned char *)malloc(binSize); + err = clGetProgramInfo(program, CL_PROGRAM_BINARIES, sizeof(unsigned char *), &bin, NULL); + checkError(err,__LINE__); + FILE* file = fopen(CL_PTX_FILE, "wb"); + fwrite(bin, sizeof(char), binSize, file); + fclose(file); + free(bin); + + // Prepare OpenCL memory objects + cl_mem bufA = clCreateBuffer(context, CL_MEM_READ_ONLY, M*K*sizeof(*A), NULL, &err); + cl_mem bufB = clCreateBuffer(context, CL_MEM_READ_ONLY, K*N*sizeof(*B), NULL, &err); + cl_mem bufB_TR = clCreateBuffer(context, CL_MEM_READ_ONLY, N*K*sizeof(*B), NULL, &err); + cl_mem bufC = clCreateBuffer(context, CL_MEM_READ_WRITE, M*N*sizeof(*C), NULL, &err); + checkError(err,__LINE__); + + // Copy matrices to the GPU (also C to erase the results of the previous run) + err = clEnqueueWriteBuffer(queue, bufA, CL_TRUE, 0, M*K*sizeof(*A), A, 0, NULL, NULL); + err = clEnqueueWriteBuffer(queue, bufB, CL_TRUE, 0, K*N*sizeof(*B), B, 0, NULL, NULL); + err = clEnqueueWriteBuffer(queue, bufC, CL_TRUE, 0, M*N*sizeof(*C), C, 0, NULL, NULL); + checkError(err,__LINE__); + + // Create extra objects for rounded-up sizes (only needed in case of myGEMM10) + cl_mem bufA_XL = clCreateBuffer(context, CL_MEM_READ_ONLY, M_XL*K_XL*sizeof(*A), NULL, &err); + cl_mem bufB_TR_XL = clCreateBuffer(context, CL_MEM_READ_ONLY, N_XL*K_XL*sizeof(*B), NULL, &err); + cl_mem bufC_XL = clCreateBuffer(context, CL_MEM_READ_WRITE, M_XL*N_XL*sizeof(*C), NULL, &err); + checkError(err,__LINE__); + + // Configure the myGEMM kernel + char kernelname[100]; + sprintf(kernelname, "myGEMM%d", KERNEL); + cl_kernel kernel1 = clCreateKernel(program, kernelname, &err); + checkError(err,__LINE__); + + // Set the arguments of the myGEMM kernel + #if KERNEL == 10 + err = clSetKernelArg(kernel1, 0, sizeof(int), (void*)&M_XL); + err = clSetKernelArg(kernel1, 1, sizeof(int), (void*)&N_XL); + err = clSetKernelArg(kernel1, 2, sizeof(int), (void*)&K_XL); + err = clSetKernelArg(kernel1, 3, sizeof(cl_mem), (void*)&bufA_XL); + err = clSetKernelArg(kernel1, 4, sizeof(cl_mem), (void*)&bufB_TR_XL); + err = clSetKernelArg(kernel1, 5, sizeof(cl_mem), (void*)&bufC_XL); + #else + err = clSetKernelArg(kernel1, 0, sizeof(int), (void*)&M); + err = clSetKernelArg(kernel1, 1, sizeof(int), (void*)&N); + err = clSetKernelArg(kernel1, 2, sizeof(int), (void*)&K); + err = clSetKernelArg(kernel1, 3, sizeof(cl_mem), (void*)&bufA); + #if KERNEL == 5 || KERNEL == 6 || KERNEL == 7 || KERNEL == 8 || KERNEL == 9 + err = clSetKernelArg(kernel1, 4, sizeof(cl_mem), (void*)&bufB_TR); + #else + err = clSetKernelArg(kernel1, 4, sizeof(cl_mem), (void*)&bufB); + #endif + err = clSetKernelArg(kernel1, 5, sizeof(cl_mem), (void*)&bufC); + #endif + checkError(err,__LINE__); + + // Configure the supporting transpose kernel and set its arguments (only for certain myGEMMs) + #if KERNEL == 5 || KERNEL == 6 || KERNEL == 7 || KERNEL == 8 || KERNEL == 9 || KERNEL == 10 + cl_kernel kernel2 = clCreateKernel(program, "transpose", &err); + checkError(err,__LINE__); + err = clSetKernelArg(kernel2, 0, sizeof(int), (void*)&K); + err = clSetKernelArg(kernel2, 1, sizeof(int), (void*)&N); + err = clSetKernelArg(kernel2, 2, sizeof(cl_mem), (void*)&bufB); + err = clSetKernelArg(kernel2, 3, sizeof(cl_mem), (void*)&bufB_TR); + checkError(err,__LINE__); + const size_t tLocal[2] = { TRANSPOSEX, TRANSPOSEY }; + const size_t tGlobal[2] = { (size_t)K, (size_t)N }; + #endif + + // Configure the supporting padding kernels and set their arguments (only for myGEMM10) + #if KERNEL == 10 + cl_kernel kernel3a = clCreateKernel(program, "paddingAddZeroes", &err); + checkError(err,__LINE__); + err = clSetKernelArg(kernel3a, 0, sizeof(int), (void*)&M); + err = clSetKernelArg(kernel3a, 1, sizeof(int), (void*)&K); + err = clSetKernelArg(kernel3a, 2, sizeof(cl_mem), (void*)&bufA); + err = clSetKernelArg(kernel3a, 3, sizeof(int), (void*)&M_XL); + err = clSetKernelArg(kernel3a, 4, sizeof(int), (void*)&K_XL); + err = clSetKernelArg(kernel3a, 5, sizeof(cl_mem), (void*)&bufA_XL); + checkError(err,__LINE__); + cl_kernel kernel3b = clCreateKernel(program, "paddingAddZeroes", &err); + checkError(err,__LINE__); + err = clSetKernelArg(kernel3b, 0, sizeof(int), (void*)&N); + err = clSetKernelArg(kernel3b, 1, sizeof(int), (void*)&K); + err = clSetKernelArg(kernel3b, 2, sizeof(cl_mem), (void*)&bufB_TR); + err = clSetKernelArg(kernel3b, 3, sizeof(int), (void*)&N_XL); + err = clSetKernelArg(kernel3b, 4, sizeof(int), (void*)&K_XL); + err = clSetKernelArg(kernel3b, 5, sizeof(cl_mem), (void*)&bufB_TR_XL); + checkError(err,__LINE__); + cl_kernel kernel3c = clCreateKernel(program, "paddingRemoveZeroes", &err); + checkError(err,__LINE__); + err = clSetKernelArg(kernel3c, 0, sizeof(int), (void*)&M_XL); + err = clSetKernelArg(kernel3c, 1, sizeof(int), (void*)&N_XL); + err = clSetKernelArg(kernel3c, 2, sizeof(cl_mem), (void*)&bufC_XL); + err = clSetKernelArg(kernel3c, 3, sizeof(int), (void*)&M); + err = clSetKernelArg(kernel3c, 4, sizeof(int), (void*)&N); + err = clSetKernelArg(kernel3c, 5, sizeof(cl_mem), (void*)&bufC); + checkError(err,__LINE__); + const size_t pLocal[2] = { PADDINGX, PADDINGY }; + const size_t pAGlobal[2] = { (size_t)M_XL, (size_t)K_XL }; + const size_t pBGlobal[2] = { (size_t)N_XL, (size_t)K_XL }; + const size_t pCGlobal[2] = { (size_t)M, (size_t)N }; + #endif + + // Configure the thread/work-group dimensions of the myGEMM kernel + #if KERNEL == 1 || KERNEL == 2 + const size_t local[2] = { TS, TS }; + const size_t global[2] = { (size_t)M, (size_t)N }; + #elif KERNEL == 3 || KERNEL == 5 + const size_t local[2] = { TS, TS/WPT }; + const size_t global[2] = { (size_t)M, (size_t)(N/WPT) }; + #elif KERNEL == 4 + const size_t local[2] = { TS/WIDTH, TS }; + const size_t global[2] = { (size_t)(M/WIDTH), (size_t)N }; + #elif KERNEL == 6 || KERNEL == 7 || KERNEL == 8 || KERNEL == 9 + const size_t local[2] = { TSM/WPTM, TSN/WPTN }; + const size_t global[2] = { (size_t)(M/WPTM), (size_t)(N/WPTN) }; + #elif KERNEL == 10 + const size_t local[2] = { TSM/WPTM, TSN/WPTN }; + const size_t global[2] = { (size_t)(M_XL/WPTM), (size_t)(N_XL/WPTN) }; + #elif KERNEL == 11 + const size_t local[2] = { THREADSX, THREADSY }; + const size_t global[2] = { (size_t)(M/RX), (size_t)(N/RY) }; + #endif + + // Start the timed loop + // double startTime = opencl_timer(); + for (int r=0; r +// Author......... Cedric Nugteren +// Changed at..... 2014-11-06 +// License........ MIT license +// Tab-size....... 4 spaces +// Line length.... 100 characters +// +// ================================================================================================= +// +// Matrices in column-major format +// A: K columns, M rows +// B: N columns, K rows +// C: N columns, M rows +// +// N +// o-----o +// | | +// K | [B] | +// | | +// o-----o +// K N +// o-------o o-----o +// M | [A] | M | [C] | +// | | | | +// o-------o o-----o +// +// +// C-code for column-major matrix multiplication with alpha=1 and beta=0: +// +// for (int m=0; m +// Author......... Cedric Nugteren +// Changed at..... 2014-11-07 +// License........ MIT license +// Tab-size....... 4 spaces +// Line length.... 100 characters +// +// ================================================================================================= + +// Select a kernel +#define KERNEL 8 + +// Constants for kernels 1 -- 5 +#define TS 32 // The square-root of the 2D tile-size (== work-group dims) + +// Constants for kernels 3, 5 +#define WPT 8 // The amount of work-per-thread, i.e. the thread-coarsening factor +#define RTS (TS/WPT) // The reduced tile-size in one dimension + +// Constants for kernels 4, 7 -- 10 +#define WIDTH 4 // The vector-width (in number of floats) + +// Constants for kernel 5 +#define TSDK 16 // The tile-size in dimension K (for kernel 5 only) +#define LPT ((TSDK*WPT)/(TS)) // The amount of loads-per-thread (assume TSN==TSM) + +// Constants for kernels 6 -- 10 +#define TSM 128 // The tile-size in dimension M +#define TSN 128 // The tile-size in dimension N +#define TSK 16 // The tile-size in dimension K +#define WPTM 8 // The amount of work-per-thread in dimension M +#define WPTN 8 // The amount of work-per-thread in dimension N +#define RTSM (TSM/WPTM) // The reduced tile-size in dimension M (== number of threads) +#define RTSN (TSN/WPTN) // The reduced tile-size in dimension N (== number of threads) +#define LPTA ((TSK*WPTM*WPTN)/(TSN)) // The amount of loads-per-thread for A +#define LPTB ((TSK*WPTM*WPTN)/(TSM)) // The amount of loads-per-thread for B + +// Constraints on settings for kernels 6 -- 10 +// Note: TSM/WPTM has to be integer +// Note: TSN/WPTN has to be integer +// Note: TSM/WIDTH has to be integer +// Note: TSN/WIDTH has to be integer +// Note: (TSK*WPTM*WPTN)/(TSN*WIDTH) has to be integer +// Note: (TSK*WPTM*WPTN)/(TSM*WIDTH) has to be integer + +// Constants for kernel 11 (mimicing clBlas) +#define THREADSX 8 +#define THREADSY 8 +#define RX 8 +#define RY 4 +#define RK (RY) + +// Constants for the supporting transpose kernel +#define TRANSPOSEX 16 +#define TRANSPOSEY 16 + +// Constants for the supporting padding kernels +#define PADDINGX 16 +#define PADDINGY 16 + +// Macros for host and kernel code +#define MIN(a,b) ((a) > (b)) ? (b) : (a) +#define MAX(a,b) ((a) > (b)) ? (a) : (b) +#define CEIL_DIV(x,y) (((x) + (y) - 1) / (y)) +#define MOD2(x,y) ((x) % (y)) +#define DIV2(x,y) ((x) / (y)) + +// ================================================================================================= diff --git a/python/jittor/compile_extern.py b/python/jittor/compile_extern.py index 323739c8..2a69f1b8 100644 --- a/python/jittor/compile_extern.py +++ b/python/jittor/compile_extern.py @@ -148,6 +148,37 @@ def setup_cub(): setup_cuda_lib("cub", link=False, extra_flags=extra_flags) def setup_cuda_extern(): + if os.environ.get("use_opencl","1") == "1": + LOG.vv("setup opencl extern.") + cache_path_opencl = os.path.join(cache_path,"opencl") + opencl_include = os.path.join(jittor_path,"extern","opencl","inc") + opencl_lib = os.path.join(jittor_path,"extern","opencl","lib") + make_cache_dir(cache_path_opencl) + opencl_extern_src = os.path.join(jittor_path,"extern","opencl","src") + from pathlib import Path + jit_path = os.path.join(Path.home(),".cache","jittor","master","g++","jit") + opencl_extern_files = [os.path.join(opencl_extern_src, name) + for name in os.listdir(opencl_extern_src) + ] + so_name = os.path.join(cache_path_opencl,"opencl_extern.so") + os.system(f"cd {jit_path} && cp {opencl_lib}/* .") + # compile(cc_path, cc_flags+f"-I'{opencl_include}' -lOpenCL",opencl_extern_files,so_name) + # output_lib = os.path.join(Path.home(),".cache","jittor","master","g++","opencl") + opencl_extern_op = os.path.join(jittor_path, "extern", "opencl", "ops") + opencl_op_files = [os.path.join(opencl_extern_op, name) for name in os.listdir(opencl_extern_op)] + print("compile ops ",opencl_op_files) + # print(output_lib) + opencllib = compile_custom_ops(opencl_op_files, return_module=True,extra_flags=f"-I'{opencl_include}' -I/usr/local/cuda/include -lOpenCL") + # opencl_ops = opencllib.ops + + ''' + culib = compile_custom_ops(culib_src_files, return_module=True, + extra_flags=f" -I'{jt_cuda_include}' -I'{jt_culib_include}' {link_flags} {extra_flags} ") + culib_ops = culib.ops + globals()[lib_name+"_ops"] = culib_ops + globals()[lib_name] = culib + LOG.vv(f"Get {lib_name}_ops: "+str(dir(culib_ops))) + ''' if not has_cuda: return LOG.vv("setup cuda extern...") cache_path_cuda = os.path.join(cache_path, "cuda") From e390d1dbe5deb13f5d7cbc098d9532aeaca7bb7a Mon Sep 17 00:00:00 2001 From: Exusial Date: Fri, 7 May 2021 09:55:33 +0800 Subject: [PATCH 22/36] update extern,dist. --- doc/source/conf.py | 3 +- python/jittor/__init__.pyi | 164 +++++++++++++++++++++++ python/jittor/compile_extern.py | 70 +++++----- python/jittor/distributions.py | 72 +++++++++- python/jittor/test/test_distributions.py | 29 ++++ test_c.py | 38 ++++++ 6 files changed, 339 insertions(+), 37 deletions(-) create mode 100644 python/jittor/__init__.pyi create mode 100644 test_c.py diff --git a/doc/source/conf.py b/doc/source/conf.py index 541e617a..17e10f7b 100644 --- a/doc/source/conf.py +++ b/doc/source/conf.py @@ -59,7 +59,8 @@ # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. exclude_patterns = [] - +locale_dir = ['locale/'] +gettext_compact = False # -- Options for HTML output ------------------------------------------------- diff --git a/python/jittor/__init__.pyi b/python/jittor/__init__.pyi new file mode 100644 index 00000000..a65f6969 --- /dev/null +++ b/python/jittor/__init__.pyi @@ -0,0 +1,164 @@ +from typing import List, Tuple, Optional, Union, Any, ContextManager, Callable, overload +import builtins +import math +import pickle +def unary(x,op):... +def cast(x,op):... +def bool(x):... +def int8(x):... +def int16(x):... +def int32(x):... +def int64(x):... +def uint8(x):... +def uint16(x):... +def uint32(x):... +def uint64(x):... +def float32(x):... +def float64(x):... +def abs(x):... +def negative(x):... +def logical_not(x):... +def bitwise_not(x):... +def log(x):... +def exp(x):... +def sqrt(x):... +def round(x):... +def floor(x):... +def ceil(x):... +def sin(x):... +def asin(x):... +def arcsin(x):... +def sinh(x):... +def asinh(x):... +def arcsinh(x):... +def tan(x):... +def atan(x):... +def arctan(x):... +def tanh(x):... +def atanh(x):... +def arctanh(x):... +def cos(x):... +def acos(x):... +def arccos(x):... +def cosh(x):... +def acosh(x):... +def arccosh(x):... +def sigmoid(x):... +def erf(x):... +def broadcast(x,shape,dims):... +def broadcast(x,y,dims):... +def broadcast_var(x,y,dims):... +def tape(x):... +def fetch(inputs,func):... +def transpose(x,axes):... +def code(shape,dtype,inputs,cpu_src,cpu_grad_src,cpu_header,cuda_src,cuda_grad_src,cuda_header):... +def code(shapes,dtypes,inputs,cpu_src,cpu_grad_src,cpu_header,cuda_src,cuda_grad_src,cuda_header):... +def code(inputs,outputs,cpu_src,cpu_grad_src,cpu_header,cuda_src,cuda_grad_src,cuda_header):... +def setitem(x,slices,y,op):... +def candidate(x,fail_cond,dtype):... +def getitem(x,slices):... +def random(shape,dtype,type):... +def reindex_reduce(y,op,shape,indexes,overflow_conditions,extras):... +def clone(x):... +def reshape(x,shape):... +def reduce(x,op,dim,keepdims):... +def reduce(x,op,dims,keepdims):... +def max(x,dim,keepdims):... +def max(x,dims,keepdims):... +def max(x,dims_mask,keepdims_mask):... +def reduce_maximum(x,dim,keepdims):... +def reduce_maximum(x,dims,keepdims):... +def reduce_maximum(x,dims_mask,keepdims_mask):... +def min(x,dim,keepdims):... +def min(x,dims,keepdims):... +def min(x,dims_mask,keepdims_mask):... +def reduce_minimum(x,dim,keepdims):... +def reduce_minimum(x,dims,keepdims):... +def reduce_minimum(x,dims_mask,keepdims_mask):... +def sum(x,dim,keepdims):... +def sum(x,dims,keepdims):... +def sum(x,dims_mask,keepdims_mask):... +def reduce_add(x,dim,keepdims):... +def reduce_add(x,dims,keepdims):... +def reduce_add(x,dims_mask,keepdims_mask):... +def prod(x,dim,keepdims):... +def prod(x,dims,keepdims):... +def prod(x,dims_mask,keepdims_mask):... +def product(x,dim,keepdims):... +def product(x,dims,keepdims):... +def product(x,dims_mask,keepdims_mask):... +def reduce_multiply(x,dim,keepdims):... +def reduce_multiply(x,dims,keepdims):... +def reduce_multiply(x,dims_mask,keepdims_mask):... +def reduce_logical_and(x,dim,keepdims):... +def reduce_logical_and(x,dims,keepdims):... +def reduce_logical_and(x,dims_mask,keepdims_mask):... +def all_(x,dim,keepdims):... +def all_(x,dims,keepdims):... +def all_(x,dims_mask,keepdims_mask):... +def reduce_logical_or(x,dim,keepdims):... +def reduce_logical_or(x,dims,keepdims):... +def reduce_logical_or(x,dims_mask,keepdims_mask):... +def any_(x,dim,keepdims):... +def any_(x,dims,keepdims):... +def any_(x,dims_mask,keepdims_mask):... +def reduce_logical_xor(x,dim,keepdims):... +def reduce_logical_xor(x,dims,keepdims):... +def reduce_logical_xor(x,dims_mask,keepdims_mask):... +def reduce_bitwise_and(x,dim,keepdims):... +def reduce_bitwise_and(x,dims,keepdims):... +def reduce_bitwise_and(x,dims_mask,keepdims_mask):... +def reduce_bitwise_or(x,dim,keepdims):... +def reduce_bitwise_or(x,dims,keepdims):... +def reduce_bitwise_or(x,dims_mask,keepdims_mask):... +def reduce_bitwise_xor(x,dim,keepdims):... +def reduce_bitwise_xor(x,dims,keepdims):... +def reduce_bitwise_xor(x,dims_mask,keepdims_mask):... +def mean(x,dim,keepdims):... +def mean(x,dims,keepdims):... +def mean(x,dims_mask,keepdims_mask):... +def copy(x):... +def arg_reduce(x,op,dim,keepdims):... +def argsort(x,dim,descending,dtype):... +def ternary(cond,x,y):... +def binary(x,y,p):... +def pow(x,y):... +def maximum(x,y):... +def minimum(x,y):... +def add(x,y):... +def subtract(x,y):... +def multiply(x,y):... +def divide(x,y):... +def floor_divide(x,y):... +def mod(x,y):... +def less(x,y):... +def less_equal(x,y):... +def greater(x,y):... +def greater_equal(x,y):... +def equal(x,y):... +def not_equal(x,y):... +def left_shift(x,y):... +def right_shift(x,y):... +def logical_and(x,y):... +def logical_or(x,y):... +def logical_xor(x,y):... +def bitwise_and(x,y):... +def bitwise_or(x,y):... +def bitwise_xor(x,y):... +def numpy_code(shape,dtype,inputs,forward,backward):... +def numpy_code(shapes,dtypes,inputs,forward,backward):... +def numpy_code(shape,dtype,inputs,forward):... +def numpy_code(shapes,dtypes,inputs,forward):... +def where(cond,dtype):... +def index(shape,dim,dtype):... +def index(shape,dtype):... +def index(a,dim,dtype):... +def index(a,dtype):... +def index_var(a,dim,dtype):... +def index_var(a,dtype):... +def array_(args):... +def array(obj):... +def reindex(x,shape,indexes,overflow_value,overflow_conditions,extras):... +def reindex(x,indexes,overflow_value,overflow_conditions):... +def reindex_var(x,indexes,overflow_value,overflow_conditions):... +def empty(shape,dtype):... diff --git a/python/jittor/compile_extern.py b/python/jittor/compile_extern.py index 2a69f1b8..2dda7e57 100644 --- a/python/jittor/compile_extern.py +++ b/python/jittor/compile_extern.py @@ -8,6 +8,7 @@ from .compiler import * from jittor_utils import run_cmd, get_version, get_int_version from jittor.utils.misc import download_url_to_local +from jittor import ops def search_file(dirs, name): for d in dirs: @@ -20,8 +21,7 @@ def search_file(dirs, name): def install_mkl(root_folder): # origin url is # url = "https://github.com/intel/mkl-dnn/releases/download/v1.0.2/mkldnn_lnx_1.0.2_cpu_gomp.tgz" - if os.environ.get("use_onednn","1")=="1": - print("get in") + if os.environ.get("use_onednn") == "1": url = "https://cloud.tsinghua.edu.cn/f/cd63e0df3c5c4c52b76d/?dl=1" filename = "oneDNN-2.2-rc.tar.gz" fullname = os.path.join(root_folder, filename) @@ -64,9 +64,8 @@ def setup_mkl(): if not use_mkl: return mkl_include_path = os.environ.get("mkl_include_path") mkl_lib_path = os.environ.get("mkl_lib_path") - if mkl_lib_path is None or mkl_include_path is None: - if os.environ.get("use_onednn","1")=="1": + if os.environ.get("use_onednn","0")=="1": LOG.v("setup onednn...") from pathlib import Path one_path = os.path.join(str(Path.home()),".cache", "jittor", "one") @@ -82,7 +81,6 @@ def setup_mkl(): # mkl_path decouple with cc_path from pathlib import Path mkl_path = os.path.join(str(Path.home()), ".cache", "jittor", "mkl") - make_cache_dir(mkl_path) install_mkl(mkl_path) mkl_home = "" @@ -103,9 +101,6 @@ def setup_mkl(): LOG.v(f"mkl_lib_name: {mkl_lib_name}") # We do not link manualy, link in custom ops # ctypes.CDLL(mkl_lib_name, dlopen_flags) - print(f"mkl_include_path: {mkl_include_path}") - print(f"mkl_lib_path: {mkl_lib_path}") - print(f"mkl_lib_name: {mkl_lib_name}") mkl_op_dir = os.path.join(jittor_path, "extern", "mkl", "ops") mkl_op_files = [os.path.join(mkl_op_dir, name) for name in os.listdir(mkl_op_dir)] mkl_ops = compile_custom_ops(mkl_op_files, @@ -147,30 +142,7 @@ def setup_cub(): cub_home += "/" setup_cuda_lib("cub", link=False, extra_flags=extra_flags) -def setup_cuda_extern(): - if os.environ.get("use_opencl","1") == "1": - LOG.vv("setup opencl extern.") - cache_path_opencl = os.path.join(cache_path,"opencl") - opencl_include = os.path.join(jittor_path,"extern","opencl","inc") - opencl_lib = os.path.join(jittor_path,"extern","opencl","lib") - make_cache_dir(cache_path_opencl) - opencl_extern_src = os.path.join(jittor_path,"extern","opencl","src") - from pathlib import Path - jit_path = os.path.join(Path.home(),".cache","jittor","master","g++","jit") - opencl_extern_files = [os.path.join(opencl_extern_src, name) - for name in os.listdir(opencl_extern_src) - ] - so_name = os.path.join(cache_path_opencl,"opencl_extern.so") - os.system(f"cd {jit_path} && cp {opencl_lib}/* .") - # compile(cc_path, cc_flags+f"-I'{opencl_include}' -lOpenCL",opencl_extern_files,so_name) - # output_lib = os.path.join(Path.home(),".cache","jittor","master","g++","opencl") - opencl_extern_op = os.path.join(jittor_path, "extern", "opencl", "ops") - opencl_op_files = [os.path.join(opencl_extern_op, name) for name in os.listdir(opencl_extern_op)] - print("compile ops ",opencl_op_files) - # print(output_lib) - opencllib = compile_custom_ops(opencl_op_files, return_module=True,extra_flags=f"-I'{opencl_include}' -I/usr/local/cuda/include -lOpenCL") - # opencl_ops = opencllib.ops - +def setup_cuda_extern(): ''' culib = compile_custom_ops(culib_src_files, return_module=True, extra_flags=f" -I'{jt_cuda_include}' -I'{jt_culib_include}' {link_flags} {extra_flags} ") @@ -494,6 +466,39 @@ def inner(self, *args, **kw): if k == "mpi_test": continue setattr(core.Var, k, warper(mpi_ops.__dict__[k])) +def get_pyi(): + f = open(os.path.join(jittor_path,"__init__.pyi"),"w") + # fundamental declaration + f.write("from typing import List, Tuple, Optional, Union, Any, ContextManager, Callable, overload\n") + f.write("import builtins\nimport math\nimport pickle\n") + # for c++ ops + for func_name,func in ops.__dict__.items(): + if func_name == "__doc__" or func_name == "__name__" or func_name == "__loader__" or func_name == "__spec__" or func_name == "__package__": + continue + # print(func_name) + text = func.__doc__ + declarations = re.findall(r"Declaration:\n(.+)\n",text) + # print(declarations) + for decl in declarations: + f.write(f"def {func_name}(") + params = re.findall(r".+ [a-zA-Z_0-9]+\((.+)", decl) + # print(params) + for param in params: + para = param.split(",") + for i,p in enumerate(para): + pa = p.strip().split(" ")[1] + pf = pa.split("=")[0] + # print(pa) + f.write(pf) + if i != len(para) - 1: + f.write(",") + else: + if len(pa.split("=")) > 1: + f.write("):...\n") + else: + f.write(":...\n") + f.close() + setup_mpi() in_mpi = inside_mpi() rank = mpi.world_rank() if in_mpi else 0 @@ -503,3 +508,4 @@ def inner(self, *args, **kw): setup_mkl() setup_cuda_extern() +get_pyi() diff --git a/python/jittor/distributions.py b/python/jittor/distributions.py index abd5269d..a919c99a 100644 --- a/python/jittor/distributions.py +++ b/python/jittor/distributions.py @@ -6,6 +6,8 @@ # This file is subject to the terms and conditions defined in # file 'LICENSE.txt', which is part of this source code package. # *************************************************************** +import math +import numpy as np import jittor as jt def simple_presum(x): @@ -30,8 +32,11 @@ def __init__(self, probs=None, logits=None): if probs is None: # cannot align to pytorch probs = jt.sigmoid(logits) + elif logits is None: + logits = jt.log(probs) with jt.no_grad(): self.probs = probs / probs.sum(-1, True) + self.logits = logits self.cum_probs = simple_presum(probs) self.cum_probs_l = self.cum_probs[..., :-1] self.cum_probs_r = self.cum_probs[..., 1:] @@ -41,12 +46,31 @@ def sample(self, sample_shape=[]): rand = jt.rand(shape) one_hot = jt.logical_and(self.cum_probs_l < rand, rand <= self.cum_probs_r).float() return one_hot - - - + + def log_prob(self,x): + return jt.log(self.probs)[0,x] + + def entropy(self): + min_real = -(math.pow(2,23)-1) / math.pow(2,22) * math.pow(2,127) + logits = jt.clamp(self.logits,min_v=min_real) + p_log_p = logits * self.probs + return -p_log_p.sum(-1) + + class Categorical: def __init__(self, probs=None, logits=None): - OneHotCategorical.__init__(self, probs, logits) + assert not (probs is None and logits is None) + if probs is None: + # cannot align to pytorch + probs = jt.sigmoid(logits) + elif logits is None: + logits = jt.log(probs) + with jt.no_grad(): + self.probs = probs / probs.sum(-1, True) + self.logits = logits + self.cum_probs = simple_presum(probs) + self.cum_probs_l = self.cum_probs[..., :-1] + self.cum_probs_r = self.cum_probs[..., 1:] def sample(self, sample_shape=[]): shape = sample_shape + self.probs.shape[:-1] + (1,) @@ -54,3 +78,43 @@ def sample(self, sample_shape=[]): one_hot = jt.logical_and(self.cum_probs_l < rand, rand <= self.cum_probs_r) index = one_hot.index(one_hot.ndim-1) return (one_hot * index).sum(-1) + + def log_prob(self, x): + return jt.log(self.probs)[0,x] + + def entropy(self): + min_real = -(math.pow(2,23)-1) / math.pow(2,22) * math.pow(2,127) + logits = jt.clamp(self.logits,min_v=min_real) + p_log_p = logits * self.probs + return -p_log_p.sum(-1) + + +class Normal: + def __init__(self,mu,sigma): + self.mu = mu + self.sigma = sigma + + def sample(self,sample_shape): + return jt.normal(mu,sigma,sample_shape) + + def log_prob(self,x): + var = self.sigma**2 + log_scale = jt.log(self.sigma) + return -((x-self.mu)**2) / (2*var) - log_scale-np.log(np.sqrt(2*np.pi)) + + def entropy(self): + return 0.5+0.5*np.log(2*np.pi)+jt.log(self.sigma) + + +def kl_divergence(cur_dist,old_dist): + assert isinstance(cur_dist,type(old_dist)) + if isinstance(cur_dist,Normal): + vr = (cur_dist.sigma / old_dist.sigma)**2 + t1 = ((cur_dist.mu - old_dist.mu) / old_dist.sigma)**2 + return 0.5*(vr+t1-1-jt.log(vr)) + if isinstance(cur_dist,Categorical) or isinstance(cur_dist,OneHotCategorical):# ? + t = cur_dist.probs * (cur_dist.logits-old_dist.logits) + t[jt.array((old_dist.probs == 0))] = math.inf + t[jt.array((cur_dist.probs == 0))] = 0 + return t.sum(-1) + \ No newline at end of file diff --git a/python/jittor/test/test_distributions.py b/python/jittor/test/test_distributions.py index d33701fe..07fb17cc 100644 --- a/python/jittor/test/test_distributions.py +++ b/python/jittor/test/test_distributions.py @@ -10,6 +10,7 @@ # *************************************************************** import unittest import jittor as jt +import torch import numpy as np import jittor.distributions as jd @@ -40,7 +41,35 @@ def test_cate(self): y = a.sample([2,3]) y.sync() assert y.shape == [2,3] + + def test_normal(self): + for _ in range(10): + mu = np.random.uniform(-1,1) + sigma = np.random.uniform(0,2) + jn = jd.Normal(mu,sigma) + tn = torch.distributions.Normal(mu,sigma) + assert np.allclose(jn.entropy().data,tn.entropy().numpy()) + x = np.random.uniform(-1,1) + # print(jn.log_prob(x)) + # print(tn.log_prob(torch.tensor(x))) + assert np.allclose(jn.log_prob(x),tn.log_prob(torch.tensor(x))) + mu2 = np.random.uniform(-1,1) + sigma2 = np.random.uniform(0,2) + jn2 = jd.Normal(mu2,sigma2) + tn2 = torch.distributions.Normal(mu2,sigma2) + assert np.allclose(jd.kl_divergence(jn,jn2).data,torch.distributions.kl_divergence(tn,tn2).numpy()) + def test_categorical(self): + for _ in range(10): + probs,probs2 = np.random.uniform(0,1,(10)), np.random.uniform(0,1,(10)) + probs,probs2 = probs / probs.sum(),probs2 / probs2.sum() + jc, jc2 = jd.Categorical(jt.array(probs).reshape(1,-1)),jd.Categorical(jt.array(probs2).reshape(1,-1)) + tc, tc2 = torch.distributions.Categorical(torch.tensor(probs)),torch.distributions.Categorical(torch.tensor(probs2)) + assert np.allclose(jc.entropy().data,tc.entropy().numpy()) + x = np.random.randint(0,10) + # print(jc.log_prob(x),tc.log_prob(x)) + assert np.allclose(jc.log_prob(x),tc.log_prob(torch.tensor(x))) + assert np.allclose(jd.kl_divergence(jc,jc2),torch.distributions.kl_divergence(tc,tc2)) if __name__ == "__main__": diff --git a/test_c.py b/test_c.py new file mode 100644 index 00000000..2c24c1f1 --- /dev/null +++ b/test_c.py @@ -0,0 +1,38 @@ +import os +import re + +def get_pyi(): + f = open(os.path.join(".","python","jittor","__init__.pyi"),"w") + # fundamental declaration + f.write("from typing import List, Tuple, Optional, Union, Any, ContextManager, Callable, overload\n") + f.write("import builtins\nimport math\nimport pickle\n") + # for c++ ops + for func_name,func in ops.__dict__.items(): + if func_name == "__doc__" or func_name == "__name__" or func_name == "__loader__" or func_name == "__spec__" or func_name == "__package__": + continue + # print(func_name) + text = func.__doc__ + declarations = re.findall(r"Declaration:\n(.+)\n",text) + # print(declarations) + for decl in declarations: + f.write(f"def {func_name}(") + params = re.findall(r".+ [a-zA-Z_0-9]+\((.+)", decl) + # print(params) + for param in params: + para = param.split(",") + for i,p in enumerate(para): + pa = p.strip().split(" ")[1] + pf = pa.split("=")[0] + # print(pa) + f.write(pf) + if i != len(para) - 1: + f.write(",") + else: + if len(pa.split("=")) > 1: + f.write("):...\n") + else: + f.write(":...\n") + f.close() + +if __name__ == "__main__": + get_pyi() \ No newline at end of file From 2a483cfa7821c0b6cdd60788003208e10fc847f2 Mon Sep 17 00:00:00 2001 From: Exusial Date: Fri, 7 May 2021 09:56:00 +0800 Subject: [PATCH 23/36] update extern,dist. --- python/jittor/__init__.pyi | 164 ------------------------------------- test_c.py | 38 --------- 2 files changed, 202 deletions(-) delete mode 100644 python/jittor/__init__.pyi delete mode 100644 test_c.py diff --git a/python/jittor/__init__.pyi b/python/jittor/__init__.pyi deleted file mode 100644 index a65f6969..00000000 --- a/python/jittor/__init__.pyi +++ /dev/null @@ -1,164 +0,0 @@ -from typing import List, Tuple, Optional, Union, Any, ContextManager, Callable, overload -import builtins -import math -import pickle -def unary(x,op):... -def cast(x,op):... -def bool(x):... -def int8(x):... -def int16(x):... -def int32(x):... -def int64(x):... -def uint8(x):... -def uint16(x):... -def uint32(x):... -def uint64(x):... -def float32(x):... -def float64(x):... -def abs(x):... -def negative(x):... -def logical_not(x):... -def bitwise_not(x):... -def log(x):... -def exp(x):... -def sqrt(x):... -def round(x):... -def floor(x):... -def ceil(x):... -def sin(x):... -def asin(x):... -def arcsin(x):... -def sinh(x):... -def asinh(x):... -def arcsinh(x):... -def tan(x):... -def atan(x):... -def arctan(x):... -def tanh(x):... -def atanh(x):... -def arctanh(x):... -def cos(x):... -def acos(x):... -def arccos(x):... -def cosh(x):... -def acosh(x):... -def arccosh(x):... -def sigmoid(x):... -def erf(x):... -def broadcast(x,shape,dims):... -def broadcast(x,y,dims):... -def broadcast_var(x,y,dims):... -def tape(x):... -def fetch(inputs,func):... -def transpose(x,axes):... -def code(shape,dtype,inputs,cpu_src,cpu_grad_src,cpu_header,cuda_src,cuda_grad_src,cuda_header):... -def code(shapes,dtypes,inputs,cpu_src,cpu_grad_src,cpu_header,cuda_src,cuda_grad_src,cuda_header):... -def code(inputs,outputs,cpu_src,cpu_grad_src,cpu_header,cuda_src,cuda_grad_src,cuda_header):... -def setitem(x,slices,y,op):... -def candidate(x,fail_cond,dtype):... -def getitem(x,slices):... -def random(shape,dtype,type):... -def reindex_reduce(y,op,shape,indexes,overflow_conditions,extras):... -def clone(x):... -def reshape(x,shape):... -def reduce(x,op,dim,keepdims):... -def reduce(x,op,dims,keepdims):... -def max(x,dim,keepdims):... -def max(x,dims,keepdims):... -def max(x,dims_mask,keepdims_mask):... -def reduce_maximum(x,dim,keepdims):... -def reduce_maximum(x,dims,keepdims):... -def reduce_maximum(x,dims_mask,keepdims_mask):... -def min(x,dim,keepdims):... -def min(x,dims,keepdims):... -def min(x,dims_mask,keepdims_mask):... -def reduce_minimum(x,dim,keepdims):... -def reduce_minimum(x,dims,keepdims):... -def reduce_minimum(x,dims_mask,keepdims_mask):... -def sum(x,dim,keepdims):... -def sum(x,dims,keepdims):... -def sum(x,dims_mask,keepdims_mask):... -def reduce_add(x,dim,keepdims):... -def reduce_add(x,dims,keepdims):... -def reduce_add(x,dims_mask,keepdims_mask):... -def prod(x,dim,keepdims):... -def prod(x,dims,keepdims):... -def prod(x,dims_mask,keepdims_mask):... -def product(x,dim,keepdims):... -def product(x,dims,keepdims):... -def product(x,dims_mask,keepdims_mask):... -def reduce_multiply(x,dim,keepdims):... -def reduce_multiply(x,dims,keepdims):... -def reduce_multiply(x,dims_mask,keepdims_mask):... -def reduce_logical_and(x,dim,keepdims):... -def reduce_logical_and(x,dims,keepdims):... -def reduce_logical_and(x,dims_mask,keepdims_mask):... -def all_(x,dim,keepdims):... -def all_(x,dims,keepdims):... -def all_(x,dims_mask,keepdims_mask):... -def reduce_logical_or(x,dim,keepdims):... -def reduce_logical_or(x,dims,keepdims):... -def reduce_logical_or(x,dims_mask,keepdims_mask):... -def any_(x,dim,keepdims):... -def any_(x,dims,keepdims):... -def any_(x,dims_mask,keepdims_mask):... -def reduce_logical_xor(x,dim,keepdims):... -def reduce_logical_xor(x,dims,keepdims):... -def reduce_logical_xor(x,dims_mask,keepdims_mask):... -def reduce_bitwise_and(x,dim,keepdims):... -def reduce_bitwise_and(x,dims,keepdims):... -def reduce_bitwise_and(x,dims_mask,keepdims_mask):... -def reduce_bitwise_or(x,dim,keepdims):... -def reduce_bitwise_or(x,dims,keepdims):... -def reduce_bitwise_or(x,dims_mask,keepdims_mask):... -def reduce_bitwise_xor(x,dim,keepdims):... -def reduce_bitwise_xor(x,dims,keepdims):... -def reduce_bitwise_xor(x,dims_mask,keepdims_mask):... -def mean(x,dim,keepdims):... -def mean(x,dims,keepdims):... -def mean(x,dims_mask,keepdims_mask):... -def copy(x):... -def arg_reduce(x,op,dim,keepdims):... -def argsort(x,dim,descending,dtype):... -def ternary(cond,x,y):... -def binary(x,y,p):... -def pow(x,y):... -def maximum(x,y):... -def minimum(x,y):... -def add(x,y):... -def subtract(x,y):... -def multiply(x,y):... -def divide(x,y):... -def floor_divide(x,y):... -def mod(x,y):... -def less(x,y):... -def less_equal(x,y):... -def greater(x,y):... -def greater_equal(x,y):... -def equal(x,y):... -def not_equal(x,y):... -def left_shift(x,y):... -def right_shift(x,y):... -def logical_and(x,y):... -def logical_or(x,y):... -def logical_xor(x,y):... -def bitwise_and(x,y):... -def bitwise_or(x,y):... -def bitwise_xor(x,y):... -def numpy_code(shape,dtype,inputs,forward,backward):... -def numpy_code(shapes,dtypes,inputs,forward,backward):... -def numpy_code(shape,dtype,inputs,forward):... -def numpy_code(shapes,dtypes,inputs,forward):... -def where(cond,dtype):... -def index(shape,dim,dtype):... -def index(shape,dtype):... -def index(a,dim,dtype):... -def index(a,dtype):... -def index_var(a,dim,dtype):... -def index_var(a,dtype):... -def array_(args):... -def array(obj):... -def reindex(x,shape,indexes,overflow_value,overflow_conditions,extras):... -def reindex(x,indexes,overflow_value,overflow_conditions):... -def reindex_var(x,indexes,overflow_value,overflow_conditions):... -def empty(shape,dtype):... diff --git a/test_c.py b/test_c.py deleted file mode 100644 index 2c24c1f1..00000000 --- a/test_c.py +++ /dev/null @@ -1,38 +0,0 @@ -import os -import re - -def get_pyi(): - f = open(os.path.join(".","python","jittor","__init__.pyi"),"w") - # fundamental declaration - f.write("from typing import List, Tuple, Optional, Union, Any, ContextManager, Callable, overload\n") - f.write("import builtins\nimport math\nimport pickle\n") - # for c++ ops - for func_name,func in ops.__dict__.items(): - if func_name == "__doc__" or func_name == "__name__" or func_name == "__loader__" or func_name == "__spec__" or func_name == "__package__": - continue - # print(func_name) - text = func.__doc__ - declarations = re.findall(r"Declaration:\n(.+)\n",text) - # print(declarations) - for decl in declarations: - f.write(f"def {func_name}(") - params = re.findall(r".+ [a-zA-Z_0-9]+\((.+)", decl) - # print(params) - for param in params: - para = param.split(",") - for i,p in enumerate(para): - pa = p.strip().split(" ")[1] - pf = pa.split("=")[0] - # print(pa) - f.write(pf) - if i != len(para) - 1: - f.write(",") - else: - if len(pa.split("=")) > 1: - f.write("):...\n") - else: - f.write(":...\n") - f.close() - -if __name__ == "__main__": - get_pyi() \ No newline at end of file From d4f83b075a6a6599336a33425e89b26f64ea7d08 Mon Sep 17 00:00:00 2001 From: Exusial Date: Fri, 7 May 2021 14:04:01 +0800 Subject: [PATCH 24/36] fix. --- python/jittor/__init__.py | 9 ++++----- python/jittor/distributions.py | 6 +++--- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/python/jittor/__init__.py b/python/jittor/__init__.py index c2ebb980..e97b95d9 100644 --- a/python/jittor/__init__.py +++ b/python/jittor/__init__.py @@ -8,7 +8,7 @@ # This file is subject to the terms and conditions defined in # file 'LICENSE.txt', which is part of this source code package. # *************************************************************** -__version__ = '1.2.2.67' +__version__ = '1.2.2.68' from . import lock with lock.lock_scope(): ori_int = int @@ -976,17 +976,16 @@ def load(self, path: str): This method also supports loading a state dict from a pytorch .pth file. .. note:: - 当载入的参数与模型定义不一致时, jittor 会输出错误信息, 但是不会抛出异常. + 当载入的参数与模型定义不一致时, jittor 会输出错误信? 但是不会抛出异常. 若载入参数出现模型定义中没有的参数名, 则会输出如下信息, 并忽略此参数: >>> [w 0205 21:49:39.962762 96 __init__.py:723] load parameter w failed ... - 若载入参数的 shape 与模型定义不一致, 则会输出如下信息, 并忽略此参数: + 若载入参数的 shape 与模型定义不一? 则会输出如下信息, 并忽略此参数: >>> [e 0205 21:49:39.962822 96 __init__.py:739] load parameter w failed: expect the shape of w to be [1000,100,], but got [3,100,100,] - 如载入过程中出现错误, jittor 会输出概要信息, 您需要仔细核对错误信息 - + 如载入过程中出现错误, jittor 会输出概要信? 您需要仔细核对错误信? >>> [w 0205 21:49:39.962906 96 __init__.py:741] load total 100 params, 3 failed ''' self.load_parameters(load(path)) diff --git a/python/jittor/distributions.py b/python/jittor/distributions.py index a919c99a..054aaa7f 100644 --- a/python/jittor/distributions.py +++ b/python/jittor/distributions.py @@ -17,10 +17,10 @@ def simple_presum(x): void kernel(int n0, int i0, in0_type* x, in0_type* out, int nl) { out[i0*(nl+1)] = 0; for (int i=0; inum/in0->shape[in0->shape.size()-1], 0, in0_p, out0_p, in0->num); +kernel(in0->num/in0->shape[in0->shape.size()-1], 0, in0_p, out0_p, in0->shape[in0->shape.size()-1]); ''' return jt.code(x.shape[:-1]+(x.shape[-1]+1,), x.dtype, [x], cpu_src=src, cuda_src=src) @@ -37,7 +37,7 @@ def __init__(self, probs=None, logits=None): with jt.no_grad(): self.probs = probs / probs.sum(-1, True) self.logits = logits - self.cum_probs = simple_presum(probs) + self.cum_probs = simple_presum(self.probs) self.cum_probs_l = self.cum_probs[..., :-1] self.cum_probs_r = self.cum_probs[..., 1:] From 51a57ab240ac5ab46bd8bc5726fa50317b317179 Mon Sep 17 00:00:00 2001 From: Exusial Date: Fri, 7 May 2021 15:09:18 +0800 Subject: [PATCH 25/36] fix. --- python/jittor/__init__.py | 1 + python/jittor/distributions.py | 3 +-- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/jittor/__init__.py b/python/jittor/__init__.py index e97b95d9..8bd83a14 100644 --- a/python/jittor/__init__.py +++ b/python/jittor/__init__.py @@ -8,6 +8,7 @@ # This file is subject to the terms and conditions defined in # file 'LICENSE.txt', which is part of this source code package. # *************************************************************** + __version__ = '1.2.2.68' from . import lock with lock.lock_scope(): diff --git a/python/jittor/distributions.py b/python/jittor/distributions.py index 054aaa7f..2415fdb5 100644 --- a/python/jittor/distributions.py +++ b/python/jittor/distributions.py @@ -19,7 +19,6 @@ def simple_presum(x): for (int i=0; inum/in0->shape[in0->shape.size()-1], 0, in0_p, out0_p, in0->shape[in0->shape.size()-1]); ''' return jt.code(x.shape[:-1]+(x.shape[-1]+1,), x.dtype, [x], @@ -36,10 +35,10 @@ def __init__(self, probs=None, logits=None): logits = jt.log(probs) with jt.no_grad(): self.probs = probs / probs.sum(-1, True) - self.logits = logits self.cum_probs = simple_presum(self.probs) self.cum_probs_l = self.cum_probs[..., :-1] self.cum_probs_r = self.cum_probs[..., 1:] + self.logits = logits def sample(self, sample_shape=[]): shape = sample_shape + self.probs.shape[:-1] + (1,) From db6a585a3595e97bcd41791a4b13b8af811f71fc Mon Sep 17 00:00:00 2001 From: Exusial Date: Fri, 7 May 2021 15:36:55 +0800 Subject: [PATCH 26/36] delete useless. --- opencl/inc/opencl_matmul_helper.h | 7 - opencl/lib/kernels.cl | 1327 ----------------------------- opencl/lib/settings.h | 74 -- opencl/ops/minimal.cpp | 151 ---- opencl/ops/opencl_matmul_op.cc | 73 -- opencl/ops/opencl_matmul_op.h | 17 - opencl/src/clGEMM.cc | 427 ---------- opencl/src/kernels.cl | 1326 ---------------------------- opencl/src/settings.h | 74 -- 9 files changed, 3476 deletions(-) delete mode 100644 opencl/inc/opencl_matmul_helper.h delete mode 100644 opencl/lib/kernels.cl delete mode 100644 opencl/lib/settings.h delete mode 100644 opencl/ops/minimal.cpp delete mode 100644 opencl/ops/opencl_matmul_op.cc delete mode 100644 opencl/ops/opencl_matmul_op.h delete mode 100644 opencl/src/clGEMM.cc delete mode 100644 opencl/src/kernels.cl delete mode 100644 opencl/src/settings.h diff --git a/opencl/inc/opencl_matmul_helper.h b/opencl/inc/opencl_matmul_helper.h deleted file mode 100644 index d08d68b6..00000000 --- a/opencl/inc/opencl_matmul_helper.h +++ /dev/null @@ -1,7 +0,0 @@ -#pragma once -namespace jittor{ -// ================================================================================================= -void myclblas(float* A, float* B, float* C, - int K, int M, int N); -// ================================================================================================= -} \ No newline at end of file diff --git a/opencl/lib/kernels.cl b/opencl/lib/kernels.cl deleted file mode 100644 index 875ffbc5..00000000 --- a/opencl/lib/kernels.cl +++ /dev/null @@ -1,1327 +0,0 @@ - -// ================================================================================================= -// Project: -// Exploring the performance of general matrix-multiplication on an NVIDIA Tesla K40m GPU. -// -// File information: -// Institution.... SURFsara -// Author......... Cedric Nugteren -// Changed at..... 2014-11-06 -// License........ MIT license -// Tab-size....... 4 spaces -// Line length.... 100 characters -// -// ================================================================================================= -// -// Matrices in column-major format -// A: K columns, M rows -// B: N columns, K rows -// C: N columns, M rows -// -// N -// o-----o -// | | -// K | [B] | -// | | -// o-----o -// K N -// o-------o o-----o -// M | [A] | M | [C] | -// | | | | -// o-------o o-----o -// -// -// C-code for column-major matrix multiplication with alpha=1 and beta=0: -// -// for (int m=0; m -// Author......... Cedric Nugteren -// Changed at..... 2014-11-07 -// License........ MIT license -// Tab-size....... 4 spaces -// Line length.... 100 characters -// -// ================================================================================================= - -// Select a kernel -#define KERNEL 8 - -// Constants for kernels 1 -- 5 -#define TS 32 // The square-root of the 2D tile-size (== work-group dims) - -// Constants for kernels 3, 5 -#define WPT 8 // The amount of work-per-thread, i.e. the thread-coarsening factor -#define RTS (TS/WPT) // The reduced tile-size in one dimension - -// Constants for kernels 4, 7 -- 10 -#define WIDTH 4 // The vector-width (in number of floats) - -// Constants for kernel 5 -#define TSDK 16 // The tile-size in dimension K (for kernel 5 only) -#define LPT ((TSDK*WPT)/(TS)) // The amount of loads-per-thread (assume TSN==TSM) - -// Constants for kernels 6 -- 10 -#define TSM 128 // The tile-size in dimension M -#define TSN 128 // The tile-size in dimension N -#define TSK 16 // The tile-size in dimension K -#define WPTM 8 // The amount of work-per-thread in dimension M -#define WPTN 8 // The amount of work-per-thread in dimension N -#define RTSM (TSM/WPTM) // The reduced tile-size in dimension M (== number of threads) -#define RTSN (TSN/WPTN) // The reduced tile-size in dimension N (== number of threads) -#define LPTA ((TSK*WPTM*WPTN)/(TSN)) // The amount of loads-per-thread for A -#define LPTB ((TSK*WPTM*WPTN)/(TSM)) // The amount of loads-per-thread for B - -// Constraints on settings for kernels 6 -- 10 -// Note: TSM/WPTM has to be integer -// Note: TSN/WPTN has to be integer -// Note: TSM/WIDTH has to be integer -// Note: TSN/WIDTH has to be integer -// Note: (TSK*WPTM*WPTN)/(TSN*WIDTH) has to be integer -// Note: (TSK*WPTM*WPTN)/(TSM*WIDTH) has to be integer - -// Constants for kernel 11 (mimicing clBlas) -#define THREADSX 8 -#define THREADSY 8 -#define RX 8 -#define RY 4 -#define RK (RY) - -// Constants for the supporting transpose kernel -#define TRANSPOSEX 16 -#define TRANSPOSEY 16 - -// Constants for the supporting padding kernels -#define PADDINGX 16 -#define PADDINGY 16 - -// Macros for host and kernel code -#define MIN(a,b) ((a) > (b)) ? (b) : (a) -#define MAX(a,b) ((a) > (b)) ? (a) : (b) -#define CEIL_DIV(x,y) (((x) + (y) - 1) / (y)) -#define MOD2(x,y) ((x) % (y)) -#define DIV2(x,y) ((x) / (y)) - -// ================================================================================================= diff --git a/opencl/ops/minimal.cpp b/opencl/ops/minimal.cpp deleted file mode 100644 index 151af872..00000000 --- a/opencl/ops/minimal.cpp +++ /dev/null @@ -1,151 +0,0 @@ - -// ================================================================================================= -// Project: -// Exploring the performance of general matrix-multiplication on an NVIDIA Tesla K40m GPU. -// -// File information: -// Institution.... SURFsara -// Author......... Cedric Nugteren -// Changed at..... 2014-11-07 -// License........ MIT license -// Tab-size....... 4 spaces -// Line length.... 100 characters -// -// Compilation example: -// g++ -O3 -I$OPENCL_DIR/include minimal.cpp -o minimal -lOpenCL -// -// ================================================================================================= - -// Includes -#include -#include -#include -#include "opencl_matmul_helper.h" -// ================================================================================================= - -// Repeat all kernels multiple times to get an average timing result -#define NUM_RUNS 1 - -// Size of the matrices - K, M, N (squared) -#define SIZE 64 - -// Threadblock sizes (e.g. for kernels myGEMM1 or myGEMM2) -#define TS 1 - -// ================================================================================================= - -// Set the kernel as a string (better to do this in a separate file though) -namespace jittor{ -const char *kernelstring = - "__kernel void myGEMM1(const int M, const int N, const int K," - " const __global float* A," - " const __global float* B," - " __global float* C) {" - " const int globalRow = get_global_id(0);" - " const int globalCol = get_global_id(1);" - " float acc = 0.0f;" - " for (int k=0; k>> Initializing OpenCL...\n"); - cl_platform_id platform = 0; - clGetPlatformIDs(1, &platform, NULL); - cl_device_id device = 0; - clGetDeviceIDs(platform, CL_DEVICE_TYPE_GPU, 1, &device, NULL); - cl_context context = clCreateContext(NULL, 1, &device, NULL, NULL, NULL); - cl_command_queue queue = clCreateCommandQueue(context, device, 0, NULL); - char deviceName[1024]; - clGetDeviceInfo(device, CL_DEVICE_NAME, 1024, deviceName, NULL); - cl_event event = NULL; - - // Compile the kernel - cl_program program = clCreateProgramWithSource(context, 1, &kernelstring, NULL, NULL); - clBuildProgram(program, 0, NULL, "", NULL, NULL); - // Check for compilation errors - size_t logSize; - clGetProgramBuildInfo(program, device, CL_PROGRAM_BUILD_LOG, 0, NULL, &logSize); - char* messages = (char*)malloc((1+logSize)*sizeof(char)); - clGetProgramBuildInfo(program, device, CL_PROGRAM_BUILD_LOG, logSize, messages, NULL); - messages[logSize] = '\0'; - if (logSize > 10) { printf(">>> Compiler message: %s\n", messages); } - free(messages); - - // Prepare OpenCL memory objects - cl_mem bufA = clCreateBuffer(context, CL_MEM_READ_ONLY, M*K*sizeof(float), NULL, NULL); - cl_mem bufB = clCreateBuffer(context, CL_MEM_READ_ONLY, K*N*sizeof(float), NULL, NULL); - cl_mem bufC = clCreateBuffer(context, CL_MEM_READ_WRITE, M*N*sizeof(float), NULL, NULL); - - // Copy matrices to the GPU - clEnqueueWriteBuffer(queue, bufA, CL_TRUE, 0, M*K*sizeof(float), A, 0, NULL, NULL); - clEnqueueWriteBuffer(queue, bufB, CL_TRUE, 0, K*N*sizeof(float), B, 0, NULL, NULL); - clEnqueueWriteBuffer(queue, bufC, CL_TRUE, 0, M*N*sizeof(float), C, 0, NULL, NULL); - - // Configure the myGEMM kernel and set its arguments - cl_kernel kernel = clCreateKernel(program, "myGEMM1", NULL); - clSetKernelArg(kernel, 0, sizeof(int), (void*)&M); - clSetKernelArg(kernel, 1, sizeof(int), (void*)&N); - clSetKernelArg(kernel, 2, sizeof(int), (void*)&K); - clSetKernelArg(kernel, 3, sizeof(cl_mem), (void*)&bufA); - clSetKernelArg(kernel, 4, sizeof(cl_mem), (void*)&bufB); - clSetKernelArg(kernel, 5, sizeof(cl_mem), (void*)&bufC); - - // Start the timed loop - printf(">>> Starting %d myGEMM runs...\n", NUM_RUNS); - gettimeofday(&Tvalue, &dummy); - double starttime = (double)Tvalue.tv_sec + 1.0e-6*((double)Tvalue.tv_usec); - for (int r=0; r>> Done: took %.3lf seconds per run, %.1lf GFLOPS\n", runtime, gflop/runtime); - - // Copy the output matrix C back to the CPU memory - clEnqueueReadBuffer(queue, bufC, CL_TRUE, 0, M*N*sizeof(float), C, 0, NULL, NULL); - for(int i=0;i -// Dun Liang . -// -// This file is subject to the terms and conditions defined in -// file 'LICENSE.txt', which is part of this source code package. -// *************************************************************** - -#include "var.h" -#include "opencl_matmul_helper.h" -#include "opencl_matmul_op.h" -#include "common.h" -using namespace std; - -namespace jittor { - -#ifndef JIT - -OpenclMatmulOp::OpenclMatmulOp(Var* a, Var* b, bool trans_a, bool trans_b) - : a(a), b(b), trans_a(trans_a), trans_b(trans_b) { - // TODO: support int8 * int8 - ASSERT(a->dtype().is_float() && b->dtype().is_float()) << "type of two inputs should be the same"; - // TODO: support diffrent input type - ASSERT(a->dtype().dsize() == b->dtype().dsize()) << "type of two inputs should be the same"; - c = create_output(nullptr, a->dtype()); -} - -void OpenclMatmulOp::infer_shape() { - ASSERTop(a->shape.size(),==,2); - ASSERTop(b->shape.size(),==,2); - int n = a->shape[0], m = a->shape[1]; - int m_ = b->shape[0], k = b->shape[1]; - if (trans_a) { - swap(n, m); - } - if (trans_b) { - swap(m_, k); - } - ASSERTop(m,==,m_); - c->set_shape({n, k}); -} - -void OpenclMatmulOp::jit_prepare(JK& jk) { - jk << _CS("[T:") << a->dtype(); - jk << _CS("][Trans_a:") << (trans_a ? 'T' : 'N'); - jk << _CS("][Trans_b:") << (trans_b ? 'T' : 'N'); - jk << _CS("][op:") << (a->dtype().dsize() == 4 ? 'S' : 'D'); - jk << ']'; -} - -#else // JIT -#pragma clang diagnostic ignored "-Wtautological-compare" -void OpenclMatmulOp::jit_run() { - const auto& as = a->shape; - const auto& bs = b->shape; - auto n = as[0]; - auto m = as[1]; - auto k = bs[1]; - if ('@Trans_a'=='T') { - n = as[1]; - m = as[0]; - } - if ('@Trans_b'=='T') { - k = bs[0]; - } - // a: [n,m], b: [m,k], c: [n,k] - myclblas(a->ptr(),b->ptr(),c->ptr(),k,m,n); -} -#endif // JIT - -} // jittor diff --git a/opencl/ops/opencl_matmul_op.h b/opencl/ops/opencl_matmul_op.h deleted file mode 100644 index fd506193..00000000 --- a/opencl/ops/opencl_matmul_op.h +++ /dev/null @@ -1,17 +0,0 @@ -#pragma once -#include "op.h" - -namespace jittor { - -struct OpenclMatmulOp : Op { - Var* a, * b, * c; - bool trans_a, trans_b; - OpenclMatmulOp(Var* a, Var* b, bool trans_a, bool trans_b); - - const char* name() const override { return "opencl_matmul"; } - void infer_shape() override; - DECLARE_jit_run; -}; - -} // jittor - diff --git a/opencl/src/clGEMM.cc b/opencl/src/clGEMM.cc deleted file mode 100644 index e874b067..00000000 --- a/opencl/src/clGEMM.cc +++ /dev/null @@ -1,427 +0,0 @@ - -// ================================================================================================= -// Project: -// Exploring the performance of general matrix-multiplication on an NVIDIA Tesla K40m GPU. -// -// File information: -// Institution.... SURFsara -// Author......... Cedric Nugteren -// Changed at..... 2014-11-17 -// License........ MIT license -// Tab-size....... 4 spaces -// Line length.... 100 characters -// -// ================================================================================================= - -#include -#include "opencl_matmul_helper.h" -#include -#include -// Set the locations of the OpenCL kernel files -#define CL_INCLUDE_FILE "settings.h" -#define CL_KERNEL_FILE "kernels.cl" -// Determine the location where to output the PTX code -#define CL_PTX_FILE "bin/myGEMM.cl.ptx" - -// Define OpenCL compiler options, such as "-cl-nv-maxrregcount=127" -#define COMPILER_OPTIONS "" - -namespace jittor{ -// Forward declaration of the OpenCL error checking function -void checkError(cl_int error, int line); - -// ================================================================================================= -// ================================================================================================= - -// Matrix-multiplication using a custom OpenCL SGEMM kernel. This function also copies the input -// matrices to the GPU, runs SGEMM, and copies the output matrix back to the CPU. -void myclblas(float* A, float* B, float* C, - int K, int M, int N) { - - // In case of myGEMM10, compute matrix sizes K, M, N as rounded-up to form complete tiles - #if KERNEL == 10 - int K_XL = CEIL_DIV(K, TSK) * TSK; - int M_XL = CEIL_DIV(M, TSM) * TSM; - int N_XL = CEIL_DIV(N, TSN) * TSN; - #else - int K_XL = K; - int M_XL = M; - int N_XL = N; - #endif - - // Define OpenCL variables - cl_int err; - cl_platform_id platform = 0; - cl_device_id device = 0; - cl_device_id devices[MAX_NUM_DEVICES]; - cl_uint numDevices = 0; - cl_context_properties props[3] = {CL_CONTEXT_PLATFORM, 0, 0}; - cl_context context = 0; - cl_command_queue queue = 0; - cl_event event = NULL; - cl_program program = NULL; - char deviceName[MAX_DEVICE_NAME]; - - // Configure the OpenCL environment - err = clGetPlatformIDs(1, &platform, NULL); - err = clGetDeviceIDs(platform, CL_DEVICE_TYPE_GPU, 0, NULL, &numDevices); - err = clGetDeviceIDs(platform, CL_DEVICE_TYPE_GPU, numDevices, devices, NULL); - device = devices[CURRENT_DEVICE]; - props[1] = (cl_context_properties)platform; - context = clCreateContext(props, 1, &device, NULL, NULL, &err); - queue = clCreateCommandQueue(context, device, 0, &err); - err = clGetDeviceInfo(device, CL_DEVICE_NAME, MAX_DEVICE_NAME, deviceName, NULL); - checkError(err,__LINE__); - //printf("## %d devices, running on %d: '%s'\n", numDevices, CURRENT_DEVICE, deviceName); - - // Read the kernel file from disk - long sizeHeader, sizeSource; - char* header = readKernelFile(CL_INCLUDE_FILE, &sizeHeader); - char* source = readKernelFile(CL_KERNEL_FILE, &sizeSource); - long size = 2 + sizeHeader + sizeSource; - char* code = (char*)malloc(size*sizeof(char)); - for (int c=0; c 10) { printf("## Compiler message: %s\n", messages); } - free(messages); - - // Retrieve the PTX code from the OpenCL compiler and output it to disk - size_t binSize; - err = clGetProgramInfo(program, CL_PROGRAM_BINARY_SIZES, sizeof(size_t), &binSize, NULL); - checkError(err,__LINE__); - unsigned char *bin = (unsigned char *)malloc(binSize); - err = clGetProgramInfo(program, CL_PROGRAM_BINARIES, sizeof(unsigned char *), &bin, NULL); - checkError(err,__LINE__); - FILE* file = fopen(CL_PTX_FILE, "wb"); - fwrite(bin, sizeof(char), binSize, file); - fclose(file); - free(bin); - - // Prepare OpenCL memory objects - cl_mem bufA = clCreateBuffer(context, CL_MEM_READ_ONLY, M*K*sizeof(*A), NULL, &err); - cl_mem bufB = clCreateBuffer(context, CL_MEM_READ_ONLY, K*N*sizeof(*B), NULL, &err); - cl_mem bufB_TR = clCreateBuffer(context, CL_MEM_READ_ONLY, N*K*sizeof(*B), NULL, &err); - cl_mem bufC = clCreateBuffer(context, CL_MEM_READ_WRITE, M*N*sizeof(*C), NULL, &err); - checkError(err,__LINE__); - - // Copy matrices to the GPU (also C to erase the results of the previous run) - err = clEnqueueWriteBuffer(queue, bufA, CL_TRUE, 0, M*K*sizeof(*A), A, 0, NULL, NULL); - err = clEnqueueWriteBuffer(queue, bufB, CL_TRUE, 0, K*N*sizeof(*B), B, 0, NULL, NULL); - err = clEnqueueWriteBuffer(queue, bufC, CL_TRUE, 0, M*N*sizeof(*C), C, 0, NULL, NULL); - checkError(err,__LINE__); - - // Create extra objects for rounded-up sizes (only needed in case of myGEMM10) - cl_mem bufA_XL = clCreateBuffer(context, CL_MEM_READ_ONLY, M_XL*K_XL*sizeof(*A), NULL, &err); - cl_mem bufB_TR_XL = clCreateBuffer(context, CL_MEM_READ_ONLY, N_XL*K_XL*sizeof(*B), NULL, &err); - cl_mem bufC_XL = clCreateBuffer(context, CL_MEM_READ_WRITE, M_XL*N_XL*sizeof(*C), NULL, &err); - checkError(err,__LINE__); - - // Configure the myGEMM kernel - char kernelname[100]; - sprintf(kernelname, "myGEMM%d", KERNEL); - cl_kernel kernel1 = clCreateKernel(program, kernelname, &err); - checkError(err,__LINE__); - - // Set the arguments of the myGEMM kernel - #if KERNEL == 10 - err = clSetKernelArg(kernel1, 0, sizeof(int), (void*)&M_XL); - err = clSetKernelArg(kernel1, 1, sizeof(int), (void*)&N_XL); - err = clSetKernelArg(kernel1, 2, sizeof(int), (void*)&K_XL); - err = clSetKernelArg(kernel1, 3, sizeof(cl_mem), (void*)&bufA_XL); - err = clSetKernelArg(kernel1, 4, sizeof(cl_mem), (void*)&bufB_TR_XL); - err = clSetKernelArg(kernel1, 5, sizeof(cl_mem), (void*)&bufC_XL); - #else - err = clSetKernelArg(kernel1, 0, sizeof(int), (void*)&M); - err = clSetKernelArg(kernel1, 1, sizeof(int), (void*)&N); - err = clSetKernelArg(kernel1, 2, sizeof(int), (void*)&K); - err = clSetKernelArg(kernel1, 3, sizeof(cl_mem), (void*)&bufA); - #if KERNEL == 5 || KERNEL == 6 || KERNEL == 7 || KERNEL == 8 || KERNEL == 9 - err = clSetKernelArg(kernel1, 4, sizeof(cl_mem), (void*)&bufB_TR); - #else - err = clSetKernelArg(kernel1, 4, sizeof(cl_mem), (void*)&bufB); - #endif - err = clSetKernelArg(kernel1, 5, sizeof(cl_mem), (void*)&bufC); - #endif - checkError(err,__LINE__); - - // Configure the supporting transpose kernel and set its arguments (only for certain myGEMMs) - #if KERNEL == 5 || KERNEL == 6 || KERNEL == 7 || KERNEL == 8 || KERNEL == 9 || KERNEL == 10 - cl_kernel kernel2 = clCreateKernel(program, "transpose", &err); - checkError(err,__LINE__); - err = clSetKernelArg(kernel2, 0, sizeof(int), (void*)&K); - err = clSetKernelArg(kernel2, 1, sizeof(int), (void*)&N); - err = clSetKernelArg(kernel2, 2, sizeof(cl_mem), (void*)&bufB); - err = clSetKernelArg(kernel2, 3, sizeof(cl_mem), (void*)&bufB_TR); - checkError(err,__LINE__); - const size_t tLocal[2] = { TRANSPOSEX, TRANSPOSEY }; - const size_t tGlobal[2] = { (size_t)K, (size_t)N }; - #endif - - // Configure the supporting padding kernels and set their arguments (only for myGEMM10) - #if KERNEL == 10 - cl_kernel kernel3a = clCreateKernel(program, "paddingAddZeroes", &err); - checkError(err,__LINE__); - err = clSetKernelArg(kernel3a, 0, sizeof(int), (void*)&M); - err = clSetKernelArg(kernel3a, 1, sizeof(int), (void*)&K); - err = clSetKernelArg(kernel3a, 2, sizeof(cl_mem), (void*)&bufA); - err = clSetKernelArg(kernel3a, 3, sizeof(int), (void*)&M_XL); - err = clSetKernelArg(kernel3a, 4, sizeof(int), (void*)&K_XL); - err = clSetKernelArg(kernel3a, 5, sizeof(cl_mem), (void*)&bufA_XL); - checkError(err,__LINE__); - cl_kernel kernel3b = clCreateKernel(program, "paddingAddZeroes", &err); - checkError(err,__LINE__); - err = clSetKernelArg(kernel3b, 0, sizeof(int), (void*)&N); - err = clSetKernelArg(kernel3b, 1, sizeof(int), (void*)&K); - err = clSetKernelArg(kernel3b, 2, sizeof(cl_mem), (void*)&bufB_TR); - err = clSetKernelArg(kernel3b, 3, sizeof(int), (void*)&N_XL); - err = clSetKernelArg(kernel3b, 4, sizeof(int), (void*)&K_XL); - err = clSetKernelArg(kernel3b, 5, sizeof(cl_mem), (void*)&bufB_TR_XL); - checkError(err,__LINE__); - cl_kernel kernel3c = clCreateKernel(program, "paddingRemoveZeroes", &err); - checkError(err,__LINE__); - err = clSetKernelArg(kernel3c, 0, sizeof(int), (void*)&M_XL); - err = clSetKernelArg(kernel3c, 1, sizeof(int), (void*)&N_XL); - err = clSetKernelArg(kernel3c, 2, sizeof(cl_mem), (void*)&bufC_XL); - err = clSetKernelArg(kernel3c, 3, sizeof(int), (void*)&M); - err = clSetKernelArg(kernel3c, 4, sizeof(int), (void*)&N); - err = clSetKernelArg(kernel3c, 5, sizeof(cl_mem), (void*)&bufC); - checkError(err,__LINE__); - const size_t pLocal[2] = { PADDINGX, PADDINGY }; - const size_t pAGlobal[2] = { (size_t)M_XL, (size_t)K_XL }; - const size_t pBGlobal[2] = { (size_t)N_XL, (size_t)K_XL }; - const size_t pCGlobal[2] = { (size_t)M, (size_t)N }; - #endif - - // Configure the thread/work-group dimensions of the myGEMM kernel - #if KERNEL == 1 || KERNEL == 2 - const size_t local[2] = { TS, TS }; - const size_t global[2] = { (size_t)M, (size_t)N }; - #elif KERNEL == 3 || KERNEL == 5 - const size_t local[2] = { TS, TS/WPT }; - const size_t global[2] = { (size_t)M, (size_t)(N/WPT) }; - #elif KERNEL == 4 - const size_t local[2] = { TS/WIDTH, TS }; - const size_t global[2] = { (size_t)(M/WIDTH), (size_t)N }; - #elif KERNEL == 6 || KERNEL == 7 || KERNEL == 8 || KERNEL == 9 - const size_t local[2] = { TSM/WPTM, TSN/WPTN }; - const size_t global[2] = { (size_t)(M/WPTM), (size_t)(N/WPTN) }; - #elif KERNEL == 10 - const size_t local[2] = { TSM/WPTM, TSN/WPTN }; - const size_t global[2] = { (size_t)(M_XL/WPTM), (size_t)(N_XL/WPTN) }; - #elif KERNEL == 11 - const size_t local[2] = { THREADSX, THREADSY }; - const size_t global[2] = { (size_t)(M/RX), (size_t)(N/RY) }; - #endif - - // Start the timed loop - // double startTime = opencl_timer(); - for (int r=0; r -// Author......... Cedric Nugteren -// Changed at..... 2014-11-06 -// License........ MIT license -// Tab-size....... 4 spaces -// Line length.... 100 characters -// -// ================================================================================================= -// -// Matrices in column-major format -// A: K columns, M rows -// B: N columns, K rows -// C: N columns, M rows -// -// N -// o-----o -// | | -// K | [B] | -// | | -// o-----o -// K N -// o-------o o-----o -// M | [A] | M | [C] | -// | | | | -// o-------o o-----o -// -// -// C-code for column-major matrix multiplication with alpha=1 and beta=0: -// -// for (int m=0; m -// Author......... Cedric Nugteren -// Changed at..... 2014-11-07 -// License........ MIT license -// Tab-size....... 4 spaces -// Line length.... 100 characters -// -// ================================================================================================= - -// Select a kernel -#define KERNEL 8 - -// Constants for kernels 1 -- 5 -#define TS 32 // The square-root of the 2D tile-size (== work-group dims) - -// Constants for kernels 3, 5 -#define WPT 8 // The amount of work-per-thread, i.e. the thread-coarsening factor -#define RTS (TS/WPT) // The reduced tile-size in one dimension - -// Constants for kernels 4, 7 -- 10 -#define WIDTH 4 // The vector-width (in number of floats) - -// Constants for kernel 5 -#define TSDK 16 // The tile-size in dimension K (for kernel 5 only) -#define LPT ((TSDK*WPT)/(TS)) // The amount of loads-per-thread (assume TSN==TSM) - -// Constants for kernels 6 -- 10 -#define TSM 128 // The tile-size in dimension M -#define TSN 128 // The tile-size in dimension N -#define TSK 16 // The tile-size in dimension K -#define WPTM 8 // The amount of work-per-thread in dimension M -#define WPTN 8 // The amount of work-per-thread in dimension N -#define RTSM (TSM/WPTM) // The reduced tile-size in dimension M (== number of threads) -#define RTSN (TSN/WPTN) // The reduced tile-size in dimension N (== number of threads) -#define LPTA ((TSK*WPTM*WPTN)/(TSN)) // The amount of loads-per-thread for A -#define LPTB ((TSK*WPTM*WPTN)/(TSM)) // The amount of loads-per-thread for B - -// Constraints on settings for kernels 6 -- 10 -// Note: TSM/WPTM has to be integer -// Note: TSN/WPTN has to be integer -// Note: TSM/WIDTH has to be integer -// Note: TSN/WIDTH has to be integer -// Note: (TSK*WPTM*WPTN)/(TSN*WIDTH) has to be integer -// Note: (TSK*WPTM*WPTN)/(TSM*WIDTH) has to be integer - -// Constants for kernel 11 (mimicing clBlas) -#define THREADSX 8 -#define THREADSY 8 -#define RX 8 -#define RY 4 -#define RK (RY) - -// Constants for the supporting transpose kernel -#define TRANSPOSEX 16 -#define TRANSPOSEY 16 - -// Constants for the supporting padding kernels -#define PADDINGX 16 -#define PADDINGY 16 - -// Macros for host and kernel code -#define MIN(a,b) ((a) > (b)) ? (b) : (a) -#define MAX(a,b) ((a) > (b)) ? (a) : (b) -#define CEIL_DIV(x,y) (((x) + (y) - 1) / (y)) -#define MOD2(x,y) ((x) % (y)) -#define DIV2(x,y) ((x) / (y)) - -// ================================================================================================= From 93263964da0278992932d2313ecc063dfa43aade Mon Sep 17 00:00:00 2001 From: Exusial Date: Fri, 7 May 2021 23:59:20 +0800 Subject: [PATCH 27/36] add test. --- python/jittor/distributions.py | 24 +++++++++++++++++++++++- python/jittor/test/test_distributions.py | 13 +++++++++++++ 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/python/jittor/distributions.py b/python/jittor/distributions.py index 2415fdb5..208c7049 100644 --- a/python/jittor/distributions.py +++ b/python/jittor/distributions.py @@ -105,6 +105,24 @@ def entropy(self): return 0.5+0.5*np.log(2*np.pi)+jt.log(self.sigma) +class Uniform: + def __init__(self,low,high): + self.low = low + self.high = high + assert high > low + + def sample(self,sample_shape): + return jt.uniform(low,high,sample_shape) + + def log_prob(self,x): + if x < low or x >= high: + return math.inf + return -jt.log(self.high - self.low) + + def entropy(self): + return jt.log(self.high - self.low) + + def kl_divergence(cur_dist,old_dist): assert isinstance(cur_dist,type(old_dist)) if isinstance(cur_dist,Normal): @@ -116,4 +134,8 @@ def kl_divergence(cur_dist,old_dist): t[jt.array((old_dist.probs == 0))] = math.inf t[jt.array((cur_dist.probs == 0))] = 0 return t.sum(-1) - \ No newline at end of file + if isinstance(cur_dist,Uniform): + res = jt.log((old_dist.high - old_dist.low) / (cur_dist.high - cur_dist.low)) + if old_dist.low > cur_dist.low or old_dist.high < cur_dist.high: + res = math.inf + return res diff --git a/python/jittor/test/test_distributions.py b/python/jittor/test/test_distributions.py index 07fb17cc..dcf71529 100644 --- a/python/jittor/test/test_distributions.py +++ b/python/jittor/test/test_distributions.py @@ -70,6 +70,19 @@ def test_categorical(self): # print(jc.log_prob(x),tc.log_prob(x)) assert np.allclose(jc.log_prob(x),tc.log_prob(torch.tensor(x))) assert np.allclose(jd.kl_divergence(jc,jc2),torch.distributions.kl_divergence(tc,tc2)) + + def test_uniform(self): + for _ in range(4): + low, low2 = np.random.ranint(-1,2), np.random.ranint(-1,2) + leng, leng2 = np.random.uniform(0,2), np.random.uniform(0,2) + high, high2 = low + leng, low2 + leng2 + ju, ju2 = jd.Uniform(low,high),jd.Uniform(low2,high2) + tu, tu2 = torch.distributions.Categorical(low,high),torch.distributions.Categorical(low2,high2) + assert np.allclose(ju.entropy().data,tu.entropy().numpy()) + x = np.random.uniform(low,high) + # print(jc.log_prob(x),tc.log_prob(x)) + assert np.allclose(ju.log_prob(x),tu.log_prob(torch.tensor(x))) + assert np.allclose(jd.kl_divergence(ju,ju2),torch.distributions.kl_divergence(tu,tu2)) if __name__ == "__main__": From c6347415678464eb3e4ee58e36a64033b873092c Mon Sep 17 00:00:00 2001 From: Exusial Date: Sat, 8 May 2021 00:10:31 +0800 Subject: [PATCH 28/36] merge. --- python/jittor/test/test_distributions.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/jittor/test/test_distributions.py b/python/jittor/test/test_distributions.py index e04a3abf..4c21f171 100644 --- a/python/jittor/test/test_distributions.py +++ b/python/jittor/test/test_distributions.py @@ -71,8 +71,9 @@ def test_categorical(self): # print(jc.log_prob(x),tc.log_prob(x)) assert np.allclose(jc.log_prob(x),tc.log_prob(torch.tensor(x))) assert np.allclose(jd.kl_divergence(jc,jc2),torch.distributions.kl_divergence(tc,tc2)) - + def test_uniform(self): + import torch for _ in range(4): low, low2 = np.random.ranint(-1,2), np.random.ranint(-1,2) leng, leng2 = np.random.uniform(0,2), np.random.uniform(0,2) From 62e6f81af41ce3acac44355d0127dbc0cbba1156 Mon Sep 17 00:00:00 2001 From: Exusial Date: Sat, 8 May 2021 09:37:29 +0800 Subject: [PATCH 29/36] fix. --- python/jittor/__init__.py | 8 +- python/jittor/__init__.pyi | 164 +++++++++++++++++++++++ python/jittor/distributions.py | 4 +- python/jittor/test/test_distributions.py | 4 +- 4 files changed, 172 insertions(+), 8 deletions(-) create mode 100644 python/jittor/__init__.pyi diff --git a/python/jittor/__init__.py b/python/jittor/__init__.py index a8a2a34e..87532387 100644 --- a/python/jittor/__init__.py +++ b/python/jittor/__init__.py @@ -8,7 +8,6 @@ # This file is subject to the terms and conditions defined in # file 'LICENSE.txt', which is part of this source code package. # *************************************************************** - __version__ = '1.2.2.69' from . import lock with lock.lock_scope(): @@ -977,16 +976,17 @@ def load(self, path: str): This method also supports loading a state dict from a pytorch .pth file. .. note:: - 当载入的参数与模型定义不一致时, jittor 会输出错误信? 但是不会抛出异常. + 当载入的参数与模型定义不一致时, jittor 会输出错误信息, 但是不会抛出异常. 若载入参数出现模型定义中没有的参数名, 则会输出如下信息, 并忽略此参数: >>> [w 0205 21:49:39.962762 96 __init__.py:723] load parameter w failed ... - 若载入参数的 shape 与模型定义不一? 则会输出如下信息, 并忽略此参数: + 若载入参数的 shape 与模型定义不一致, 则会输出如下信息, 并忽略此参数: >>> [e 0205 21:49:39.962822 96 __init__.py:739] load parameter w failed: expect the shape of w to be [1000,100,], but got [3,100,100,] - 如载入过程中出现错误, jittor 会输出概要信? 您需要仔细核对错误信? + 如载入过程中出现错误, jittor 会输出概要信息, 您需要仔细核对错误信息 + >>> [w 0205 21:49:39.962906 96 __init__.py:741] load total 100 params, 3 failed ''' self.load_parameters(load(path)) diff --git a/python/jittor/__init__.pyi b/python/jittor/__init__.pyi new file mode 100644 index 00000000..a65f6969 --- /dev/null +++ b/python/jittor/__init__.pyi @@ -0,0 +1,164 @@ +from typing import List, Tuple, Optional, Union, Any, ContextManager, Callable, overload +import builtins +import math +import pickle +def unary(x,op):... +def cast(x,op):... +def bool(x):... +def int8(x):... +def int16(x):... +def int32(x):... +def int64(x):... +def uint8(x):... +def uint16(x):... +def uint32(x):... +def uint64(x):... +def float32(x):... +def float64(x):... +def abs(x):... +def negative(x):... +def logical_not(x):... +def bitwise_not(x):... +def log(x):... +def exp(x):... +def sqrt(x):... +def round(x):... +def floor(x):... +def ceil(x):... +def sin(x):... +def asin(x):... +def arcsin(x):... +def sinh(x):... +def asinh(x):... +def arcsinh(x):... +def tan(x):... +def atan(x):... +def arctan(x):... +def tanh(x):... +def atanh(x):... +def arctanh(x):... +def cos(x):... +def acos(x):... +def arccos(x):... +def cosh(x):... +def acosh(x):... +def arccosh(x):... +def sigmoid(x):... +def erf(x):... +def broadcast(x,shape,dims):... +def broadcast(x,y,dims):... +def broadcast_var(x,y,dims):... +def tape(x):... +def fetch(inputs,func):... +def transpose(x,axes):... +def code(shape,dtype,inputs,cpu_src,cpu_grad_src,cpu_header,cuda_src,cuda_grad_src,cuda_header):... +def code(shapes,dtypes,inputs,cpu_src,cpu_grad_src,cpu_header,cuda_src,cuda_grad_src,cuda_header):... +def code(inputs,outputs,cpu_src,cpu_grad_src,cpu_header,cuda_src,cuda_grad_src,cuda_header):... +def setitem(x,slices,y,op):... +def candidate(x,fail_cond,dtype):... +def getitem(x,slices):... +def random(shape,dtype,type):... +def reindex_reduce(y,op,shape,indexes,overflow_conditions,extras):... +def clone(x):... +def reshape(x,shape):... +def reduce(x,op,dim,keepdims):... +def reduce(x,op,dims,keepdims):... +def max(x,dim,keepdims):... +def max(x,dims,keepdims):... +def max(x,dims_mask,keepdims_mask):... +def reduce_maximum(x,dim,keepdims):... +def reduce_maximum(x,dims,keepdims):... +def reduce_maximum(x,dims_mask,keepdims_mask):... +def min(x,dim,keepdims):... +def min(x,dims,keepdims):... +def min(x,dims_mask,keepdims_mask):... +def reduce_minimum(x,dim,keepdims):... +def reduce_minimum(x,dims,keepdims):... +def reduce_minimum(x,dims_mask,keepdims_mask):... +def sum(x,dim,keepdims):... +def sum(x,dims,keepdims):... +def sum(x,dims_mask,keepdims_mask):... +def reduce_add(x,dim,keepdims):... +def reduce_add(x,dims,keepdims):... +def reduce_add(x,dims_mask,keepdims_mask):... +def prod(x,dim,keepdims):... +def prod(x,dims,keepdims):... +def prod(x,dims_mask,keepdims_mask):... +def product(x,dim,keepdims):... +def product(x,dims,keepdims):... +def product(x,dims_mask,keepdims_mask):... +def reduce_multiply(x,dim,keepdims):... +def reduce_multiply(x,dims,keepdims):... +def reduce_multiply(x,dims_mask,keepdims_mask):... +def reduce_logical_and(x,dim,keepdims):... +def reduce_logical_and(x,dims,keepdims):... +def reduce_logical_and(x,dims_mask,keepdims_mask):... +def all_(x,dim,keepdims):... +def all_(x,dims,keepdims):... +def all_(x,dims_mask,keepdims_mask):... +def reduce_logical_or(x,dim,keepdims):... +def reduce_logical_or(x,dims,keepdims):... +def reduce_logical_or(x,dims_mask,keepdims_mask):... +def any_(x,dim,keepdims):... +def any_(x,dims,keepdims):... +def any_(x,dims_mask,keepdims_mask):... +def reduce_logical_xor(x,dim,keepdims):... +def reduce_logical_xor(x,dims,keepdims):... +def reduce_logical_xor(x,dims_mask,keepdims_mask):... +def reduce_bitwise_and(x,dim,keepdims):... +def reduce_bitwise_and(x,dims,keepdims):... +def reduce_bitwise_and(x,dims_mask,keepdims_mask):... +def reduce_bitwise_or(x,dim,keepdims):... +def reduce_bitwise_or(x,dims,keepdims):... +def reduce_bitwise_or(x,dims_mask,keepdims_mask):... +def reduce_bitwise_xor(x,dim,keepdims):... +def reduce_bitwise_xor(x,dims,keepdims):... +def reduce_bitwise_xor(x,dims_mask,keepdims_mask):... +def mean(x,dim,keepdims):... +def mean(x,dims,keepdims):... +def mean(x,dims_mask,keepdims_mask):... +def copy(x):... +def arg_reduce(x,op,dim,keepdims):... +def argsort(x,dim,descending,dtype):... +def ternary(cond,x,y):... +def binary(x,y,p):... +def pow(x,y):... +def maximum(x,y):... +def minimum(x,y):... +def add(x,y):... +def subtract(x,y):... +def multiply(x,y):... +def divide(x,y):... +def floor_divide(x,y):... +def mod(x,y):... +def less(x,y):... +def less_equal(x,y):... +def greater(x,y):... +def greater_equal(x,y):... +def equal(x,y):... +def not_equal(x,y):... +def left_shift(x,y):... +def right_shift(x,y):... +def logical_and(x,y):... +def logical_or(x,y):... +def logical_xor(x,y):... +def bitwise_and(x,y):... +def bitwise_or(x,y):... +def bitwise_xor(x,y):... +def numpy_code(shape,dtype,inputs,forward,backward):... +def numpy_code(shapes,dtypes,inputs,forward,backward):... +def numpy_code(shape,dtype,inputs,forward):... +def numpy_code(shapes,dtypes,inputs,forward):... +def where(cond,dtype):... +def index(shape,dim,dtype):... +def index(shape,dtype):... +def index(a,dim,dtype):... +def index(a,dtype):... +def index_var(a,dim,dtype):... +def index_var(a,dtype):... +def array_(args):... +def array(obj):... +def reindex(x,shape,indexes,overflow_value,overflow_conditions,extras):... +def reindex(x,indexes,overflow_value,overflow_conditions):... +def reindex_var(x,indexes,overflow_value,overflow_conditions):... +def empty(shape,dtype):... diff --git a/python/jittor/distributions.py b/python/jittor/distributions.py index 29191103..320a5ec2 100644 --- a/python/jittor/distributions.py +++ b/python/jittor/distributions.py @@ -113,10 +113,10 @@ def __init__(self,low,high): assert high > low def sample(self,sample_shape): - return jt.uniform(low,high,sample_shape) + return jt.uniform(self.low,self.high,sample_shape) def log_prob(self,x): - if x < low or x >= high: + if x < self.low or x >= self.high: return math.inf return -jt.log(self.high - self.low) diff --git a/python/jittor/test/test_distributions.py b/python/jittor/test/test_distributions.py index 4c21f171..544e636e 100644 --- a/python/jittor/test/test_distributions.py +++ b/python/jittor/test/test_distributions.py @@ -75,11 +75,11 @@ def test_categorical(self): def test_uniform(self): import torch for _ in range(4): - low, low2 = np.random.ranint(-1,2), np.random.ranint(-1,2) + low, low2 = np.random.randint(-1,2), np.random.randint(-1,2) leng, leng2 = np.random.uniform(0,2), np.random.uniform(0,2) high, high2 = low + leng, low2 + leng2 ju, ju2 = jd.Uniform(low,high),jd.Uniform(low2,high2) - tu, tu2 = torch.distributions.Categorical(low,high),torch.distributions.Categorical(low2,high2) + tu, tu2 = torch.distributions.Uniform(low,high),torch.distributions.Uniform(low2,high2) assert np.allclose(ju.entropy().data,tu.entropy().numpy()) x = np.random.uniform(low,high) # print(jc.log_prob(x),tc.log_prob(x)) From 1b87e7328eb23e956bda9e62e86672b2e46131cf Mon Sep 17 00:00:00 2001 From: Exusial Date: Mon, 10 May 2021 19:52:17 +0800 Subject: [PATCH 30/36] add finfo. --- python/jittor/__init__.py | 15 +++++++++++++++ python/jittor/distributions.py | 23 +++++++++++++++++++++++ 2 files changed, 38 insertions(+) diff --git a/python/jittor/__init__.py b/python/jittor/__init__.py index 87532387..95825d38 100644 --- a/python/jittor/__init__.py +++ b/python/jittor/__init__.py @@ -1256,6 +1256,21 @@ def get_len(var): Var.__module__ = "jittor" Var.__reduce__ = lambda self: (Var, (self.data,)) +class finfo: + def __init__(self,dtype=Var.float): + if dtype == "float32": + self.bits = 32 + self.eps = math.pow(2,-23) + self.max = (math.pow(2,23)-1) / math.pow(2,22) * math.pow(2,127) + self.min = -self.max + self.tiny = math.pow(2,-126) + elif dtype == "float64": + self.bits = 64 + self.eps = math.pow(2,-52) + self.max = (math.pow(2,52)-1) / math.pow(2,51) * math.pow(2,1023) + self.min = -self.max + self.tiny = math.pow(2,-1022) + from . import nn from . import attention from . import lr_scheduler diff --git a/python/jittor/distributions.py b/python/jittor/distributions.py index 320a5ec2..ad8554ac 100644 --- a/python/jittor/distributions.py +++ b/python/jittor/distributions.py @@ -124,6 +124,29 @@ def entropy(self): return jt.log(self.high - self.low) +class Geometric: + def __init__(self,p=None,logits=None): + assert (p is not None) or (logits is not None) + assert 0 < p and p < 1 + if p is None: + self.prob = jt.pow(e,logits) + self.logits = logits + else: + self.prob = p + self.logits = jt.log(p) + + def sample(self,sample_shape): + tiny = jt.info(self.probs.dtype).tiny + u = jt.clamp(jt.rand(sample_shape),min_v=tiny) + return (jt.log(u) / (jt.log(-self.probs+1))).floor() + + def log_prob(self,x): + pass + + def entropy(self): + pass + + def kl_divergence(cur_dist,old_dist): assert isinstance(cur_dist,type(old_dist)) if isinstance(cur_dist,Normal): From 1ef727eee156cc0aae56d8ff974be685c70e93b6 Mon Sep 17 00:00:00 2001 From: Exusial Date: Mon, 10 May 2021 20:48:52 +0800 Subject: [PATCH 31/36] add gemotric --- python/jittor/distributions.py | 6 ++++-- python/jittor/test/test_distributions.py | 12 ++++++++++++ 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/python/jittor/distributions.py b/python/jittor/distributions.py index ad8554ac..d207ae8d 100644 --- a/python/jittor/distributions.py +++ b/python/jittor/distributions.py @@ -141,10 +141,10 @@ def sample(self,sample_shape): return (jt.log(u) / (jt.log(-self.probs+1))).floor() def log_prob(self,x): - pass + return x*jt.log(-self.prob+1)+jt.log(self.prob) def entropy(self): - pass + binary_cross_entropy_with_logits(self.prob.self.logits) def kl_divergence(cur_dist,old_dist): @@ -163,3 +163,5 @@ def kl_divergence(cur_dist,old_dist): if old_dist.low > cur_dist.low or old_dist.high < cur_dist.high: res = math.inf return res + if isinstance(cur_dist,Geometric): + return -cur_dist.entropy() - jt.log(-q.probs+1) / cur_dist.probs - old_dist.logits diff --git a/python/jittor/test/test_distributions.py b/python/jittor/test/test_distributions.py index 544e636e..c4ce2e1a 100644 --- a/python/jittor/test/test_distributions.py +++ b/python/jittor/test/test_distributions.py @@ -85,6 +85,18 @@ def test_uniform(self): # print(jc.log_prob(x),tc.log_prob(x)) assert np.allclose(ju.log_prob(x),tu.log_prob(torch.tensor(x))) assert np.allclose(jd.kl_divergence(ju,ju2),torch.distributions.kl_divergence(tu,tu2)) + + def test_uniform(self): + import torch + for _ in range(4): + prob, prob2 = np.random.uniform(0,1), np.random.uniform(0,1) + jg, jg2 = jd.Geometric(prob),jd.Geometric(prob2) + tg, tg2 = torch.distributions.Geometric(prob),torch.distributions.Geometric(prob2) + assert np.allclose(jg.entropy().data,tg.entropy().numpy()) + x = np.random.randint(1,10) + # print(jc.log_prob(x),tc.log_prob(x)) + assert np.allclose(jg.log_prob(x),tg.log_prob(torch.tensor(x))) + assert np.allclose(jd.kl_divergence(jg,jg2),torch.distributions.kl_divergence(tg,tg2)) if __name__ == "__main__": unittest.main() \ No newline at end of file From 57c18e75250991c022a0d1607960dcef268cef8c Mon Sep 17 00:00:00 2001 From: Exusial Date: Tue, 11 May 2021 12:03:28 +0800 Subject: [PATCH 32/36] fix onehot. --- python/jittor/distributions.py | 17 +++++++++-------- python/jittor/test/test_distributions.py | 22 +++++++++++++++------- 2 files changed, 24 insertions(+), 15 deletions(-) diff --git a/python/jittor/distributions.py b/python/jittor/distributions.py index d207ae8d..cb566bab 100644 --- a/python/jittor/distributions.py +++ b/python/jittor/distributions.py @@ -10,7 +10,7 @@ import math import numpy as np import jittor as jt - +from jittor.nn import binary_cross_entropy_with_logits def simple_presum(x): src = ''' __inline_static__ @@ -48,7 +48,8 @@ def sample(self, sample_shape=[]): return one_hot def log_prob(self,x): - return jt.log(self.probs)[0,x] + indices = jt.argmax(x,0)[0].int() + return jt.log(self.probs)[0,indices] def entropy(self): min_real = -(math.pow(2,23)-1) / math.pow(2,22) * math.pow(2,127) @@ -129,11 +130,11 @@ def __init__(self,p=None,logits=None): assert (p is not None) or (logits is not None) assert 0 < p and p < 1 if p is None: - self.prob = jt.pow(e,logits) + self.prob = jt.sigmoid(logits) self.logits = logits - else: + elif logits is None: self.prob = p - self.logits = jt.log(p) + self.logits = -jt.log(1. / p - 1) def sample(self,sample_shape): tiny = jt.info(self.probs.dtype).tiny @@ -144,7 +145,7 @@ def log_prob(self,x): return x*jt.log(-self.prob+1)+jt.log(self.prob) def entropy(self): - binary_cross_entropy_with_logits(self.prob.self.logits) + return binary_cross_entropy_with_logits(jt.array(self.logits),jt.array(self.prob)) / self.prob def kl_divergence(cur_dist,old_dist): @@ -153,7 +154,7 @@ def kl_divergence(cur_dist,old_dist): vr = (cur_dist.sigma / old_dist.sigma)**2 t1 = ((cur_dist.mu - old_dist.mu) / old_dist.sigma)**2 return 0.5*(vr+t1-1-jt.log(vr)) - if isinstance(cur_dist,Categorical) or isinstance(cur_dist,OneHotCategorical):# ? + if isinstance(cur_dist,Categorical) or isinstance(cur_dist,OneHotCategorical): t = cur_dist.probs * (cur_dist.logits-old_dist.logits) t[jt.array((old_dist.probs == 0))] = math.inf t[jt.array((cur_dist.probs == 0))] = 0 @@ -164,4 +165,4 @@ def kl_divergence(cur_dist,old_dist): res = math.inf return res if isinstance(cur_dist,Geometric): - return -cur_dist.entropy() - jt.log(-q.probs+1) / cur_dist.probs - old_dist.logits + return -cur_dist.entropy() - jt.log(-old_dist.prob+1) / cur_dist.prob - old_dist.logits diff --git a/python/jittor/test/test_distributions.py b/python/jittor/test/test_distributions.py index c4ce2e1a..0fdd7b2f 100644 --- a/python/jittor/test/test_distributions.py +++ b/python/jittor/test/test_distributions.py @@ -25,11 +25,21 @@ def test_one_hot(self): x = a.sample().numpy() for i in range(1000): x += a.sample().numpy() - print(x) assert (x > 200).all() y = a.sample([2,3]) y.sync() assert y.shape == [2,3,4] + probs,probs2 = np.random.uniform(0,1,(10)), np.random.uniform(0,1,(10)) + probs,probs2 = probs / probs.sum(),probs2 / probs2.sum() + import torch + jc, jc2 = jd.OneHotCategorical(jt.array(probs).reshape(1,-1)),jd.OneHotCategorical(jt.array(probs2).reshape(1,-1)) + tc, tc2 = torch.distributions.OneHotCategorical(torch.tensor(probs)),torch.distributions.OneHotCategorical(torch.tensor(probs2)) + assert np.allclose(jc.entropy().data,tc.entropy().numpy()) + nx = np.random.randint(0,9) + x = np.zeros((10)) + x[nx] = 1 + assert np.allclose(jc.log_prob(jt.array(x)),tc.log_prob(torch.tensor(x))) + assert np.allclose(jd.kl_divergence(jc,jc2),torch.distributions.kl_divergence(tc,tc2)) def test_cate(self): a = jd.Categorical(jt.array([0.25, 0.25, 0.25, 0.25])) @@ -43,7 +53,7 @@ def test_cate(self): def test_normal(self): import torch - for _ in range(10): + for _ in range(4): mu = np.random.uniform(-1,1) sigma = np.random.uniform(0,2) jn = jd.Normal(mu,sigma) @@ -61,14 +71,13 @@ def test_normal(self): def test_categorical(self): import torch - for _ in range(10): + for _ in range(4): probs,probs2 = np.random.uniform(0,1,(10)), np.random.uniform(0,1,(10)) probs,probs2 = probs / probs.sum(),probs2 / probs2.sum() jc, jc2 = jd.Categorical(jt.array(probs).reshape(1,-1)),jd.Categorical(jt.array(probs2).reshape(1,-1)) tc, tc2 = torch.distributions.Categorical(torch.tensor(probs)),torch.distributions.Categorical(torch.tensor(probs2)) assert np.allclose(jc.entropy().data,tc.entropy().numpy()) x = np.random.randint(0,10) - # print(jc.log_prob(x),tc.log_prob(x)) assert np.allclose(jc.log_prob(x),tc.log_prob(torch.tensor(x))) assert np.allclose(jd.kl_divergence(jc,jc2),torch.distributions.kl_divergence(tc,tc2)) @@ -82,11 +91,10 @@ def test_uniform(self): tu, tu2 = torch.distributions.Uniform(low,high),torch.distributions.Uniform(low2,high2) assert np.allclose(ju.entropy().data,tu.entropy().numpy()) x = np.random.uniform(low,high) - # print(jc.log_prob(x),tc.log_prob(x)) assert np.allclose(ju.log_prob(x),tu.log_prob(torch.tensor(x))) assert np.allclose(jd.kl_divergence(ju,ju2),torch.distributions.kl_divergence(tu,tu2)) - def test_uniform(self): + def test_geometric(self): import torch for _ in range(4): prob, prob2 = np.random.uniform(0,1), np.random.uniform(0,1) @@ -94,8 +102,8 @@ def test_uniform(self): tg, tg2 = torch.distributions.Geometric(prob),torch.distributions.Geometric(prob2) assert np.allclose(jg.entropy().data,tg.entropy().numpy()) x = np.random.randint(1,10) - # print(jc.log_prob(x),tc.log_prob(x)) assert np.allclose(jg.log_prob(x),tg.log_prob(torch.tensor(x))) + # print(jd.kl_divergence(jg,jg2),torch.distributions.kl_divergence(tg,tg2)) assert np.allclose(jd.kl_divergence(jg,jg2),torch.distributions.kl_divergence(tg,tg2)) if __name__ == "__main__": From 94868a1865c219d31fd2e5c5878ed951688df261 Mon Sep 17 00:00:00 2001 From: Exusial Date: Wed, 12 May 2021 20:33:57 +0800 Subject: [PATCH 33/36] fix. --- python/jittor/compile_extern.py | 136 ++++++----------------- python/jittor/distributions.py | 11 +- python/jittor/test/test_distributions.py | 11 +- 3 files changed, 45 insertions(+), 113 deletions(-) diff --git a/python/jittor/compile_extern.py b/python/jittor/compile_extern.py index 2dda7e57..68d68274 100644 --- a/python/jittor/compile_extern.py +++ b/python/jittor/compile_extern.py @@ -8,7 +8,6 @@ from .compiler import * from jittor_utils import run_cmd, get_version, get_int_version from jittor.utils.misc import download_url_to_local -from jittor import ops def search_file(dirs, name): for d in dirs: @@ -21,41 +20,21 @@ def search_file(dirs, name): def install_mkl(root_folder): # origin url is # url = "https://github.com/intel/mkl-dnn/releases/download/v1.0.2/mkldnn_lnx_1.0.2_cpu_gomp.tgz" - if os.environ.get("use_onednn") == "1": - url = "https://cloud.tsinghua.edu.cn/f/cd63e0df3c5c4c52b76d/?dl=1" - filename = "oneDNN-2.2-rc.tar.gz" - fullname = os.path.join(root_folder, filename) - dirname = os.path.join(root_folder, filename.replace(".tar.gz","")) - download_url_to_local(url, filename, root_folder, "fd6e22bb49dedcf0430495098b3dcf1f") + url = "https://cloud.tsinghua.edu.cn/f/da02bf62b55b4aa3b8ee/?dl=1" + filename = "mkldnn_lnx_1.0.2_cpu_gomp.tgz" + fullname = os.path.join(root_folder, filename) + dirname = os.path.join(root_folder, filename.replace(".tgz","")) + + if not os.path.isfile(os.path.join(dirname, "examples", "test")): + LOG.i("Downloading mkl...") + download_url_to_local(url, filename, root_folder, "47187284ede27ad3bd64b5f0e7d5e730") import tarfile + with tarfile.open(fullname, "r") as tar: tar.extractall(root_folder) - import platform - if platform.machine() == "aarch64": - os.system(f"cd {dirname} && mkdir -p build && cd build && export CC=aarch64-linux-gnu-gcc && export CXX=aarch64-linux-gnu-g++ && cmake .. \ - -DCMAKE_SYSTEM_NAME=Linux \ - -DCMAKE_SYSTEM_PROCESSOR=AARCH64 \ - -DCMAKE_LIBRARY_PATH=/usr/aarch64-linux-gnu/lib \ - -D CMAKE_INSTALL_PREFIX={root_folder} && make -j && make install") - else: - os.system(f"cd {dirname} && mkdir -p build && cd build && cmake -D CMAKE_INSTALL_PREFIX={root_folder} .. && make -j && make install") - # TODO add completition test. - else: - url = "https://cloud.tsinghua.edu.cn/f/da02bf62b55b4aa3b8ee/?dl=1" - filename = "mkldnn_lnx_1.0.2_cpu_gomp.tgz" - fullname = os.path.join(root_folder, filename) - dirname = os.path.join(root_folder, filename.replace(".tgz","")) - - if not os.path.isfile(os.path.join(dirname, "examples", "test")): - LOG.i("Downloading mkl...") - download_url_to_local(url, filename, root_folder, "47187284ede27ad3bd64b5f0e7d5e730") - import tarfile - - with tarfile.open(fullname, "r") as tar: - tar.extractall(root_folder) - - assert 0 == os.system(f"cd {dirname}/examples && " - f"{cc_path} -std=c++14 cpu_cnn_inference_f32.cpp -Ofast -lmkldnn -I ../include -L ../lib -o test && LD_LIBRARY_PATH=../lib/ ./test") + + assert 0 == os.system(f"cd {dirname}/examples && " + f"{cc_path} -std=c++14 cpu_cnn_inference_f32.cpp -Ofast -lmkldnn -I ../include -L ../lib -o test && LD_LIBRARY_PATH=../lib/ ./test") def setup_mkl(): global mkl_ops, use_mkl @@ -64,33 +43,25 @@ def setup_mkl(): if not use_mkl: return mkl_include_path = os.environ.get("mkl_include_path") mkl_lib_path = os.environ.get("mkl_lib_path") + if mkl_lib_path is None or mkl_include_path is None: - if os.environ.get("use_onednn","0")=="1": - LOG.v("setup onednn...") - from pathlib import Path - one_path = os.path.join(str(Path.home()),".cache", "jittor", "one") - make_cache_dir(one_path) - mkl_include_path = os.path.join(one_path,"include") - mkl_lib_path = os.path.join(one_path,"lib") - if not os.path.isdir(mkl_include_path) or not os.path.isdir(mkl_lib_path) or not os.path.isfile(os.path.join(mkl_lib_path, "libmkldnn.so")): - install_mkl(one_path) - else: - mkl_install_sh = os.path.join(jittor_path, "script", "install_mkl.sh") - LOG.v("setup mkl...") - # mkl_path = os.path.join(cache_path, "mkl") - # mkl_path decouple with cc_path - from pathlib import Path - mkl_path = os.path.join(str(Path.home()), ".cache", "jittor", "mkl") - make_cache_dir(mkl_path) - install_mkl(mkl_path) - mkl_home = "" - for name in os.listdir(mkl_path): - if name.startswith("mkldnn_lnx") and os.path.isdir(os.path.join(mkl_path, name)): - mkl_home = os.path.join(mkl_path, name) - break - assert mkl_home!="" - mkl_include_path = os.path.join(mkl_home, "include") - mkl_lib_path = os.path.join(mkl_home, "lib") + mkl_install_sh = os.path.join(jittor_path, "script", "install_mkl.sh") + LOG.v("setup mkl...") + # mkl_path = os.path.join(cache_path, "mkl") + # mkl_path decouple with cc_path + from pathlib import Path + mkl_path = os.path.join(str(Path.home()), ".cache", "jittor", "mkl") + + make_cache_dir(mkl_path) + install_mkl(mkl_path) + mkl_home = "" + for name in os.listdir(mkl_path): + if name.startswith("mkldnn_lnx") and os.path.isdir(os.path.join(mkl_path, name)): + mkl_home = os.path.join(mkl_path, name) + break + assert mkl_home!="" + mkl_include_path = os.path.join(mkl_home, "include") + mkl_lib_path = os.path.join(mkl_home, "lib") mkl_lib_name = os.path.join(mkl_lib_path, "libmkldnn.so") assert os.path.isdir(mkl_include_path) @@ -101,6 +72,7 @@ def setup_mkl(): LOG.v(f"mkl_lib_name: {mkl_lib_name}") # We do not link manualy, link in custom ops # ctypes.CDLL(mkl_lib_name, dlopen_flags) + mkl_op_dir = os.path.join(jittor_path, "extern", "mkl", "ops") mkl_op_files = [os.path.join(mkl_op_dir, name) for name in os.listdir(mkl_op_dir)] mkl_ops = compile_custom_ops(mkl_op_files, @@ -142,15 +114,7 @@ def setup_cub(): cub_home += "/" setup_cuda_lib("cub", link=False, extra_flags=extra_flags) -def setup_cuda_extern(): - ''' - culib = compile_custom_ops(culib_src_files, return_module=True, - extra_flags=f" -I'{jt_cuda_include}' -I'{jt_culib_include}' {link_flags} {extra_flags} ") - culib_ops = culib.ops - globals()[lib_name+"_ops"] = culib_ops - globals()[lib_name] = culib - LOG.vv(f"Get {lib_name}_ops: "+str(dir(culib_ops))) - ''' +def setup_cuda_extern(): if not has_cuda: return LOG.vv("setup cuda extern...") cache_path_cuda = os.path.join(cache_path, "cuda") @@ -466,39 +430,6 @@ def inner(self, *args, **kw): if k == "mpi_test": continue setattr(core.Var, k, warper(mpi_ops.__dict__[k])) -def get_pyi(): - f = open(os.path.join(jittor_path,"__init__.pyi"),"w") - # fundamental declaration - f.write("from typing import List, Tuple, Optional, Union, Any, ContextManager, Callable, overload\n") - f.write("import builtins\nimport math\nimport pickle\n") - # for c++ ops - for func_name,func in ops.__dict__.items(): - if func_name == "__doc__" or func_name == "__name__" or func_name == "__loader__" or func_name == "__spec__" or func_name == "__package__": - continue - # print(func_name) - text = func.__doc__ - declarations = re.findall(r"Declaration:\n(.+)\n",text) - # print(declarations) - for decl in declarations: - f.write(f"def {func_name}(") - params = re.findall(r".+ [a-zA-Z_0-9]+\((.+)", decl) - # print(params) - for param in params: - para = param.split(",") - for i,p in enumerate(para): - pa = p.strip().split(" ")[1] - pf = pa.split("=")[0] - # print(pa) - f.write(pf) - if i != len(para) - 1: - f.write(",") - else: - if len(pa.split("=")) > 1: - f.write("):...\n") - else: - f.write(":...\n") - f.close() - setup_mpi() in_mpi = inside_mpi() rank = mpi.world_rank() if in_mpi else 0 @@ -507,5 +438,4 @@ def get_pyi(): setup_cutt() setup_mkl() -setup_cuda_extern() -get_pyi() +setup_cuda_extern() \ No newline at end of file diff --git a/python/jittor/distributions.py b/python/jittor/distributions.py index cb566bab..83cb8a1c 100644 --- a/python/jittor/distributions.py +++ b/python/jittor/distributions.py @@ -48,8 +48,11 @@ def sample(self, sample_shape=[]): return one_hot def log_prob(self,x): - indices = jt.argmax(x,0)[0].int() - return jt.log(self.probs)[0,indices] + if len(x.shape) == 1: + x = x.unsqueeze(0) + logits = self.logits.broadcast(x.shape) + indices = jt.argmax(x, dim=-1)[0] + return logits.gather(1, indices.unsqueeze(-1)).reshape(-1) def entropy(self): min_real = -(math.pow(2,23)-1) / math.pow(2,22) * math.pow(2,127) @@ -95,8 +98,8 @@ def __init__(self,mu,sigma): self.mu = mu self.sigma = sigma - def sample(self,sample_shape): - return jt.normal(self.mu, self.sigma, sample_shape) + def sample(self,sample_shape=None): + return jt.normal(jt.array(self.mu), jt.array(self.sigma),size=sample_shape) def log_prob(self, x): var = self.sigma**2 diff --git a/python/jittor/test/test_distributions.py b/python/jittor/test/test_distributions.py index 0fdd7b2f..55f79708 100644 --- a/python/jittor/test/test_distributions.py +++ b/python/jittor/test/test_distributions.py @@ -35,9 +35,10 @@ def test_one_hot(self): jc, jc2 = jd.OneHotCategorical(jt.array(probs).reshape(1,-1)),jd.OneHotCategorical(jt.array(probs2).reshape(1,-1)) tc, tc2 = torch.distributions.OneHotCategorical(torch.tensor(probs)),torch.distributions.OneHotCategorical(torch.tensor(probs2)) assert np.allclose(jc.entropy().data,tc.entropy().numpy()) - nx = np.random.randint(0,9) - x = np.zeros((10)) - x[nx] = 1 + x = np.zeros((4,10)) + for _ in range(4): + nx = np.random.randint(0,9) + x[_,nx] = 1 assert np.allclose(jc.log_prob(jt.array(x)),tc.log_prob(torch.tensor(x))) assert np.allclose(jd.kl_divergence(jc,jc2),torch.distributions.kl_divergence(tc,tc2)) @@ -60,8 +61,6 @@ def test_normal(self): tn = torch.distributions.Normal(mu,sigma) assert np.allclose(jn.entropy().data,tn.entropy().numpy()) x = np.random.uniform(-1,1) - # print(jn.log_prob(x)) - # print(tn.log_prob(torch.tensor(x))) assert np.allclose(jn.log_prob(x),tn.log_prob(torch.tensor(x))) mu2 = np.random.uniform(-1,1) sigma2 = np.random.uniform(0,2) @@ -77,7 +76,7 @@ def test_categorical(self): jc, jc2 = jd.Categorical(jt.array(probs).reshape(1,-1)),jd.Categorical(jt.array(probs2).reshape(1,-1)) tc, tc2 = torch.distributions.Categorical(torch.tensor(probs)),torch.distributions.Categorical(torch.tensor(probs2)) assert np.allclose(jc.entropy().data,tc.entropy().numpy()) - x = np.random.randint(0,10) + x = np.random.randint(0,10,(4)) assert np.allclose(jc.log_prob(x),tc.log_prob(torch.tensor(x))) assert np.allclose(jd.kl_divergence(jc,jc2),torch.distributions.kl_divergence(tc,tc2)) From c7e0e44fd050f2539904ded5cabb4b94188d7eef Mon Sep 17 00:00:00 2001 From: Exusial Date: Wed, 12 May 2021 20:34:46 +0800 Subject: [PATCH 34/36] fix. --- python/jittor/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/jittor/__init__.py b/python/jittor/__init__.py index 95825d38..09d1e19f 100644 --- a/python/jittor/__init__.py +++ b/python/jittor/__init__.py @@ -8,7 +8,7 @@ # This file is subject to the terms and conditions defined in # file 'LICENSE.txt', which is part of this source code package. # *************************************************************** -__version__ = '1.2.2.69' +__version__ = '1.2.2.71' from . import lock with lock.lock_scope(): ori_int = int From 6ff959724df03903fa7be0fd7ee5069ff9f672b2 Mon Sep 17 00:00:00 2001 From: Exusial Date: Wed, 12 May 2021 22:01:48 +0800 Subject: [PATCH 35/36] fix test. --- doc/source/conf.py | 3 +- python/jittor/__init__.py | 9 +- python/jittor/__init__.pyi | 164 ----------------------- python/jittor/compile_extern.py | 2 +- python/jittor/test/test_distributions.py | 17 +-- 5 files changed, 15 insertions(+), 180 deletions(-) delete mode 100644 python/jittor/__init__.pyi diff --git a/doc/source/conf.py b/doc/source/conf.py index 17e10f7b..541e617a 100644 --- a/doc/source/conf.py +++ b/doc/source/conf.py @@ -59,8 +59,7 @@ # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. exclude_patterns = [] -locale_dir = ['locale/'] -gettext_compact = False + # -- Options for HTML output ------------------------------------------------- diff --git a/python/jittor/__init__.py b/python/jittor/__init__.py index e35a5606..b630779b 100644 --- a/python/jittor/__init__.py +++ b/python/jittor/__init__.py @@ -9,7 +9,7 @@ # file 'LICENSE.txt', which is part of this source code package. # *************************************************************** -__version__ = '1.2.2.71' +__version__ = '1.2.2.73' from . import lock with lock.lock_scope(): ori_int = int @@ -981,17 +981,16 @@ def load(self, path: str): This method also supports loading a state dict from a pytorch .pth file. .. note:: - 当载入的参数与模型定义不一致时, jittor 会输出错误信息, 但是不会抛出异常. + 当载入的参数与模型定义不一致时, jittor 会输出错误信? 但是不会抛出异常. 若载入参数出现模型定义中没有的参数名, 则会输出如下信息, 并忽略此参数: >>> [w 0205 21:49:39.962762 96 __init__.py:723] load parameter w failed ... - 若载入参数的 shape 与模型定义不一致, 则会输出如下信息, 并忽略此参数: + 若载入参数的 shape 与模型定义不一? 则会输出如下信息, 并忽略此参数: >>> [e 0205 21:49:39.962822 96 __init__.py:739] load parameter w failed: expect the shape of w to be [1000,100,], but got [3,100,100,] - 如载入过程中出现错误, jittor 会输出概要信息, 您需要仔细核对错误信息 - + 如载入过程中出现错误, jittor 会输出概要信? 您需要仔细核对错误信? >>> [w 0205 21:49:39.962906 96 __init__.py:741] load total 100 params, 3 failed ''' self.load_parameters(load(path)) diff --git a/python/jittor/__init__.pyi b/python/jittor/__init__.pyi deleted file mode 100644 index a65f6969..00000000 --- a/python/jittor/__init__.pyi +++ /dev/null @@ -1,164 +0,0 @@ -from typing import List, Tuple, Optional, Union, Any, ContextManager, Callable, overload -import builtins -import math -import pickle -def unary(x,op):... -def cast(x,op):... -def bool(x):... -def int8(x):... -def int16(x):... -def int32(x):... -def int64(x):... -def uint8(x):... -def uint16(x):... -def uint32(x):... -def uint64(x):... -def float32(x):... -def float64(x):... -def abs(x):... -def negative(x):... -def logical_not(x):... -def bitwise_not(x):... -def log(x):... -def exp(x):... -def sqrt(x):... -def round(x):... -def floor(x):... -def ceil(x):... -def sin(x):... -def asin(x):... -def arcsin(x):... -def sinh(x):... -def asinh(x):... -def arcsinh(x):... -def tan(x):... -def atan(x):... -def arctan(x):... -def tanh(x):... -def atanh(x):... -def arctanh(x):... -def cos(x):... -def acos(x):... -def arccos(x):... -def cosh(x):... -def acosh(x):... -def arccosh(x):... -def sigmoid(x):... -def erf(x):... -def broadcast(x,shape,dims):... -def broadcast(x,y,dims):... -def broadcast_var(x,y,dims):... -def tape(x):... -def fetch(inputs,func):... -def transpose(x,axes):... -def code(shape,dtype,inputs,cpu_src,cpu_grad_src,cpu_header,cuda_src,cuda_grad_src,cuda_header):... -def code(shapes,dtypes,inputs,cpu_src,cpu_grad_src,cpu_header,cuda_src,cuda_grad_src,cuda_header):... -def code(inputs,outputs,cpu_src,cpu_grad_src,cpu_header,cuda_src,cuda_grad_src,cuda_header):... -def setitem(x,slices,y,op):... -def candidate(x,fail_cond,dtype):... -def getitem(x,slices):... -def random(shape,dtype,type):... -def reindex_reduce(y,op,shape,indexes,overflow_conditions,extras):... -def clone(x):... -def reshape(x,shape):... -def reduce(x,op,dim,keepdims):... -def reduce(x,op,dims,keepdims):... -def max(x,dim,keepdims):... -def max(x,dims,keepdims):... -def max(x,dims_mask,keepdims_mask):... -def reduce_maximum(x,dim,keepdims):... -def reduce_maximum(x,dims,keepdims):... -def reduce_maximum(x,dims_mask,keepdims_mask):... -def min(x,dim,keepdims):... -def min(x,dims,keepdims):... -def min(x,dims_mask,keepdims_mask):... -def reduce_minimum(x,dim,keepdims):... -def reduce_minimum(x,dims,keepdims):... -def reduce_minimum(x,dims_mask,keepdims_mask):... -def sum(x,dim,keepdims):... -def sum(x,dims,keepdims):... -def sum(x,dims_mask,keepdims_mask):... -def reduce_add(x,dim,keepdims):... -def reduce_add(x,dims,keepdims):... -def reduce_add(x,dims_mask,keepdims_mask):... -def prod(x,dim,keepdims):... -def prod(x,dims,keepdims):... -def prod(x,dims_mask,keepdims_mask):... -def product(x,dim,keepdims):... -def product(x,dims,keepdims):... -def product(x,dims_mask,keepdims_mask):... -def reduce_multiply(x,dim,keepdims):... -def reduce_multiply(x,dims,keepdims):... -def reduce_multiply(x,dims_mask,keepdims_mask):... -def reduce_logical_and(x,dim,keepdims):... -def reduce_logical_and(x,dims,keepdims):... -def reduce_logical_and(x,dims_mask,keepdims_mask):... -def all_(x,dim,keepdims):... -def all_(x,dims,keepdims):... -def all_(x,dims_mask,keepdims_mask):... -def reduce_logical_or(x,dim,keepdims):... -def reduce_logical_or(x,dims,keepdims):... -def reduce_logical_or(x,dims_mask,keepdims_mask):... -def any_(x,dim,keepdims):... -def any_(x,dims,keepdims):... -def any_(x,dims_mask,keepdims_mask):... -def reduce_logical_xor(x,dim,keepdims):... -def reduce_logical_xor(x,dims,keepdims):... -def reduce_logical_xor(x,dims_mask,keepdims_mask):... -def reduce_bitwise_and(x,dim,keepdims):... -def reduce_bitwise_and(x,dims,keepdims):... -def reduce_bitwise_and(x,dims_mask,keepdims_mask):... -def reduce_bitwise_or(x,dim,keepdims):... -def reduce_bitwise_or(x,dims,keepdims):... -def reduce_bitwise_or(x,dims_mask,keepdims_mask):... -def reduce_bitwise_xor(x,dim,keepdims):... -def reduce_bitwise_xor(x,dims,keepdims):... -def reduce_bitwise_xor(x,dims_mask,keepdims_mask):... -def mean(x,dim,keepdims):... -def mean(x,dims,keepdims):... -def mean(x,dims_mask,keepdims_mask):... -def copy(x):... -def arg_reduce(x,op,dim,keepdims):... -def argsort(x,dim,descending,dtype):... -def ternary(cond,x,y):... -def binary(x,y,p):... -def pow(x,y):... -def maximum(x,y):... -def minimum(x,y):... -def add(x,y):... -def subtract(x,y):... -def multiply(x,y):... -def divide(x,y):... -def floor_divide(x,y):... -def mod(x,y):... -def less(x,y):... -def less_equal(x,y):... -def greater(x,y):... -def greater_equal(x,y):... -def equal(x,y):... -def not_equal(x,y):... -def left_shift(x,y):... -def right_shift(x,y):... -def logical_and(x,y):... -def logical_or(x,y):... -def logical_xor(x,y):... -def bitwise_and(x,y):... -def bitwise_or(x,y):... -def bitwise_xor(x,y):... -def numpy_code(shape,dtype,inputs,forward,backward):... -def numpy_code(shapes,dtypes,inputs,forward,backward):... -def numpy_code(shape,dtype,inputs,forward):... -def numpy_code(shapes,dtypes,inputs,forward):... -def where(cond,dtype):... -def index(shape,dim,dtype):... -def index(shape,dtype):... -def index(a,dim,dtype):... -def index(a,dtype):... -def index_var(a,dim,dtype):... -def index_var(a,dtype):... -def array_(args):... -def array(obj):... -def reindex(x,shape,indexes,overflow_value,overflow_conditions,extras):... -def reindex(x,indexes,overflow_value,overflow_conditions):... -def reindex_var(x,indexes,overflow_value,overflow_conditions):... -def empty(shape,dtype):... diff --git a/python/jittor/compile_extern.py b/python/jittor/compile_extern.py index 68d68274..37fda96b 100644 --- a/python/jittor/compile_extern.py +++ b/python/jittor/compile_extern.py @@ -438,4 +438,4 @@ def inner(self, *args, **kw): setup_cutt() setup_mkl() -setup_cuda_extern() \ No newline at end of file +setup_cuda_extern() diff --git a/python/jittor/test/test_distributions.py b/python/jittor/test/test_distributions.py index 55f79708..cca6c0a7 100644 --- a/python/jittor/test/test_distributions.py +++ b/python/jittor/test/test_distributions.py @@ -1,4 +1,3 @@ - # *************************************************************** # Copyright (c) 2021 Jittor. All Rights Reserved. # Maintainers: @@ -32,14 +31,16 @@ def test_one_hot(self): probs,probs2 = np.random.uniform(0,1,(10)), np.random.uniform(0,1,(10)) probs,probs2 = probs / probs.sum(),probs2 / probs2.sum() import torch + tc, tc2 = torch.distributions.OneHotCategorical(torch.tensor(probs).to(torch.float32)),torch.distributions.OneHotCategorical(torch.tensor(probs2).to(torch.float32)) jc, jc2 = jd.OneHotCategorical(jt.array(probs).reshape(1,-1)),jd.OneHotCategorical(jt.array(probs2).reshape(1,-1)) - tc, tc2 = torch.distributions.OneHotCategorical(torch.tensor(probs)),torch.distributions.OneHotCategorical(torch.tensor(probs2)) - assert np.allclose(jc.entropy().data,tc.entropy().numpy()) + # print(jc.probs,tc.probs) + # print(jc.logits,tc.logits) + assert np.allclose(jc.entropy().data,tc.entropy().numpy()), (jc.entropy().data, tc.entropy().numpy()) x = np.zeros((4,10)) for _ in range(4): nx = np.random.randint(0,9) x[_,nx] = 1 - assert np.allclose(jc.log_prob(jt.array(x)),tc.log_prob(torch.tensor(x))) + assert np.allclose(tc.log_prob(torch.tensor(x).to(torch.float32)),jc.log_prob(jt.array(x))) assert np.allclose(jd.kl_divergence(jc,jc2),torch.distributions.kl_divergence(tc,tc2)) def test_cate(self): @@ -73,11 +74,11 @@ def test_categorical(self): for _ in range(4): probs,probs2 = np.random.uniform(0,1,(10)), np.random.uniform(0,1,(10)) probs,probs2 = probs / probs.sum(),probs2 / probs2.sum() - jc, jc2 = jd.Categorical(jt.array(probs).reshape(1,-1)),jd.Categorical(jt.array(probs2).reshape(1,-1)) tc, tc2 = torch.distributions.Categorical(torch.tensor(probs)),torch.distributions.Categorical(torch.tensor(probs2)) - assert np.allclose(jc.entropy().data,tc.entropy().numpy()) + jc, jc2 = jd.Categorical(jt.array(probs).reshape(1,-1)),jd.Categorical(jt.array(probs2).reshape(1,-1)) + assert np.allclose(jc.entropy().data, tc.entropy().numpy()), (jc.entropy().data, tc.entropy().numpy()) x = np.random.randint(0,10,(4)) - assert np.allclose(jc.log_prob(x),tc.log_prob(torch.tensor(x))) + assert np.allclose(jc.log_prob(x), tc.log_prob(torch.tensor(x))) assert np.allclose(jd.kl_divergence(jc,jc2),torch.distributions.kl_divergence(tc,tc2)) def test_uniform(self): @@ -106,4 +107,4 @@ def test_geometric(self): assert np.allclose(jd.kl_divergence(jg,jg2),torch.distributions.kl_divergence(tg,tg2)) if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() From 3feaea4f09a197181a24744d9feceb0c15542236 Mon Sep 17 00:00:00 2001 From: Exusial Date: Fri, 28 May 2021 09:20:57 +0800 Subject: [PATCH 36/36] update. --- python/jittor/__init__.py | 24 +++-------- python/jittor/distributions.py | 52 ++++++++++++++++++++++++ python/jittor/test/test_distributions.py | 26 ++++++++++-- 3 files changed, 79 insertions(+), 23 deletions(-) diff --git a/python/jittor/__init__.py b/python/jittor/__init__.py index b630779b..ce025e66 100644 --- a/python/jittor/__init__.py +++ b/python/jittor/__init__.py @@ -981,16 +981,17 @@ def load(self, path: str): This method also supports loading a state dict from a pytorch .pth file. .. note:: - 当载入的参数与模型定义不一致时, jittor 会输出错误信? 但是不会抛出异常. + 当载入的参数与模型定义不一致时, jittor 会输出错误信息, 但是不会抛出异常. 若载入参数出现模型定义中没有的参数名, 则会输出如下信息, 并忽略此参数: >>> [w 0205 21:49:39.962762 96 __init__.py:723] load parameter w failed ... - 若载入参数的 shape 与模型定义不一? 则会输出如下信息, 并忽略此参数: + 若载入参数的 shape 与模型定义不一致, 则会输出如下信息, 并忽略此参数: >>> [e 0205 21:49:39.962822 96 __init__.py:739] load parameter w failed: expect the shape of w to be [1000,100,], but got [3,100,100,] - 如载入过程中出现错误, jittor 会输出概要信? 您需要仔细核对错误信? + 如载入过程中出现错误, jittor 会输出概要信息, 您需要仔细核对错误信息 + >>> [w 0205 21:49:39.962906 96 __init__.py:741] load total 100 params, 3 failed ''' self.load_parameters(load(path)) @@ -1260,21 +1261,6 @@ def get_len(var): Var.__module__ = "jittor" Var.__reduce__ = lambda self: (Var, (self.data,)) -class finfo: - def __init__(self,dtype=Var.float): - if dtype == "float32": - self.bits = 32 - self.eps = math.pow(2,-23) - self.max = (math.pow(2,23)-1) / math.pow(2,22) * math.pow(2,127) - self.min = -self.max - self.tiny = math.pow(2,-126) - elif dtype == "float64": - self.bits = 64 - self.eps = math.pow(2,-52) - self.max = (math.pow(2,52)-1) / math.pow(2,51) * math.pow(2,1023) - self.min = -self.max - self.tiny = math.pow(2,-1022) - from . import nn from . import attention from . import lr_scheduler @@ -1284,4 +1270,4 @@ def __init__(self,dtype=Var.float): from . import numpy2cupy from .contrib import concat from .misc import * -from . import sparse +from . import sparse \ No newline at end of file diff --git a/python/jittor/distributions.py b/python/jittor/distributions.py index 83cb8a1c..07a02169 100644 --- a/python/jittor/distributions.py +++ b/python/jittor/distributions.py @@ -26,6 +26,18 @@ def simple_presum(x): cpu_src=src, cuda_src=src) +def lgamma(x): + header = '''#include''' + src = ''' + @alias(a, in0) + @alias(b, out0) + for (int i=0;i s: + p = la * p / (k + 1) + s = s + p + k += 1 + res[i] = k + return res + + +class Poisson: + def __init__(self, la): + self.la = la + + def sample(self, sample_shape): + return Poisson_sample(self.la,sample_shape) + + def log_prob(self,x): + # todo: add lgamma. + return jt.log(self.la)* x - self.la - lgamma(x + 1) + + def kl_divergence(cur_dist,old_dist): assert isinstance(cur_dist,type(old_dist)) if isinstance(cur_dist,Normal): @@ -169,3 +219,5 @@ def kl_divergence(cur_dist,old_dist): return res if isinstance(cur_dist,Geometric): return -cur_dist.entropy() - jt.log(-old_dist.prob+1) / cur_dist.prob - old_dist.logits + if isinstance(cur_dist,Poisson): + return cur_dist.la * (jt.log(cur_dist.la) - jt.log(old_dist.la)) - (cur_dist.la - old_dist.la) diff --git a/python/jittor/test/test_distributions.py b/python/jittor/test/test_distributions.py index cca6c0a7..17df62cb 100644 --- a/python/jittor/test/test_distributions.py +++ b/python/jittor/test/test_distributions.py @@ -18,6 +18,12 @@ def test_presum(self): a = jt.array([[1,2,3,4]]) b = jd.simple_presum(a) assert (b.data == [[0,1,3,6,10]]).all() + + def test_lgamma(self): + import torch + ta = np.random.uniform(2,3,(1)) + a = jt.array(ta).float32() + assert np.allclose(jd.lgamma(a).data, torch.lgamma(torch.tensor(ta)).numpy()),(jd.lgamma(a).data, torch.lgamma(torch.tensor(ta)).numpy()) def test_one_hot(self): a = jd.OneHotCategorical(jt.array([0.25, 0.25, 0.25, 0.25])) @@ -89,7 +95,7 @@ def test_uniform(self): high, high2 = low + leng, low2 + leng2 ju, ju2 = jd.Uniform(low,high),jd.Uniform(low2,high2) tu, tu2 = torch.distributions.Uniform(low,high),torch.distributions.Uniform(low2,high2) - assert np.allclose(ju.entropy().data,tu.entropy().numpy()) + assert np.allclose(ju.entropy().data,tu.entropy().numpy()),(ju.entropy().data,tu.entropy().numpy()) x = np.random.uniform(low,high) assert np.allclose(ju.log_prob(x),tu.log_prob(torch.tensor(x))) assert np.allclose(jd.kl_divergence(ju,ju2),torch.distributions.kl_divergence(tu,tu2)) @@ -100,11 +106,23 @@ def test_geometric(self): prob, prob2 = np.random.uniform(0,1), np.random.uniform(0,1) jg, jg2 = jd.Geometric(prob),jd.Geometric(prob2) tg, tg2 = torch.distributions.Geometric(prob),torch.distributions.Geometric(prob2) - assert np.allclose(jg.entropy().data,tg.entropy().numpy()) + assert np.allclose(jg.entropy().data,tg.entropy().numpy()),(jg.entropy().data,tg.entropy().numpy()) x = np.random.randint(1,10) - assert np.allclose(jg.log_prob(x),tg.log_prob(torch.tensor(x))) + assert np.allclose(jg.log_prob(jt.array(x)),tg.log_prob(torch.tensor(x))) # print(jd.kl_divergence(jg,jg2),torch.distributions.kl_divergence(tg,tg2)) - assert np.allclose(jd.kl_divergence(jg,jg2),torch.distributions.kl_divergence(tg,tg2)) + assert np.allclose(jd.kl_divergence(jg,jg2),torch.distributions.kl_divergence(tg,tg2)),(jd.kl_divergence(jg,jg2),torch.distributions.kl_divergence(tg,tg2)) + + def test_poisson(self): + import torch + for _ in range(4): + prob, prob2 = np.random.uniform(0,1), np.random.uniform(0,1) + jp, jp2 = jd.Poisson(prob),jd.Poisson(prob2) + tp, tp2 = torch.distributions.Poisson(prob),torch.distributions.Poisson(prob2) + x = np.random.randint(1,10) + assert np.allclose(jp.log_prob(jt.array(x).float32()),tp.log_prob(torch.tensor(x))) + # print(jd.kl_divergence(jg,jg2),torch.distributions.kl_divergence(tg,tg2)) + assert np.allclose(jd.kl_divergence(jp,jp2),torch.distributions.kl_divergence(tp,tp2)),(jd.kl_divergence(jp,jp2),torch.distributions.kl_divergence(tp,tp2)) + if __name__ == "__main__": unittest.main()