Skip to content

Commit

Permalink
Cleanup of flat functions.
Browse files Browse the repository at this point in the history
  • Loading branch information
alucantonio committed Mar 1, 2024
1 parent 59fd38c commit 85f2e0f
Show file tree
Hide file tree
Showing 11 changed files with 81 additions and 168 deletions.
2 changes: 1 addition & 1 deletion src/dctkit/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import sys
import enum
import jax.numpy as jnp
from jax.config import config as cfg
from jax import config as cfg

if sys.version_info[:2] >= (3, 8):
# TODO: Import directly (no need for conditional) when `python_requires = >= 3.8`
Expand Down
2 changes: 1 addition & 1 deletion src/dctkit/dec/cochain.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def star(c: Cochain) -> Cochain:
return star_c


def inner_product(c1: Cochain, c2: Cochain) -> Array:
def inner(c1: Cochain, c2: Cochain) -> Array:
"""Computes the inner product between two cochains.
Args:
Expand Down
102 changes: 14 additions & 88 deletions src/dctkit/dec/vector.py → src/dctkit/dec/flat.py
Original file line number Diff line number Diff line change
@@ -1,115 +1,64 @@
import jax.numpy as jnp
import numpy.typing as npt
import dctkit as dt
from dctkit.dec import cochain as C
from dctkit.mesh import simplex as spx
from jax import Array
import numpy as np


class ScalarField():
def __init__(self, S: spx.SimplicialComplex, field: callable):
self.S = S
self.field = field


class DiscreteTensorField():
"""Discrete tensor fields class.
Args:
S: the simplicial complex where the discrete vector field is defined.
is_primal: True if the discrete vector field is primal, False otherwise.
coeffs: array of the coefficients of the discrete vector fields.
rank: rank of the tensor.
"""

def __init__(self, S: spx.SimplicialComplex, is_primal: bool,
coeffs: npt.NDArray | Array, rank: int):
self.S = S
self.is_primal = is_primal
self.coeffs = coeffs
self.rank = rank


class DiscreteVectorField(DiscreteTensorField):
"""Inherited class for discrete vector fields."""

def __init__(self, S: spx.SimplicialComplex, is_primal: bool,
coeffs: npt.NDArray | Array):
super().__init__(S, is_primal, coeffs, 1)


class DiscreteVectorFieldD(DiscreteVectorField):
"""Inherited class for dual discrete vector fields."""

def __init__(self, S: spx.SimplicialComplex, coeffs: npt.NDArray | Array):
super().__init__(S, False, coeffs)


class DiscreteTensorFieldD(DiscreteTensorField):
"""Inherited class for dual discrete tensor fields."""

def __init__(self, S: spx.SimplicialComplex, coeffs: npt.NDArray | Array,
rank: int):
super().__init__(S, False, coeffs, rank)


def flat_DPD(v: DiscreteTensorFieldD) -> C.CochainD1:
def flat_DPD(c: C.CochainD0V | C.CochainD0T) -> C.CochainD1:
"""Implements the flat DPD operator for dual discrete vector fields.
Args:
v: a dual discrete vector field.
Returns:
the dual 1-cochain resulting from the application of the flat operator.
"""
dedges = v.S.dual_edges_vectors[:, :v.coeffs.shape[0]]
flat_matrix = v.S.flat_DPD_weights
dedges = c.complex.dual_edges_vectors[:, :c.coeffs.shape[0]]
flat_matrix = c.complex.flat_DPD_weights
# multiply weights of each dual edge by the vectors associated to the dual nodes
# belonging to the edge
weighted_v = v.coeffs @ flat_matrix
if v.rank == 1:
weighted_v = c.coeffs @ flat_matrix
if c.coeffs.ndim == 2:
# vector field case
# perform dot product row-wise with the edge vectors
# of the dual edges (see definition of DPD in Hirani, pag. 54).
weighted_v_T = weighted_v.T
coch_coeffs = jnp.einsum("ij, ij -> i", weighted_v_T, dedges)
elif v.rank == 2:
elif c.coeffs.ndim == 3:
# tensor field case
# apply each matrix (rows of the multiarray weighted_v_T fixing the first axis)
# to the edge vector of the corresponding dual edge
weighted_v_T = jnp.transpose(weighted_v, axes=(2, 0, 1))
coch_coeffs = jnp.einsum("ijk, ik -> ij", weighted_v_T, dedges)
return C.CochainD1(v.S, coch_coeffs)
return C.CochainD1(c.complex, coch_coeffs)


def flat_DPP(v: DiscreteTensorFieldD) -> C.CochainP1:
def flat_DPP(c: C.CochainD0V | C.CochainD0T) -> C.CochainP1:
"""Implements the flat DPP operator for dual discrete vector fields.
Args:
v: a dual discrete vector field.
Returns:
the primal 1-cochain resulting from the application of the flat operator.
"""
primal_edges = v.S.primal_edges_vectors[:, :v.coeffs.shape[0]]
flat_matrix = v.S.flat_DPP_weights
primal_edges = c.complex.primal_edges_vectors[:, :c.coeffs.shape[0]]
flat_matrix = c.complex.flat_DPP_weights
# multiply weights of each primal edge by the vectors associated to the dual nodes
# belonging to the corresponding dual edge
weighted_v = v.coeffs @ flat_matrix
if v.rank == 1:
weighted_v = c.coeffs @ flat_matrix
if c.coeffs.ndim == 2:
# vector field case
# perform dot product row-wise with the edge vectors
# of the dual edges (see definition of DPD in Hirani, pag. 54).
weighted_v_T = weighted_v.T
coch_coeffs = jnp.einsum("ij, ij -> i", weighted_v_T,
primal_edges)
elif v.rank == 2:
elif c.coeffs.ndim == 3:
# tensor field case
# apply each matrix (rows of the multiarray weighted_v_T fixing the first axis)
# to the edge vector of the corresponding dual edge
weighted_v_T = jnp.transpose(weighted_v, axes=(2, 0, 1))
coch_coeffs = jnp.einsum("ijk, ik -> ij", weighted_v_T,
primal_edges)
return C.CochainP1(v.S, coch_coeffs)
return C.CochainP1(c.complex, coch_coeffs)


def flat_PDP(c: C.CochainP0) -> C.CochainP1:
Expand Down Expand Up @@ -139,26 +88,3 @@ def flat_PDD(c: C.CochainD0, scheme: str) -> C.CochainD1:
flat_c_coeffs = flat_c_coeffs.at[1:-1].set(0.5 * (dual_volumes[1:-1] *
(c.coeffs[:-1] + c.coeffs[1:]).T).T)
return C.CochainD1(c.complex, flat_c_coeffs)


def upwind_interpolation(c: C.CochainD0) -> ScalarField:
S = c.complex
circ = S.circ[1][:, 0]

def field(x):
# find the indices such that circ[idx] <= x <= circ[idx+1] pointwise
idx = np.searchsorted(circ, x)
return c.coeffs[idx-1]

return ScalarField(c.complex, field)


def upwind_integration(s: ScalarField) -> C.CochainD1:
circ = s.S.node_coords[:, 0]
dual_volumes = s.S.dual_volumes[0]
coeffs = s.field(circ)*dual_volumes
return C.CochainD1(s.S, coeffs)


def flat_PDD_2(c: C.CochainD0) -> C.CochainD1:
return upwind_integration(upwind_interpolation(c))
4 changes: 2 additions & 2 deletions src/dctkit/physics/elastica.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,9 @@ def energy(self, theta: npt.NDArray, B: float, theta_0: float, F: float) -> Arra

# potential of the applied load
A_coch = C.scalar_mul(self.ones_coch, A)
load = C.inner_product(C.sin(theta_coch), A_coch)
load = C.inner(C.sin(theta_coch), A_coch)

energy = 0.5*C.inner_product(moment, curvature) - load
energy = 0.5*C.inner(moment, curvature) - load

return energy

Expand Down
11 changes: 6 additions & 5 deletions src/dctkit/physics/elasticity.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy.typing as npt
from dctkit.mesh.simplex import SimplicialComplex
import dctkit.dec.cochain as C
import dctkit.dec.vector as V
import dctkit.dec.flat as V
from jax import Array
import jax.numpy as jnp
from typing import Tuple, Dict
Expand Down Expand Up @@ -123,7 +123,7 @@ def force_balance_residual_primal(self, node_coords: C.CochainP0, f: C.CochainP2
"""
strain = self.get_infinitesimal_strain(node_coords=node_coords.coeffs)
stress = self.get_stress(strain=strain)
stress_tensor = V.DiscreteTensorFieldD(S=self.S, coeffs=stress.T, rank=2)
stress_tensor = C.CochainD0T(complex=self.S, coeffs=stress.T)
stress_integrated = V.flat_DPD(stress_tensor)
forces = C.star(stress_integrated)
# set tractions on given sub-portions of the boundary
Expand Down Expand Up @@ -153,7 +153,8 @@ def get_dual_balance(self, node_coords: C.CochainP0,
"""
strain = self.get_infinitesimal_strain(node_coords=node_coords.coeffs)
stress = self.get_stress(strain=strain)
stress_tensor = V.DiscreteTensorFieldD(S=self.S, coeffs=stress.T, rank=2)
stress_tensor = C.CochainD0T(complex=self.S, coeffs=stress.T)
print(stress.ndim)
# compute forces on dual edges
stress_integrated = V.flat_DPP(stress_tensor)
forces = C.star(stress_integrated)
Expand Down Expand Up @@ -211,8 +212,8 @@ def elasticity_energy(self, node_coords: C.CochainP0, f: C.CochainP0) -> float:
ref_node_coords = C.CochainP0(self.S, self.S.node_coords)
displacement = C.sub(node_coords, ref_node_coords)
elastic_energy = 0.5 * \
C.inner_product(strain_cochain, stress_cochain) - \
C.inner_product(displacement, f)
C.inner(strain_cochain, stress_cochain) - \
C.inner(displacement, f)
return elastic_energy

def obj_linear_elasticity_primal(self, node_coords: npt.NDArray | Array,
Expand Down
4 changes: 2 additions & 2 deletions src/dctkit/physics/poisson.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,8 @@ def energy_poisson(x: npt.NDArray, f: npt.NDArray, S, k: float, boundary_values:
f_coch = C.CochainP0(S, f)
u = C.CochainP0(S, x)
du = C.coboundary(u)
norm_grad = k/2*C.inner_product(du, du)
bound_term = -C.inner_product(u, f_coch)
norm_grad = k/2*C.inner(du, du)
bound_term = -C.inner(u, f_coch)
penalty = 0.5*gamma*np.sum((x[pos] - value)**2)
energy = norm_grad + bound_term + penalty
return energy
Expand Down
90 changes: 45 additions & 45 deletions tests/test_cochain.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,8 +341,8 @@ def test_inner_product(setup_test):
cP1_1 = C.CochainP1(complex=S_1, coeffs=vP1_1)
cP1_2 = C.CochainP1(complex=S_1, coeffs=vP1_2)

inner_productP0 = C.inner_product(cP0_1, cP0_2)
inner_productP1 = C.inner_product(cP1_1, cP1_2)
inner_productP0 = C.inner(cP0_1, cP0_2)
inner_productP1 = C.inner(cP1_1, cP1_2)
inner_product_all = [inner_productP0, inner_productP1]
inner_productP0_true = np.dot(vP0_1, S_1.hodge_star[0]*vP0_2)
inner_productP1_true = np.dot(vP1_1, S_1.hodge_star[1]*vP1_2)
Expand All @@ -366,9 +366,9 @@ def test_inner_product(setup_test):
cP2_1 = C.CochainP2(complex=S_2, coeffs=vP2_1)
cP2_2 = C.CochainP2(complex=S_2, coeffs=vP2_2)

inner_productP0 = C.inner_product(cP0_1, cP0_2)
inner_productP1 = C.inner_product(cP1_1, cP1_2)
inner_productP2 = C.inner_product(cP2_1, cP2_2)
inner_productP0 = C.inner(cP0_1, cP0_2)
inner_productP1 = C.inner(cP1_1, cP1_2)
inner_productP2 = C.inner(cP2_1, cP2_2)
inner_product_all = [inner_productP0, inner_productP1, inner_productP2]
inner_productP0_true = np.dot(vP0_1, S_2.hodge_star[0]*vP0_2)
inner_productP1_true = np.dot(vP1_1, S_2.hodge_star[1]*vP1_2)
Expand Down Expand Up @@ -399,10 +399,10 @@ def test_inner_product(setup_test):
cP3_1 = C.CochainP3(complex=S_3, coeffs=vP3_1)
cP3_2 = C.CochainP3(complex=S_3, coeffs=vP3_2)

inner_productP0 = C.inner_product(cP0_1, cP0_2)
inner_productP1 = C.inner_product(cP1_1, cP1_2)
inner_productP2 = C.inner_product(cP2_1, cP2_2)
inner_productP3 = C.inner_product(cP3_1, cP3_2)
inner_productP0 = C.inner(cP0_1, cP0_2)
inner_productP1 = C.inner(cP1_1, cP1_2)
inner_productP2 = C.inner(cP2_1, cP2_2)
inner_productP3 = C.inner(cP3_1, cP3_2)
inner_product_all = [inner_productP0,
inner_productP1, inner_productP2, inner_productP3]
inner_productP0_true = np.dot(vP0_1, S_3.hodge_star[0]*vP0_2)
Expand Down Expand Up @@ -438,12 +438,12 @@ def test_inner_product(setup_test):
cD0_1 = C.star(cP2_1)
cD0_2 = C.star(cP2_2)

inner_product_P0 = C.inner_product(cP0_1, cP0_2)
inner_product_P1 = C.inner_product(cP1_1, cP1_2)
inner_product_P2 = C.inner_product(cP2_1, cP2_2)
inner_product_D2 = C.inner_product(cD2_1, cD2_2)
inner_product_D1 = C.inner_product(cD1_1, cD1_2)
inner_product_D0 = C.inner_product(cD0_1, cD0_2)
inner_product_P0 = C.inner(cP0_1, cP0_2)
inner_product_P1 = C.inner(cP1_1, cP1_2)
inner_product_P2 = C.inner(cP2_1, cP2_2)
inner_product_D2 = C.inner(cD2_1, cD2_2)
inner_product_D1 = C.inner(cD1_1, cD1_2)
inner_product_D0 = C.inner(cD0_1, cD0_2)
inner_product_P_all = np.array(
[inner_product_P0, inner_product_P1, inner_product_P2])
inner_product_D_all = np.array(
Expand Down Expand Up @@ -474,12 +474,12 @@ def test_inner_product(setup_test):
cD0_1 = C.star(cP2_1)
cD0_2 = C.star(cP2_2)

inner_product_P0 = C.inner_product(cP0_1, cP0_2)
inner_product_P1 = C.inner_product(cP1_1, cP1_2)
inner_product_P2 = C.inner_product(cP2_1, cP2_2)
inner_product_D2 = C.inner_product(cD2_1, cD2_2)
inner_product_D1 = C.inner_product(cD1_1, cD1_2)
inner_product_D0 = C.inner_product(cD0_1, cD0_2)
inner_product_P0 = C.inner(cP0_1, cP0_2)
inner_product_P1 = C.inner(cP1_1, cP1_2)
inner_product_P2 = C.inner(cP2_1, cP2_2)
inner_product_D2 = C.inner(cD2_1, cD2_2)
inner_product_D1 = C.inner(cD1_1, cD1_2)
inner_product_D0 = C.inner(cD0_1, cD0_2)
inner_product_P_all = np.array(
[inner_product_P0, inner_product_P1, inner_product_P2])
inner_product_D_all = np.array(
Expand Down Expand Up @@ -516,11 +516,11 @@ def test_codifferential(setup_test):
cD0 = C.CochainD0(complex=S_1, coeffs=vD0)
cD1 = C.CochainD1(complex=S_1, coeffs=vD1)

innerP0P1 = C.inner_product(C.coboundary(cP0), cP1)
innerD0D1 = C.inner_product(C.coboundary(cD0), cD1)
innerP0P1 = C.inner(C.coboundary(cP0), cP1)
innerD0D1 = C.inner(C.coboundary(cD0), cD1)
inner_all = [innerP0P1, innerD0D1]
cod_innerP0P1 = C.inner_product(cP0, C.codifferential(cP1))
cod_innerD0D1 = C.inner_product(cD0, C.codifferential(cD1))
cod_innerP0P1 = C.inner(cP0, C.codifferential(cP1))
cod_innerD0D1 = C.inner(cD0, C.codifferential(cD1))
cod_inner_all = [cod_innerP0P1, cod_innerD0D1]

for i in range(2):
Expand All @@ -544,15 +544,15 @@ def test_codifferential(setup_test):
cD1 = C.CochainD1(complex=S_2, coeffs=vD1)
cD2 = C.CochainD2(complex=S_2, coeffs=vD2)

innerP0P1 = C.inner_product(C.coboundary(cP0), cP1)
innerP1P2 = C.inner_product(C.coboundary(cP1), cP2)
innerD0D1 = C.inner_product(C.coboundary(cD0), cD1)
innerD1D2 = C.inner_product(C.coboundary(cD1), cD2)
innerP0P1 = C.inner(C.coboundary(cP0), cP1)
innerP1P2 = C.inner(C.coboundary(cP1), cP2)
innerD0D1 = C.inner(C.coboundary(cD0), cD1)
innerD1D2 = C.inner(C.coboundary(cD1), cD2)
inner_all = [innerP0P1, innerP1P2, innerD0D1, innerD1D2]
cod_innerP0P1 = C.inner_product(cP0, C.codifferential(cP1))
cod_innerP1P2 = C.inner_product(cP1, C.codifferential(cP2))
cod_innerD0D1 = C.inner_product(cD0, C.codifferential(cD1))
cod_innerD1D2 = C.inner_product(cD1, C.codifferential(cD2))
cod_innerP0P1 = C.inner(cP0, C.codifferential(cP1))
cod_innerP1P2 = C.inner(cP1, C.codifferential(cP2))
cod_innerD0D1 = C.inner(cD0, C.codifferential(cD1))
cod_innerD1D2 = C.inner(cD1, C.codifferential(cD2))
cod_inner_all = [cod_innerP0P1, cod_innerP1P2, cod_innerD0D1, cod_innerD1D2]

for i in range(4):
Expand Down Expand Up @@ -581,19 +581,19 @@ def test_codifferential(setup_test):
cD2 = C.CochainD2(complex=S_3, coeffs=vD2)
cD3 = C.CochainD3(complex=S_3, coeffs=vD3)

innerP0P1 = C.inner_product(C.coboundary(cP0), cP1)
innerP1P2 = C.inner_product(C.coboundary(cP1), cP2)
innerP2P3 = C.inner_product(C.coboundary(cP2), cP3)
innerD0D1 = C.inner_product(C.coboundary(cD0), cD1)
innerD1D2 = C.inner_product(C.coboundary(cD1), cD2)
innerD2D3 = C.inner_product(C.coboundary(cD2), cD3)
innerP0P1 = C.inner(C.coboundary(cP0), cP1)
innerP1P2 = C.inner(C.coboundary(cP1), cP2)
innerP2P3 = C.inner(C.coboundary(cP2), cP3)
innerD0D1 = C.inner(C.coboundary(cD0), cD1)
innerD1D2 = C.inner(C.coboundary(cD1), cD2)
innerD2D3 = C.inner(C.coboundary(cD2), cD3)
inner_all = [innerP0P1, innerP1P2, innerP2P3, innerD0D1, innerD1D2, innerD2D3]
cod_innerP0P1 = C.inner_product(cP0, C.codifferential(cP1))
cod_innerP1P2 = C.inner_product(cP1, C.codifferential(cP2))
cod_innerP2P3 = C.inner_product(cP2, C.codifferential(cP3))
cod_innerD0D1 = C.inner_product(cD0, C.codifferential(cD1))
cod_innerD1D2 = C.inner_product(cD1, C.codifferential(cD2))
cod_innerD2D3 = C.inner_product(cD2, C.codifferential(cD3))
cod_innerP0P1 = C.inner(cP0, C.codifferential(cP1))
cod_innerP1P2 = C.inner(cP1, C.codifferential(cP2))
cod_innerP2P3 = C.inner(cP2, C.codifferential(cP3))
cod_innerD0D1 = C.inner(cD0, C.codifferential(cD1))
cod_innerD1D2 = C.inner(cD1, C.codifferential(cD2))
cod_innerD2D3 = C.inner(cD2, C.codifferential(cD3))
cod_inner_all = [cod_innerP0P1, cod_innerP1P2,
cod_innerP2P3, cod_innerD0D1, cod_innerD1D2, cod_innerD2D3]

Expand Down
4 changes: 2 additions & 2 deletions tests/test_optctrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@ def energy_poisson(u: npt.NDArray, f: npt.NDArray) -> float:
f_coch = C.CochainP0(S, f*np.ones(dim_0, dtype=dt.float_dtype))
u_coch = C.CochainP0(S, u)
du = C.coboundary(u_coch)
norm_grad = k/2.*C.inner_product(du, du)
bound_term = -C.inner_product(u_coch, f_coch)
norm_grad = k/2.*C.inner(du, du)
bound_term = -C.inner(u_coch, f_coch)
penalty = 0.5*gamma*dt.backend.sum((u[pos] - value)**2)
energy = norm_grad + bound_term + penalty
return energy
Expand Down
Loading

0 comments on commit 85f2e0f

Please sign in to comment.