Skip to content

Commit b84b770

Browse files
authored
DatView: Fix zero() (#727)
1 parent 5f18075 commit b84b770

File tree

2 files changed

+79
-18
lines changed

2 files changed

+79
-18
lines changed

pyop2/types/dat.py

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -681,6 +681,7 @@ def __init__(self, dat, index):
681681
if not (0 <= i < d):
682682
raise ex.IndexValueError("Can't create DatView with index %s for Dat with shape %s" % (index, dat.dim))
683683
self.index = index
684+
self._idx = (slice(None), *index)
684685
self._parent = dat
685686
# Point at underlying data
686687
super(DatView, self).__init__(dat.dataset,
@@ -720,41 +721,37 @@ def halo_valid(self):
720721
def halo_valid(self, value):
721722
self._parent.halo_valid = value
722723

724+
@property
725+
def dat_version(self):
726+
return self._parent.dat_version
727+
728+
@property
729+
def _data(self):
730+
return self._parent._data[self._idx]
731+
723732
@property
724733
def data(self):
725-
full = self._parent.data
726-
idx = (slice(None), *self.index)
727-
return full[idx]
734+
return self._parent.data[self._idx]
728735

729736
@property
730737
def data_ro(self):
731-
full = self._parent.data_ro
732-
idx = (slice(None), *self.index)
733-
return full[idx]
738+
return self._parent.data_ro[self._idx]
734739

735740
@property
736741
def data_wo(self):
737-
full = self._parent.data_wo
738-
idx = (slice(None), *self.index)
739-
return full[idx]
742+
return self._parent.data_wo[self._idx]
740743

741744
@property
742745
def data_with_halos(self):
743-
full = self._parent.data_with_halos
744-
idx = (slice(None), *self.index)
745-
return full[idx]
746+
return self._parent.data_with_halos[self._idx]
746747

747748
@property
748749
def data_ro_with_halos(self):
749-
full = self._parent.data_ro_with_halos
750-
idx = (slice(None), *self.index)
751-
return full[idx]
750+
return self._parent.data_ro_with_halos[self._idx]
752751

753752
@property
754753
def data_wo_with_halos(self):
755-
full = self._parent.data_wo_with_halos
756-
idx = (slice(None), *self.index)
757-
return full[idx]
754+
return self._parent.data_wo_with_halos[self._idx]
758755

759756

760757
class Dat(AbstractDat, VecAccessMixin):

test/unit/test_dats.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,16 @@ def mdat(d1):
5555
return op2.MixedDat([d1, d1])
5656

5757

58+
@pytest.fixture(scope='module')
59+
def s2(s):
60+
return op2.DataSet(s, 2)
61+
62+
63+
@pytest.fixture
64+
def vdat(s2):
65+
return op2.Dat(s2, np.zeros(2 * nelems), dtype=np.float64)
66+
67+
5868
class TestDat:
5969

6070
"""
@@ -254,6 +264,60 @@ def test_accessing_data_with_halos_increments_dat_version(self, d1):
254264
assert d1.dat_version == 1
255265

256266

267+
class TestDatView():
268+
269+
def test_dat_view_assign(self, vdat):
270+
vdat.data[:, 0] = 3
271+
vdat.data[:, 1] = 4
272+
comp = op2.DatView(vdat, 1)
273+
comp.data[:] = 7
274+
assert not vdat.halo_valid
275+
assert not comp.halo_valid
276+
277+
expected = np.zeros_like(vdat.data)
278+
expected[:, 0] = 3
279+
expected[:, 1] = 7
280+
assert all(comp.data == expected[:, 1])
281+
assert all(vdat.data[:, 0] == expected[:, 0])
282+
assert all(vdat.data[:, 1] == expected[:, 1])
283+
284+
def test_dat_view_zero(self, vdat):
285+
vdat.data[:, 0] = 3
286+
vdat.data[:, 1] = 4
287+
comp = op2.DatView(vdat, 1)
288+
comp.zero()
289+
assert vdat.halo_valid
290+
assert comp.halo_valid
291+
292+
expected = np.zeros_like(vdat.data)
293+
expected[:, 0] = 3
294+
expected[:, 1] = 0
295+
assert all(comp.data == expected[:, 1])
296+
assert all(vdat.data[:, 0] == expected[:, 0])
297+
assert all(vdat.data[:, 1] == expected[:, 1])
298+
299+
def test_dat_view_halo_valid(self, vdat):
300+
"""Check halo validity for DatView"""
301+
comp = op2.DatView(vdat, 1)
302+
assert vdat.halo_valid
303+
assert comp.halo_valid
304+
assert vdat.dat_version == 0
305+
assert comp.dat_version == 0
306+
307+
comp.data_ro_with_halos
308+
assert vdat.halo_valid
309+
assert comp.halo_valid
310+
assert vdat.dat_version == 0
311+
assert comp.dat_version == 0
312+
313+
# accessing comp.data_with_halos should mark the parent halo as dirty
314+
comp.data_with_halos
315+
assert not vdat.halo_valid
316+
assert not comp.halo_valid
317+
assert vdat.dat_version == 1
318+
assert comp.dat_version == 1
319+
320+
257321
if __name__ == '__main__':
258322
import os
259323
pytest.main(os.path.abspath(__file__))

0 commit comments

Comments
 (0)