1515
1616@register_pytree_node_class
1717class CoordinateArrayTriangles (AbstractCoordinateArray ):
18- def __init__ (
19- self ,
20- coordinates : np .ndarray ,
21- mask : np .ndarray ,
22- side_length : float = 1.0 ,
23- x_offset : float = 0.0 ,
24- y_offset : float = 0.0 ,
25- flipped : bool = False ,
26- ):
27- super ().__init__ (
28- coordinates = coordinates ,
29- side_length = side_length ,
30- x_offset = x_offset ,
31- y_offset = y_offset ,
32- flipped = flipped ,
33- )
34- self .mask = mask
35-
3618 @property
3719 def numpy (self ):
3820 return jax .numpy
@@ -57,19 +39,13 @@ def for_limits_and_scale(
5739 coordinates .append ([x , y ])
5840
5941 return cls (
60- coordinates = np .array (coordinates , dtype = np . int32 ),
42+ coordinates = np .array (coordinates ),
6143 side_length = scale ,
62- mask = np .full (
63- len (coordinates ),
64- False ,
65- dtype = bool ,
66- ),
6744 )
6845
6946 def tree_flatten (self ):
7047 return (
7148 self .coordinates ,
72- self .mask ,
7349 self .side_length ,
7450 self .x_offset ,
7551 self .y_offset ,
@@ -102,7 +78,7 @@ def centres(self) -> np.ndarray:
10278 centres = self .scaling_factors * self .coordinates + np .array (
10379 [self .x_offset , self .y_offset ]
10480 )
105- return self . numpy . where ( self . mask [:, None ], np . nan , centres )
81+ return centres
10682
10783 @cached_property
10884 def flip_mask (self ) -> np .ndarray :
@@ -138,29 +114,22 @@ def up_sample(self) -> "CoordinateArrayTriangles":
138114
139115 n = coordinates .shape [0 ]
140116
141- shift0 = np .zeros ((n , 2 ), dtype = np .int32 )
142- shift3 = np .tile (np .array ([0 , 1 ], dtype = np .int32 ), (n , 1 ))
143- shift1 = np .stack (
144- [np .ones (n , dtype = np .int32 ), np .where (flip_mask , 1 , 0 )], axis = 1
145- )
146- shift2 = np .stack (
147- [- np .ones (n , dtype = np .int32 ), np .where (flip_mask , 1 , 0 )], axis = 1
148- )
117+ shift0 = np .zeros ((n , 2 ))
118+ shift3 = np .tile (np .array ([0 , 1 ]), (n , 1 ))
119+ shift1 = np .stack ([np .ones (n ), np .where (flip_mask , 1 , 0 )], axis = 1 )
120+ shift2 = np .stack ([- np .ones (n ), np .where (flip_mask , 1 , 0 )], axis = 1 )
149121 shifts = np .stack ([shift0 , shift1 , shift2 , shift3 ], axis = 1 )
150122
151123 coordinates_expanded = coordinates [:, None , :]
152124 new_coordinates = coordinates_expanded + shifts
153125 new_coordinates = new_coordinates .reshape (- 1 , 2 )
154126
155- new_mask = np .repeat (self .mask , 4 )
156-
157127 return CoordinateArrayTriangles (
158128 coordinates = new_coordinates ,
159129 side_length = self .side_length / 2 ,
160130 flipped = True ,
161131 y_offset = self .y_offset + - 0.25 * HEIGHT_FACTOR * self .side_length ,
162132 x_offset = self .x_offset ,
163- mask = new_mask ,
164133 )
165134
166135 def neighborhood (self ) -> "CoordinateArrayTriangles" :
@@ -169,17 +138,16 @@ def neighborhood(self) -> "CoordinateArrayTriangles":
169138
170139 Ensures that the new triangles are unique and adjusts the mask accordingly.
171140 """
172- coordinates = self .coordinates . astype ( np . int32 )
141+ coordinates = self .coordinates
173142 flip_mask = self .flip_mask
174- mask = self .mask
175143
176- shift0 = np .zeros ((coordinates .shape [0 ], 2 ), dtype = np . int32 )
177- shift1 = np .tile (np .array ([1 , 0 ], dtype = np . int32 ), (coordinates .shape [0 ], 1 ))
178- shift2 = np .tile (np .array ([- 1 , 0 ], dtype = np . int32 ), (coordinates .shape [0 ], 1 ))
144+ shift0 = np .zeros ((coordinates .shape [0 ], 2 ))
145+ shift1 = np .tile (np .array ([1 , 0 ]), (coordinates .shape [0 ], 1 ))
146+ shift2 = np .tile (np .array ([- 1 , 0 ]), (coordinates .shape [0 ], 1 ))
179147 shift3 = np .where (
180148 flip_mask [:, None ],
181- np .tile (np .array ([0 , 1 ], dtype = np . int32 ), (coordinates .shape [0 ], 1 )),
182- np .tile (np .array ([0 , - 1 ], dtype = np . int32 ), (coordinates .shape [0 ], 1 )),
149+ np .tile (np .array ([0 , 1 ]), (coordinates .shape [0 ], 1 )),
150+ np .tile (np .array ([0 , - 1 ]), (coordinates .shape [0 ], 1 )),
183151 )
184152
185153 shifts = np .stack ([shift0 , shift1 , shift2 , shift3 ], axis = 1 )
@@ -188,24 +156,17 @@ def neighborhood(self) -> "CoordinateArrayTriangles":
188156 new_coordinates = coordinates_expanded + shifts
189157 new_coordinates = new_coordinates .reshape (- 1 , 2 )
190158
191- new_mask_flat = np .repeat (mask , 4 )
192159 expected_size = 4 * coordinates .shape [0 ]
193- fill_value = np .iinfo (np .int32 ).max
194160 unique_coords , indices = np .unique (
195161 new_coordinates ,
196162 axis = 0 ,
197163 size = expected_size ,
198- fill_value = fill_value ,
164+ fill_value = np . nan ,
199165 return_index = True ,
200166 )
201- new_mask = np .ones (expected_size , dtype = bool )
202- valid_indices = ~ (unique_coords == fill_value ).all (axis = 1 )
203- new_mask = new_mask .at [valid_indices ].set (new_mask_flat [indices [valid_indices ]])
204- unique_coords = unique_coords .astype (np .int32 )
205167
206168 return CoordinateArrayTriangles (
207169 coordinates = unique_coords ,
208- mask = new_mask ,
209170 side_length = self .side_length ,
210171 flipped = self .flipped ,
211172 y_offset = self .y_offset ,
@@ -277,6 +238,3 @@ def for_indexes(self, indexes: np.ndarray) -> "CoordinateArrayTriangles":
277238
278239 def containing_indices (self , shape : np .ndarray ) -> np .ndarray :
279240 raise NotImplementedError ("JAX ArrayTriangles are used for this method." )
280-
281- def __len__ (self ):
282- return self .numpy .count_nonzero (~ self .mask )
0 commit comments