Skip to content

Commit 05c2089

Browse files
committed
revert to representing coordinates as floats
1 parent de032ae commit 05c2089

2 files changed

Lines changed: 13 additions & 56 deletions

File tree

autoarray/structures/triangles/coordinate_array/jax_coordinate_array.py

Lines changed: 13 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -15,24 +15,6 @@
1515

1616
@register_pytree_node_class
1717
class CoordinateArrayTriangles(AbstractCoordinateArray):
18-
def __init__(
19-
self,
20-
coordinates: np.ndarray,
21-
mask: np.ndarray,
22-
side_length: float = 1.0,
23-
x_offset: float = 0.0,
24-
y_offset: float = 0.0,
25-
flipped: bool = False,
26-
):
27-
super().__init__(
28-
coordinates=coordinates,
29-
side_length=side_length,
30-
x_offset=x_offset,
31-
y_offset=y_offset,
32-
flipped=flipped,
33-
)
34-
self.mask = mask
35-
3618
@property
3719
def numpy(self):
3820
return jax.numpy
@@ -57,19 +39,13 @@ def for_limits_and_scale(
5739
coordinates.append([x, y])
5840

5941
return cls(
60-
coordinates=np.array(coordinates, dtype=np.int32),
42+
coordinates=np.array(coordinates),
6143
side_length=scale,
62-
mask=np.full(
63-
len(coordinates),
64-
False,
65-
dtype=bool,
66-
),
6744
)
6845

6946
def tree_flatten(self):
7047
return (
7148
self.coordinates,
72-
self.mask,
7349
self.side_length,
7450
self.x_offset,
7551
self.y_offset,
@@ -102,7 +78,7 @@ def centres(self) -> np.ndarray:
10278
centres = self.scaling_factors * self.coordinates + np.array(
10379
[self.x_offset, self.y_offset]
10480
)
105-
return self.numpy.where(self.mask[:, None], np.nan, centres)
81+
return centres
10682

10783
@cached_property
10884
def flip_mask(self) -> np.ndarray:
@@ -138,29 +114,22 @@ def up_sample(self) -> "CoordinateArrayTriangles":
138114

139115
n = coordinates.shape[0]
140116

141-
shift0 = np.zeros((n, 2), dtype=np.int32)
142-
shift3 = np.tile(np.array([0, 1], dtype=np.int32), (n, 1))
143-
shift1 = np.stack(
144-
[np.ones(n, dtype=np.int32), np.where(flip_mask, 1, 0)], axis=1
145-
)
146-
shift2 = np.stack(
147-
[-np.ones(n, dtype=np.int32), np.where(flip_mask, 1, 0)], axis=1
148-
)
117+
shift0 = np.zeros((n, 2))
118+
shift3 = np.tile(np.array([0, 1]), (n, 1))
119+
shift1 = np.stack([np.ones(n), np.where(flip_mask, 1, 0)], axis=1)
120+
shift2 = np.stack([-np.ones(n), np.where(flip_mask, 1, 0)], axis=1)
149121
shifts = np.stack([shift0, shift1, shift2, shift3], axis=1)
150122

151123
coordinates_expanded = coordinates[:, None, :]
152124
new_coordinates = coordinates_expanded + shifts
153125
new_coordinates = new_coordinates.reshape(-1, 2)
154126

155-
new_mask = np.repeat(self.mask, 4)
156-
157127
return CoordinateArrayTriangles(
158128
coordinates=new_coordinates,
159129
side_length=self.side_length / 2,
160130
flipped=True,
161131
y_offset=self.y_offset + -0.25 * HEIGHT_FACTOR * self.side_length,
162132
x_offset=self.x_offset,
163-
mask=new_mask,
164133
)
165134

166135
def neighborhood(self) -> "CoordinateArrayTriangles":
@@ -169,17 +138,16 @@ def neighborhood(self) -> "CoordinateArrayTriangles":
169138
170139
Ensures that the new triangles are unique and adjusts the mask accordingly.
171140
"""
172-
coordinates = self.coordinates.astype(np.int32)
141+
coordinates = self.coordinates
173142
flip_mask = self.flip_mask
174-
mask = self.mask
175143

176-
shift0 = np.zeros((coordinates.shape[0], 2), dtype=np.int32)
177-
shift1 = np.tile(np.array([1, 0], dtype=np.int32), (coordinates.shape[0], 1))
178-
shift2 = np.tile(np.array([-1, 0], dtype=np.int32), (coordinates.shape[0], 1))
144+
shift0 = np.zeros((coordinates.shape[0], 2))
145+
shift1 = np.tile(np.array([1, 0]), (coordinates.shape[0], 1))
146+
shift2 = np.tile(np.array([-1, 0]), (coordinates.shape[0], 1))
179147
shift3 = np.where(
180148
flip_mask[:, None],
181-
np.tile(np.array([0, 1], dtype=np.int32), (coordinates.shape[0], 1)),
182-
np.tile(np.array([0, -1], dtype=np.int32), (coordinates.shape[0], 1)),
149+
np.tile(np.array([0, 1]), (coordinates.shape[0], 1)),
150+
np.tile(np.array([0, -1]), (coordinates.shape[0], 1)),
183151
)
184152

185153
shifts = np.stack([shift0, shift1, shift2, shift3], axis=1)
@@ -188,24 +156,17 @@ def neighborhood(self) -> "CoordinateArrayTriangles":
188156
new_coordinates = coordinates_expanded + shifts
189157
new_coordinates = new_coordinates.reshape(-1, 2)
190158

191-
new_mask_flat = np.repeat(mask, 4)
192159
expected_size = 4 * coordinates.shape[0]
193-
fill_value = np.iinfo(np.int32).max
194160
unique_coords, indices = np.unique(
195161
new_coordinates,
196162
axis=0,
197163
size=expected_size,
198-
fill_value=fill_value,
164+
fill_value=np.nan,
199165
return_index=True,
200166
)
201-
new_mask = np.ones(expected_size, dtype=bool)
202-
valid_indices = ~(unique_coords == fill_value).all(axis=1)
203-
new_mask = new_mask.at[valid_indices].set(new_mask_flat[indices[valid_indices]])
204-
unique_coords = unique_coords.astype(np.int32)
205167

206168
return CoordinateArrayTriangles(
207169
coordinates=unique_coords,
208-
mask=new_mask,
209170
side_length=self.side_length,
210171
flipped=self.flipped,
211172
y_offset=self.y_offset,
@@ -277,6 +238,3 @@ def for_indexes(self, indexes: np.ndarray) -> "CoordinateArrayTriangles":
277238

278239
def containing_indices(self, shape: np.ndarray) -> np.ndarray:
279240
raise NotImplementedError("JAX ArrayTriangles are used for this method.")
280-
281-
def __len__(self):
282-
return self.numpy.count_nonzero(~self.mask)

test_autoarray/structures/triangles/coordinate/test_coordinate_jax.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ def one_triangle():
2222
return CoordinateArrayTriangles(
2323
coordinates=np.array([[0, 0]]),
2424
side_length=1.0,
25-
mask=np.array([False]),
2625
)
2726

2827

0 commit comments

Comments
 (0)