Skip to content

Commit 4b26d65

Browse files
authored
Merge pull request #130 from Jittor/zwy2
add misc & linalg doc
2 parents b79e6e1 + 843b983 commit 4b26d65

File tree

6 files changed

+93
-13
lines changed

6 files changed

+93
-13
lines changed

doc/source/index.rst

+1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
jittor.dataset
2626
jittor.transform
2727
jittor.mpi
28+
jittor.linalg
2829

2930

3031
.. toctree::

doc/source/jittor.linalg.md

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
jittor.linalg
2+
=====================
3+
4+
这里是Jittor的线性代数函数的API文档,您可以通过`from jittor import linalg`来获取该模块。
5+
6+
```eval_rst
7+
.. automodule:: jittor.linalg
8+
:members:
9+
:undoc-members:
10+
```

doc/source/jittor.md

+10
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,13 @@ jittor
4242
:members:
4343
:undoc-members:
4444
```
45+
46+
## jittor.Misc
47+
48+
这里是Jittor的基础算子模块的API文档,该API可以通过`jittor.misc.XXX`或者`jittor.XXX`直接访问。
49+
50+
```eval_rst
51+
.. automodule:: jittor.misc
52+
:members:
53+
:undoc-members:
54+
```

python/jittor/nn.py

+20-11
Original file line numberDiff line numberDiff line change
@@ -167,11 +167,16 @@ def cross_entropy_loss(output, target, ignore_index=None):
167167
def mse_loss(output, target):
168168
return (output-target).sqr().mean()
169169

170-
def bce_loss(output, target, size_average=True):
170+
def bce_loss(output, target, weight=None, size_average=True):
171+
loss = - (target * jt.log(jt.maximum(output, 1e-20)) + (1 - target) * jt.log(jt.maximum(1 - output, 1e-20)))
172+
173+
if weight is not None:
174+
loss *= weight
175+
171176
if size_average:
172-
return - (target * jt.log(jt.maximum(output, 1e-20)) + (1 - target) * jt.log(jt.maximum(1 - output, 1e-20))).mean()
177+
return loss.mean()
173178
else:
174-
return - (target * jt.log(jt.maximum(output, 1e-20)) + (1 - target) * jt.log(jt.maximum(1 - output, 1e-20))).sum()
179+
return loss.sum()
175180

176181
def l1_loss(output, target):
177182
return (output-target).abs().mean()
@@ -189,10 +194,11 @@ def execute(self, output, target):
189194
return mse_loss(output, target)
190195

191196
class BCELoss(Module):
192-
def __init__(self):
193-
pass
194-
def execute(self, output, target, size_average=True):
195-
return bce_loss(output, target, size_average)
197+
def __init__(self, weight=None, size_average=True):
198+
self.weight = weight
199+
self.size_average = size_average
200+
def execute(self, output, target):
201+
return bce_loss(output, target, self.weight, self.size_average)
196202

197203
class L1Loss(Module):
198204
def __init__(self):
@@ -201,14 +207,17 @@ def execute(self, output, target):
201207
return l1_loss(output, target)
202208

203209
class BCEWithLogitsLoss(Module):
204-
def __init__(self):
210+
def __init__(self, weight=None, size_average=True):
205211
self.sigmoid = Sigmoid()
206-
self.bce = BCELoss()
207-
def execute(self, output, target, size_average=True):
212+
self.bce = BCELoss(weight, size_average)
213+
def execute(self, output, target):
208214
output = self.sigmoid(output)
209-
output = self.bce(output, target, size_average)
215+
output = self.bce(output, target)
210216
return output
211217

218+
def binary_cross_entropy_with_logits(input, target, weight=None, size_average=True):
219+
return BCEWithLogitsLoss(weight, size_average)(input, target)
220+
212221
def softmax(x, dim = None):
213222
if dim is None:
214223
x = (x - x.max()).exp()

python/jittor/test/test_loss.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def test_cross_entropy_loss(self):
4949
jt_y=jt_loss(jt.array(output), jt.array(target))
5050
tc_y=tc_loss(torch.from_numpy(output), torch.from_numpy(target))
5151
assert np.allclose(jt_y.numpy(), tc_y.numpy())
52-
52+
5353
def test_bce_loss(self):
5454
jt_loss=jnn.BCELoss()
5555
tc_loss=tnn.BCELoss()
@@ -60,6 +60,13 @@ def test_bce_loss(self):
6060
jt_y=jt_loss(jt_sig(jt.array(output)), jt.array(target))
6161
tc_y=tc_loss(tc_sig(torch.from_numpy(output)), torch.from_numpy(target))
6262
assert np.allclose(jt_y.numpy(), tc_y.numpy())
63+
64+
weight=np.random.randn(100).astype(np.float32)
65+
jt_loss=jnn.BCELoss(weight=jt.array(weight), size_average=False)
66+
tc_loss=tnn.BCELoss(weight=torch.Tensor(weight), size_average=False)
67+
jt_y=jt_loss(jt_sig(jt.array(output)), jt.array(target))
68+
tc_y=tc_loss(tc_sig(torch.from_numpy(output)), torch.from_numpy(target))
69+
assert np.allclose(jt_y.numpy(), tc_y.numpy())
6370

6471
def test_bce_with_logits_loss(self):
6572
jt_loss=jnn.BCEWithLogitsLoss()

python/jittor/utils/pytorch_converter.py

+44-1
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,32 @@
7878
'extras': {},
7979
'delete': ['inplace'],
8080
},
81+
'relu': {
82+
'pytorch': {
83+
'args': 'input',
84+
},
85+
'jittor': {
86+
'module': 'nn',
87+
'name': 'relu',
88+
'args': 'x'
89+
},
90+
'links': {'input': 'x'},
91+
'extras': {},
92+
'delete': [],
93+
},
94+
'binary_cross_entropy_with_logits': {
95+
'pytorch': {
96+
'args': 'input, target, weight, size_average=True',
97+
},
98+
'jittor': {
99+
'module': 'nn',
100+
'name': 'binary_cross_entropy_with_logits',
101+
'args': 'input, target, weight, size_average=True'
102+
},
103+
'links': {},
104+
'extras': {},
105+
'delete': [],
106+
},
81107
'ReLU6': {
82108
'pytorch': {
83109
'args': 'inplace=False',
@@ -274,6 +300,23 @@
274300
'links': {},
275301
'extras': {},
276302
},
303+
'clamp': {
304+
'pytorch': {
305+
'prefix': ['torch'],
306+
'args_prefix': 'input, min, max, out=None',
307+
'args': 'min, max, out=None',
308+
},
309+
'jittor': {
310+
'prefix': 'jt',
311+
'module': '',
312+
'name': 'clamp',
313+
'args_prefix': 'x, min_v, max_v',
314+
'args': 'min_v, max_v'
315+
},
316+
'links': {'min': 'min_v', 'max': 'max_v'},
317+
'extras': {},
318+
'delete': ['out'],
319+
},
277320
'permute': {
278321
'pytorch': {
279322
'prefix': [],
@@ -354,7 +397,7 @@ def pjmap_append(pytorch_func_name, pytorch_args, jittor_func_module, jittor_fun
354397
# ***************************************************************
355398
# torch.nn
356399
# ***************************************************************
357-
'Parameter', 'ModuleList', 'ModuleDict', 'ParameterList', 'ParameterDict',
400+
'Parameter', 'ModuleDict', 'ParameterList', 'ParameterDict',
358401
'Conv1d', 'Conv3d', 'ConvTranspose1d', 'ConvTranspose3d', 'Unfold', 'Fold',
359402
'MaxPool1d', 'MaxPool3d', 'MaxUnpool1d', 'MaxUnpool2d', 'MaxUnpool3d', 'AvgPool1d',
360403
'AvgPool3d', 'FractionalMaxPool2d', 'LPPool1d', 'LPPool2d', 'AdaptiveMaxPool1d',

0 commit comments

Comments
 (0)