From 6d5bc47b53d256af7d234350d80e02649fbb3503 Mon Sep 17 00:00:00 2001 From: Simone Manti Date: Mon, 25 Mar 2024 14:07:09 +0100 Subject: [PATCH] Added docs. --- src/dctkit/dec/cochain.py | 61 +++++++++++++++++++++++++++------------ src/dctkit/dec/flat.py | 22 ++++++++------ 2 files changed, 56 insertions(+), 27 deletions(-) diff --git a/src/dctkit/dec/cochain.py b/src/dctkit/dec/cochain.py index fe09cdc..510ed96 100644 --- a/src/dctkit/dec/cochain.py +++ b/src/dctkit/dec/cochain.py @@ -396,20 +396,22 @@ def sym(c: Cochain) -> Cochain: return scalar_mul(add(c, transpose(c)), 0.5) -def convolution(c: Cochain, kernel: Cochain, kernel_window: float): - # FIXME: fix the docs - # only implemented for scalar valued cochains - # NOTE: both c and kernel must be 0-cochains - n = len(c.coeffs) +def convolution(c: Cochain, kernel: Cochain, kernel_window: float) -> Cochain: + """ Compute the convolution between two scalar 0-cochains. + + Args: + c: a scalar 0-cochain. + kernel: the scalar 0-cochain kernel. + kernel_window: the kernel window. + Returns: + the convolution rho*kernel. + """ # we build a kernel matrix K by rolling the kernel vector k + 1 times, where - # k is the number of the kernel final zeros. In this way we can express the - # convolution between c and k as (SK^T)^T@c_coeffs where S is the hodge star + # k is the kernel window. In this way we can express the + # convolution between c and k as SK @c_coeffs where S is the hodge star + n = len(c.coeffs) K = jnp.zeros((n, n), dtype=dt.float_dtype) - # define first row of kernel matrix - # FIXME: add docs on this - # FIXME: continue from here - # kernel_row = kernel.coeffs buffer = jnp.empty((n, n*2 - 1)) # generate a wider array that we want a slice into @@ -420,25 +422,48 @@ def convolution(c: Cochain, kernel: Cochain, kernel_window: float): K_full_roll = jnp.roll(rolled[:, :n], shift=1, axis=0) K_non_zero = K_full_roll[:n - kernel_window + 1] K = K.at[:n - kernel_window + 1, :].set(K_non_zero) - kernel_coch = Cochain(c.dim, c.is_primal, c.complex, K) + + # apply hodge star star_kernel = star(kernel_coch) conv = Cochain(c.dim, c.is_primal, c.complex, star_kernel.coeffs@c.coeffs) return conv def constant_sub(k: float, c: Cochain) -> Cochain: + """Compute the cochain subtraction between a constant cochain and another cochain. + + Args: + k: a constant. + c: a cochain. + + Returns: + the resulting subtraction + """ return Cochain(c.dim, c.is_primal, c.complex, k - c.coeffs) -def abs(c: Cochain): +def abs(c: Cochain) -> Cochain: + """ Compute the absolute value of a cochain. + + Args: + c: a cochain. + + Returns: + its absolute value. + """ return Cochain(c.dim, c.is_primal, c.complex, jnp.abs(c.coeffs)) -def maximum(c_1: Cochain, c_2: Cochain): - return Cochain(c_1.dim, c_1.is_primal, c_1.complex, - jnp.maximum(c_1.coeffs, c_2.coeffs)) +def maximum(c_1: Cochain, c_2: Cochain) -> Cochain: + """ Compute the component-wise maximum between two cochains. + Args: + c_1: a cochain. + c_2: a cochain. -def inv(c: Cochain): - return Cochain(c.dim, c.is_primal, c.complex, 1/c.coeffs) + Returns: + the component-wise maximum + """ + return Cochain(c_1.dim, c_1.is_primal, c_1.complex, + jnp.maximum(c_1.coeffs, c_2.coeffs)) diff --git a/src/dctkit/dec/flat.py b/src/dctkit/dec/flat.py index 11d73b0..045478f 100644 --- a/src/dctkit/dec/flat.py +++ b/src/dctkit/dec/flat.py @@ -2,12 +2,12 @@ from dctkit.dec import cochain as C from jax import Array, vmap from functools import partial -from typing import Callable, Dict +from typing import Callable, Dict, Optional def flat(c: C.CochainP0 | C.CochainD0, weights: Array, edges: C.CochainP1V | - C.CochainD1V, I: Callable = None, - I_args: Dict = {}) -> C.CochainP1 | C.CochainD1: + C.CochainD1V, weighted_I: Optional[Callable] = None, + weighted_I_args: Optional[Dict] = {}) -> C.CochainP1 | C.CochainD1: """Applies the flat to a vector/tensor-valued cochain representing a discrete vector/tensor field to obtain a scalar-valued cochain over primal/dual edges. @@ -21,15 +21,19 @@ def flat(c: C.CochainP0 | C.CochainD0, weights: Array, edges: C.CochainP1V | interpolation scheme chosen for the input discrete vector/tensor field. edges: vector-valued cochain collecting the primal/dual edges over which the discrete vector/tensor field should be integrated. + weighted_I: interpolation function (callable) taking in input the cochain c + and providing in output a 1-cochain of the same type (primal/dual). If + it is None, then an interpolation function is built as W^T@c.coeffs. + weighted_I_args: additional keyword arguments for weighted_I Returns: a primal/dual scalar/vector-valued cochain defined over primal/dual edges. """ - if I is None: - # contract over the simplices of the input cochain (last axis of weights, first axis - # of input cochain coeffs) - def I(x): return jnp.tensordot(weights.T, x.coeffs, axes=1) - I_args = {} - weighted_v = I(c, **I_args) + if weighted_I is None: + # contract over the simplices of the input cochain (last axis of weights, + # first axis of input cochain coeffs) + def weighted_I(x): return jnp.tensordot(weights.T, x.coeffs, axes=1) + weighted_I_args = {} + weighted_v = weighted_I(c, **weighted_I_args) # contract input vector/tensors with edge vectors (last indices of both # coefficient matrices), for each target primal/dual edge contract = partial(jnp.tensordot, axes=([-1,], [-1,]))