-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
56 additions
and
27 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,12 +2,12 @@ | |
from dctkit.dec import cochain as C | ||
from jax import Array, vmap | ||
from functools import partial | ||
from typing import Callable, Dict | ||
from typing import Callable, Dict, Optional | ||
|
||
|
||
def flat(c: C.CochainP0 | C.CochainD0, weights: Array, edges: C.CochainP1V | | ||
C.CochainD1V, I: Callable = None, | ||
I_args: Dict = {}) -> C.CochainP1 | C.CochainD1: | ||
C.CochainD1V, weighted_I: Optional[Callable] = None, | ||
weighted_I_args: Optional[Dict] = {}) -> C.CochainP1 | C.CochainD1: | ||
"""Applies the flat to a vector/tensor-valued cochain representing a discrete | ||
vector/tensor field to obtain a scalar-valued cochain over primal/dual edges. | ||
|
@@ -21,15 +21,19 @@ def flat(c: C.CochainP0 | C.CochainD0, weights: Array, edges: C.CochainP1V | | |
interpolation scheme chosen for the input discrete vector/tensor field. | ||
edges: vector-valued cochain collecting the primal/dual edges over which the | ||
discrete vector/tensor field should be integrated. | ||
weighted_I: interpolation function (callable) taking in input the cochain c | ||
and providing in output a 1-cochain of the same type (primal/dual). If | ||
it is None, then an interpolation function is built as W^[email protected]. | ||
weighted_I_args: additional keyword arguments for weighted_I | ||
Returns: | ||
a primal/dual scalar/vector-valued cochain defined over primal/dual edges. | ||
""" | ||
if I is None: | ||
# contract over the simplices of the input cochain (last axis of weights, first axis | ||
# of input cochain coeffs) | ||
def I(x): return jnp.tensordot(weights.T, x.coeffs, axes=1) | ||
I_args = {} | ||
weighted_v = I(c, **I_args) | ||
if weighted_I is None: | ||
# contract over the simplices of the input cochain (last axis of weights, | ||
# first axis of input cochain coeffs) | ||
def weighted_I(x): return jnp.tensordot(weights.T, x.coeffs, axes=1) | ||
weighted_I_args = {} | ||
weighted_v = weighted_I(c, **weighted_I_args) | ||
# contract input vector/tensors with edge vectors (last indices of both | ||
# coefficient matrices), for each target primal/dual edge | ||
contract = partial(jnp.tensordot, axes=([-1,], [-1,])) | ||
|