Skip to content

Commit

Permalink
Simplify create_diagonal_array
Browse files Browse the repository at this point in the history
  • Loading branch information
NeilGirdhar committed Feb 10, 2025
1 parent 5dd8a8a commit 8e5542b
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion tjax/_src/math_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,5 +171,5 @@ def create_diagonal_array(m: T) -> T:
retval = xp.zeros((*pre, n ** 2), dtype=m.dtype)
for index in np.ndindex(*pre):
target_index = (*index, slice(None, None, n + 1))
xpx.at(retval)[target_index].set(m[*index, :])
retval = xpx.at(retval)[target_index].set(m[*index, :])
return xp.reshape(retval, (*m.shape, n))

0 comments on commit 8e5542b

Please sign in to comment.