Skip to content

Commit

Permalink
Fixed convolution operator.
Browse files Browse the repository at this point in the history
  • Loading branch information
Smantii committed Mar 25, 2024
1 parent e2c45d9 commit 4349c94
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 3 deletions.
6 changes: 3 additions & 3 deletions src/dctkit/dec/cochain.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,16 +413,16 @@ def convolution(c: Cochain, kernel: Cochain, kernel_window: float):
buffer = jnp.empty((n, n*2 - 1))

# generate a wider array that we want a slice into
buffer = buffer.at[:, :n].set(kernel.coeffs[:n])
buffer = buffer.at[:, n:].set(kernel.coeffs[:n-1])
buffer = buffer.at[:, :n].set(kernel.coeffs[:n].T)
buffer = buffer.at[:, n:].set(kernel.coeffs[:n-1].T)

rolled = buffer.reshape(-1)[n-1:-1].reshape(n, -1)
K_full_roll = jnp.roll(rolled[:, :n], shift=1, axis=0)
K_non_zero = K_full_roll[:n - kernel_window + 1]
K = K.at[:n - kernel_window + 1, :].set(K_non_zero)

kernel_coch = Cochain(c.dim, c.is_primal, c.complex, K)
star_kernel = transpose(star(transpose(kernel_coch)))
star_kernel = star(kernel_coch)
conv = Cochain(c.dim, c.is_primal, c.complex, star_kernel.coeffs@c.coeffs)
return conv

Expand Down
17 changes: 17 additions & 0 deletions tests/test_cochain.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,23 @@ def test_codifferential(setup_test):
assert np.allclose(inner_all[i], cod_inner_all[i])


def test_convolution(setup_test):
mesh_1, _ = util.generate_line_mesh(11, 1.)
S_1 = util.build_complex_from_mesh(mesh_1)
S_1.get_hodge_star()
n_1 = S_1.S[1].shape[0]
vD0 = np.arange(n_1, dtype=dctkit.float_dtype)
cD0 = C.CochainD0(complex=S_1, coeffs=vD0)
kernel = 4*np.arange(1, 4)
kernel_coeffs = np.zeros_like(vD0)
kernel_coeffs[:len(kernel)] = kernel
kernel_coch = C.CochainD0(complex=S_1, coeffs=kernel_coeffs)
conv = C.convolution(cD0, kernel_coch, len(kernel))
conv_true = np.array([3.2, 5.6, 8., 10.4, 12.8, 15.2,
17.6, 20., 0., 0.]).reshape(-1, 1)
assert np.allclose(conv.coeffs, conv_true)


def test_coboundary_closure(setup_test):
mesh_2, _ = util.generate_square_mesh(1.0)
S_2 = util.build_complex_from_mesh(mesh_2, is_well_centered=False)
Expand Down

0 comments on commit 4349c94

Please sign in to comment.