88)
99
1010
11+ # @torch.jit.script
12+ def _merge_sorted (
13+ old_nd : Tensor ,
14+ new_nd : Tensor ,
15+ old_values : Tensor ,
16+ new_values : Tensor ,
17+ insertion_positions : Tensor ,
18+ ) -> tuple [Tensor , Tensor ]:
19+ """Merges two sorted sequences of sparse indices/values and return them in coalesced
20+ order.
21+
22+ All input args must be on the same device.
23+
24+ Args:
25+ old_nd (Tensor): [S x n_old] tensor of N-D indices
26+ new_nd (Tensor): [S x n_new] tensor of N-D indices
27+ old_values (Tensor): [n_old, ...] tensor of values
28+ new_values (Tensor): [n_new, ...] tensor of values
29+ insertion_positions (Tensor): [n_new] tensor of insertion positions in
30+ old_linear for each element in new_linear
31+
32+ Returns:
33+ merged_nd: [S, n_old+n_new]
34+ merged_values: [n_old+n_new, ...]
35+ """
36+ device = old_nd .device
37+ n_old , n_new = old_nd .size (1 ), new_nd .size (1 )
38+ n_total = n_old + n_new
39+
40+ # determine final positions of new values
41+ # account for previous insertions to get final positions of new rows
42+ new_positions = insertion_positions + torch .arange (
43+ n_new , device = device , dtype = insertion_positions .dtype
44+ )
45+
46+ # determine final positions of old values by counting how many new values are
47+ # inserted before each old value
48+ hist = torch .bincount (insertion_positions , minlength = n_old + 1 )
49+ old_shift = torch .cumsum (hist [:- 1 ], 0 )
50+ old_positions = torch .arange (n_old , device = device ) + old_shift
51+
52+ # allocate output tensors
53+ merged_nd = old_nd .new_empty (old_nd .size (0 ), n_total )
54+ merged_values = old_values .new_empty ((n_total ,) + old_values .shape [1 :])
55+
56+ # insert values
57+ merged_nd [:, old_positions ] = old_nd
58+ merged_nd [:, new_positions ] = new_nd
59+ merged_values [old_positions ] = old_values
60+ merged_values [new_positions ] = new_values
61+
62+ return merged_nd , merged_values
63+
64+
65+ # @torch.jit.script
1166def scatter_to_sparse_tensor (
1267 sparse_tensor : Tensor ,
1368 index_tensor : Tensor ,
@@ -21,12 +76,13 @@ def scatter_to_sparse_tensor(
2176 sparse tensor.
2277
2378 Args:
24- sparse_tensor (Tensor): Sparse tensor of dimension ..., M; where ... are
25- S leading sparse dimensions and M is the dense dimension
26- index_tensor (Tensor): Long tensor of dimension ..., S; where ... are
27- leading batch dimensions.
28- values (Tensor): Tensor of dimension ..., M; where ... are leading
29- batch dimensions and M is the dense dimension
79+ sparse_tensor (Tensor): Sparse tensor of dimension
80+ [s0, s1, s2, ..., d0, d1, d2, ...]; where s0, s1, ... are
81+ S leading sparse dimensions and d0, d1, d2, ... are D dense dimensions.
82+ index_tensor (Tensor): Long tensor of dimension [b0, b1, b2, ..., S]; where
83+ b0, b1, b2, ... are B leading batch dimensions.
84+ values (Tensor): Tensor of dimension [b0, b1, b2, ... d0, d1, d2, ...], where
85+ dimensions are as above.
3086 check_all_specified (bool): If True, this function will throw a ValueError
3187 if any of the indices specified in index_tensor are not already present
3288 in sparse_tensor. Default: False.
@@ -105,8 +161,15 @@ def scatter_to_sparse_tensor(
105161 index_tensor = torch .cat (index_tensor .unbind ())
106162 values = torch .cat (values .unbind ())
107163
108- assert index_tensor .shape [:- 1 ] == values .shape [:- 1 ]
109- assert sparse_tensor .dense_dim () == values .ndim - 1
164+ dense_dim = sparse_tensor .dense_dim ()
165+ sparse_dim = sparse_tensor .sparse_dim ()
166+ values_batch_dims = values .shape [:- dense_dim ] if dense_dim else values .shape
167+ if index_tensor .shape [:- 1 ] != values_batch_dims :
168+ raise ValueError (
169+ "Expected matching batch dims for `index_tensor` and `values`, but got "
170+ f"batch dims { index_tensor .shape [:- 1 ]} and "
171+ f"{ values_batch_dims } , respectively."
172+ )
110173
111174 sparse_tensor = sparse_tensor .coalesce ()
112175 sparse_tensor_values = sparse_tensor .values ()
@@ -151,19 +214,24 @@ def scatter_to_sparse_tensor(
151214 new_values = values [~ is_specified_mask ]
152215
153216 # Get sparse shape info for linearization
154- sparse_dim = sparse_tensor .sparse_dim ()
155217 sparse_sizes = torch .tensor (
156218 sparse_tensor .shape [:sparse_dim ], device = sparse_tensor .device
157219 )
158220
221+ if (new_nd_indices >= sparse_sizes .unsqueeze (0 )).any ():
222+ raise ValueError (
223+ "`index_tensor` has indices that are out of bounds of the original "
224+ f"sparse tensor's sparse shape ({ sparse_sizes } )."
225+ )
226+
159227 # Obtain linearized versions of all indices for sorting
160228 old_indices_nd = sparse_tensor .indices ()
161229 linear_offsets = _make_linear_offsets (sparse_sizes )
162230 new_indices_lin : Tensor = (new_nd_indices * linear_offsets ).sum (- 1 )
163231
164232 # Find duplicate linear indices
165- unique_new_indices_lin , inverse = new_indices_lin .unique (
166- sorted = True , return_inverse = True
233+ unique_new_indices_lin , inverse = torch .unique (
234+ new_indices_lin , sorted = True , return_inverse = True
167235 )
168236
169237 # Use inverse of indices unique to write to new values tensor and tensor of
@@ -198,56 +266,3 @@ def scatter_to_sparse_tensor(
198266 device = sparse_tensor .device ,
199267 is_coalesced = True ,
200268 )
201-
202-
203- def _merge_sorted (
204- old_nd : Tensor ,
205- new_nd : Tensor ,
206- old_values : Tensor ,
207- new_values : Tensor ,
208- insertion_positions : Tensor ,
209- ) -> tuple [Tensor , Tensor ]:
210- """Merges two sorted sequences of sparse indices/values and return them in coalesced
211- order.
212-
213- All input args must be on the same device.
214-
215- Args:
216- old_nd (Tensor): [S x n_old] tensor of N-D indices
217- new_nd (Tensor): [S x n_new] tensor of N-D indices
218- old_values (Tensor): [n_old, ...] tensor of values
219- new_values (Tensor): [n_new, ...] tensor of values
220- insertion_positions (Tensor): [n_new] tensor of insertion positions in
221- old_linear for each element in new_linear
222-
223- Returns:
224- merged_nd: [S, n_old+n_new]
225- merged_values: [n_old+n_new, ...]
226- """
227- device = old_nd .device
228- n_old , n_new = old_nd .size (1 ), new_nd .size (1 )
229- n_total = n_old + n_new
230-
231- # determine final positions of new values
232- # account for previous insertions to get final positions of new rows
233- new_positions = insertion_positions + torch .arange (
234- n_new , device = device , dtype = insertion_positions .dtype
235- )
236-
237- # determine final positions of old values by counting how many new values are
238- # inserted before each old value
239- hist = torch .bincount (insertion_positions , minlength = n_old + 1 )
240- old_shift = torch .cumsum (hist [:- 1 ], 0 )
241- old_positions = torch .arange (n_old , device = device ) + old_shift
242-
243- # allocate output tensors
244- merged_nd = old_nd .new_empty (old_nd .size (0 ), n_total )
245- merged_values = old_values .new_empty ((n_total ,) + old_values .shape [1 :])
246-
247- # insert values
248- merged_nd [:, old_positions ] = old_nd
249- merged_nd [:, new_positions ] = new_nd
250- merged_values [old_positions ] = old_values
251- merged_values [new_positions ] = new_values
252-
253- return merged_nd , merged_values
0 commit comments