Skip to content

Commit 2dc161a

Browse files
committed
fix float mod
1 parent 4b26d65 commit 2dc161a

File tree

3 files changed

+17
-2
lines changed

3 files changed

+17
-2
lines changed

python/jittor/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
# This file is subject to the terms and conditions defined in
88
# file 'LICENSE.txt', which is part of this source code package.
99
# ***************************************************************
10-
__version__ = '1.1.7.7'
10+
__version__ = '1.1.7.8'
1111
from . import lock
1212
with lock.lock_scope():
1313
from . import compiler

python/jittor/test/test_binary_op.py

+14
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,20 @@ def test_grad(self):
126126
for jd, nd in zip(jgrads, grads):
127127
assert (np.abs(jd.data-nd)<1e-4).all(), f"\n{jd.data}\n{nd}"
128128

129+
def test_mod_float(self):
130+
a = jt.random((10,))
131+
b = jt.random((10,))
132+
c = a % b
133+
assert np.allclose(c.data, a.data % b.data)
134+
a = jt.random((10,), 'float64')
135+
b = jt.random((10,), 'float64')
136+
c = a % b
137+
assert np.allclose(c.data, a.data % b.data)
138+
a = jt.random((10,)) * 1000
139+
b = (jt.random((10,)) * 10).int() + 1
140+
c = a % b
141+
assert np.allclose(c.data, a.data % b.data), (c.data, a.data%b.data)
142+
129143

130144
class TestBinaryOpCuda(TestBinaryOp, test_cuda(2)):
131145
pass

src/ops/binary_op_defs.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,18 @@ namespace jittor {
1212
#define pow(T,a,b) ::pow(a,b)
1313
#define maximum(T,a,b) ::max(T(a), T(b))
1414
#define minimum(T,a,b) ::min(T(a), T(b))
15+
#define mod(T,a,b) @if(@strcmp(@Tx,float32)==0,::fmodf(T(a),T(b)),@if(@strcmp(@Tx,float64)==0,::fmod(T(a),T(b)),((a)%(b))))
1516
#else // JIT_cpu
1617
#define pow(T,a,b) std::pow(a,b)
1718
#define maximum(T,a,b) std::max(T(a), T(b))
1819
#define minimum(T,a,b) std::min(T(a), T(b))
20+
#define mod(T,a,b) @if(@strcmp(@Tx,float32)==0 || @strcmp(@Tx,float64)==0,std::fmod((T)a,(T)b),((a)%(b)))
1921
#endif
2022
#define add(T,a,b) ((a)+(b))
2123
#define subtract(T,a,b) ((a)-(b))
2224
#define multiply(T,a,b) ((a)*(b))
2325
#define divide(T,a,b) (T((T(a))/(T(b))))
2426
#define floor_divide(T,a,b) (T((T(a))/(T(b))))
25-
#define mod(T,a,b) ((a)%(b))
2627
#define less(T,a,b) ((a)<(b))
2728
#define less_equal(T,a,b) ((a)<=(b))
2829
#define greater(T,a,b) ((a)>(b))

0 commit comments

Comments
 (0)