Skip to content

Commit

Permalink
Added docs.
Browse files Browse the repository at this point in the history
  • Loading branch information
Smantii committed Mar 25, 2024
1 parent 4349c94 commit 6d5bc47
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 27 deletions.
61 changes: 43 additions & 18 deletions src/dctkit/dec/cochain.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,20 +396,22 @@ def sym(c: Cochain) -> Cochain:
return scalar_mul(add(c, transpose(c)), 0.5)


def convolution(c: Cochain, kernel: Cochain, kernel_window: float):
# FIXME: fix the docs
# only implemented for scalar valued cochains
# NOTE: both c and kernel must be 0-cochains
n = len(c.coeffs)
def convolution(c: Cochain, kernel: Cochain, kernel_window: float) -> Cochain:
""" Compute the convolution between two scalar 0-cochains.
Args:
c: a scalar 0-cochain.
kernel: the scalar 0-cochain kernel.
kernel_window: the kernel window.
Returns:
the convolution rho*kernel.
"""
# we build a kernel matrix K by rolling the kernel vector k + 1 times, where
# k is the number of the kernel final zeros. In this way we can express the
# convolution between c and k as (SK^T)^T@c_coeffs where S is the hodge star
# k is the kernel window. In this way we can express the
# convolution between c and k as SK @c_coeffs where S is the hodge star
n = len(c.coeffs)
K = jnp.zeros((n, n), dtype=dt.float_dtype)
# define first row of kernel matrix
# FIXME: add docs on this
# FIXME: continue from here
# kernel_row = kernel.coeffs
buffer = jnp.empty((n, n*2 - 1))

# generate a wider array that we want a slice into
Expand All @@ -420,25 +422,48 @@ def convolution(c: Cochain, kernel: Cochain, kernel_window: float):
K_full_roll = jnp.roll(rolled[:, :n], shift=1, axis=0)
K_non_zero = K_full_roll[:n - kernel_window + 1]
K = K.at[:n - kernel_window + 1, :].set(K_non_zero)

kernel_coch = Cochain(c.dim, c.is_primal, c.complex, K)

# apply hodge star
star_kernel = star(kernel_coch)
conv = Cochain(c.dim, c.is_primal, c.complex, star_kernel.coeffs@c.coeffs)
return conv


def constant_sub(k: float, c: Cochain) -> Cochain:
"""Compute the cochain subtraction between a constant cochain and another cochain.
Args:
k: a constant.
c: a cochain.
Returns:
the resulting subtraction
"""
return Cochain(c.dim, c.is_primal, c.complex, k - c.coeffs)


def abs(c: Cochain):
def abs(c: Cochain) -> Cochain:
""" Compute the absolute value of a cochain.
Args:
c: a cochain.
Returns:
its absolute value.
"""
return Cochain(c.dim, c.is_primal, c.complex, jnp.abs(c.coeffs))


def maximum(c_1: Cochain, c_2: Cochain):
return Cochain(c_1.dim, c_1.is_primal, c_1.complex,
jnp.maximum(c_1.coeffs, c_2.coeffs))
def maximum(c_1: Cochain, c_2: Cochain) -> Cochain:
""" Compute the component-wise maximum between two cochains.
Args:
c_1: a cochain.
c_2: a cochain.
def inv(c: Cochain):
return Cochain(c.dim, c.is_primal, c.complex, 1/c.coeffs)
Returns:
the component-wise maximum
"""
return Cochain(c_1.dim, c_1.is_primal, c_1.complex,
jnp.maximum(c_1.coeffs, c_2.coeffs))
22 changes: 13 additions & 9 deletions src/dctkit/dec/flat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
from dctkit.dec import cochain as C
from jax import Array, vmap
from functools import partial
from typing import Callable, Dict
from typing import Callable, Dict, Optional


def flat(c: C.CochainP0 | C.CochainD0, weights: Array, edges: C.CochainP1V |
C.CochainD1V, I: Callable = None,
I_args: Dict = {}) -> C.CochainP1 | C.CochainD1:
C.CochainD1V, weighted_I: Optional[Callable] = None,
weighted_I_args: Optional[Dict] = {}) -> C.CochainP1 | C.CochainD1:
"""Applies the flat to a vector/tensor-valued cochain representing a discrete
vector/tensor field to obtain a scalar-valued cochain over primal/dual edges.
Expand All @@ -21,15 +21,19 @@ def flat(c: C.CochainP0 | C.CochainD0, weights: Array, edges: C.CochainP1V |
interpolation scheme chosen for the input discrete vector/tensor field.
edges: vector-valued cochain collecting the primal/dual edges over which the
discrete vector/tensor field should be integrated.
weighted_I: interpolation function (callable) taking in input the cochain c
and providing in output a 1-cochain of the same type (primal/dual). If
it is None, then an interpolation function is built as W^[email protected].
weighted_I_args: additional keyword arguments for weighted_I
Returns:
a primal/dual scalar/vector-valued cochain defined over primal/dual edges.
"""
if I is None:
# contract over the simplices of the input cochain (last axis of weights, first axis
# of input cochain coeffs)
def I(x): return jnp.tensordot(weights.T, x.coeffs, axes=1)
I_args = {}
weighted_v = I(c, **I_args)
if weighted_I is None:
# contract over the simplices of the input cochain (last axis of weights,
# first axis of input cochain coeffs)
def weighted_I(x): return jnp.tensordot(weights.T, x.coeffs, axes=1)
weighted_I_args = {}
weighted_v = weighted_I(c, **weighted_I_args)
# contract input vector/tensors with edge vectors (last indices of both
# coefficient matrices), for each target primal/dual edge
contract = partial(jnp.tensordot, axes=([-1,], [-1,]))
Expand Down

0 comments on commit 6d5bc47

Please sign in to comment.