Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve JAX compatibility with ics and bcs. #1492

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions deepxde/data/fpde.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from .pde import PDE
from .. import backend as bkd
from .. import config
from ..backend import is_tensor, backend_name
from ..utils import array_ops_compat, run_if_all_none


Expand Down Expand Up @@ -100,7 +99,7 @@ def __init__(
def losses_train(self, targets, outputs, loss_fn, inputs, model, aux=None):
bcs_start = np.cumsum([0] + self.num_bcs)
# do not cache int_mat when alpha is a learnable parameter
if is_tensor(self.alpha):
if bkd.is_tensor(self.alpha):
int_mat = self.get_int_matrix(True)
else:
if self.int_mat_train is not None:
Expand All @@ -119,9 +118,14 @@ def losses_train(self, targets, outputs, loss_fn, inputs, model, aux=None):
loss_fn(bkd.zeros_like(fi), fi) for fi in f
]

outputs_fpde = outputs
if bkd.backend_name == "jax":
# JAX requires pure functions
outputs_fpde = (outputs, aux[0])

for i, bc in enumerate(self.bcs):
beg, end = bcs_start[i], bcs_start[i + 1]
error = bc.error(self.train_x, inputs, outputs, beg, end)
error = bc.error(self.train_x, inputs, outputs_fpde, beg, end)
losses.append(
loss_fn(bkd.zeros_like(error), error)
)
Expand All @@ -138,7 +142,7 @@ def losses_test(self, targets, outputs, loss_fn, inputs, model, aux=None):

def train_next_batch(self, batch_size=None):
# do not cache train data when alpha is a learnable parameter
if not is_tensor(self.alpha) or backend_name == "tensorflow.compat.v1":
if not bkd.is_tensor(self.alpha) or bkd.backend_name == "tensorflow.compat.v1":
if self.train_x is not None:
return self.train_x, self.train_y
if self.disc.meshtype == "static":
Expand Down Expand Up @@ -168,7 +172,7 @@ def train_next_batch(self, batch_size=None):

def test(self):
# do not cache test data when alpha is a learnable parameter
if not is_tensor(self.alpha) or backend_name == "tensorflow.compat.v1":
if not bkd.is_tensor(self.alpha) or bkd.backend_name == "tensorflow.compat.v1":
if self.test_x is not None:
return self.test_x, self.test_y

Expand Down
7 changes: 6 additions & 1 deletion deepxde/data/ide.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,14 @@ def losses_train(self, targets, outputs, loss_fn, inputs, model, aux=None):
f = [fi[bcs_start[-1] :] for fi in f]
losses = [loss_fn(bkd.zeros_like(fi), fi) for fi in f]

outputs_ide = outputs
if bkd.backend_name == "jax":
# JAX requires pure functions
outputs_ide = (outputs, aux[0])

for i, bc in enumerate(self.bcs):
beg, end = bcs_start[i], bcs_start[i + 1]
error = bc.error(self.train_x, inputs, outputs, beg, end)
error = bc.error(self.train_x, inputs, outputs_ide, beg, end)
losses.append(loss_fn(bkd.zeros_like(error), error))
return losses

Expand Down
8 changes: 3 additions & 5 deletions deepxde/data/pde.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from .data import Data
from .. import backend as bkd
from .. import config
from ..backend import backend_name
from ..utils import get_num_args, run_if_all_none, mpi_scatter_from_rank0


Expand Down Expand Up @@ -128,9 +127,8 @@ def __init__(
self.test()

def losses(self, targets, outputs, loss_fn, inputs, model, aux=None):
if backend_name in ["tensorflow.compat.v1", "tensorflow", "pytorch", "paddle"]:
outputs_pde = outputs
elif backend_name == "jax":
outputs_pde = outputs
if bkd.backend_name == "jax":
# JAX requires pure functions
outputs_pde = (outputs, aux[0])

Expand Down Expand Up @@ -166,7 +164,7 @@ def losses(self, targets, outputs, loss_fn, inputs, model, aux=None):
for i, bc in enumerate(self.bcs):
beg, end = bcs_start[i], bcs_start[i + 1]
# The same BC points are used for training and testing.
error = bc.error(self.train_x, inputs, outputs, beg, end)
error = bc.error(self.train_x, inputs, outputs_pde, beg, end)
losses.append(loss_fn[len(error_f) + i](bkd.zeros_like(error), error))
return losses

Expand Down
8 changes: 7 additions & 1 deletion deepxde/data/pde_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,19 @@ def losses(self, targets, outputs, loss_fn, inputs, model, aux=None):
bcs_start = np.cumsum([0] + self.num_bcs)
error_f = [fi[bcs_start[-1] :] for fi in f]
losses = [loss_fn(bkd.zeros_like(error), error) for error in error_f]

outputs_pdeoperator = outputs
if bkd.backend_name == "jax":
# JAX requries pure functions
outputs_pdeoperator = (outputs, aux[0])

for i, bc in enumerate(self.pde.bcs):
beg, end = bcs_start[i], bcs_start[i + 1]
# The same BC points are used for training and testing.
error = bc.error(
self.train_x[1],
inputs[1],
outputs,
outputs_pdeoperator,
beg,
end,
aux_var=self.train_aux_vars,
Expand Down
26 changes: 21 additions & 5 deletions deepxde/icbc/boundary_conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,11 @@ def collocation_points(self, X):
return self.filter(X)

def normal_derivative(self, X, inputs, outputs, beg, end):
dydx = grad.jacobian(outputs, inputs, i=self.component, j=None)[beg:end]
dydx = grad.jacobian(outputs, inputs, i=self.component, j=None)
if backend_name == "jax":
dydx = dydx[0]
n = self.boundary_normal(X, beg, end, None)
return bkd.sum(dydx * n, 1, keepdims=True)
return bkd.sum(dydx[beg:end] * n, 1, keepdims=True)

@abstractmethod
def error(self, X, inputs, outputs, beg, end, aux_var=None):
Expand All @@ -77,6 +79,8 @@ def error(self, X, inputs, outputs, beg, end, aux_var=None):
"DirichletBC function should return an array of shape N by 1 for each "
"component. Use argument 'component' for different output components."
)
if backend_name == "jax":
outputs = outputs[0]
return outputs[beg:end, self.component : self.component + 1] - values


Expand All @@ -100,9 +104,11 @@ def __init__(self, geom, func, on_boundary, component=0):
self.func = func

def error(self, X, inputs, outputs, beg, end, aux_var=None):
return self.normal_derivative(X, inputs, outputs, beg, end) - self.func(
X[beg:end], outputs[beg:end]
)
normal_derivative = self.normal_derivative(X, inputs, outputs, beg, end)
if backend_name == "jax":
outputs = outputs[0]
values = self.func(X[beg:end], outputs[beg:end])
return normal_derivative - values


class PeriodicBC(BC):
Expand All @@ -125,10 +131,14 @@ def collocation_points(self, X):
def error(self, X, inputs, outputs, beg, end, aux_var=None):
mid = beg + (end - beg) // 2
if self.derivative_order == 0:
if backend_name == "jax":
outputs = outputs[0]
yleft = outputs[beg:mid, self.component : self.component + 1]
yright = outputs[mid:end, self.component : self.component + 1]
else:
dydx = grad.jacobian(outputs, inputs, i=self.component, j=self.component_x)
if backend_name == "jax":
dydx = dydx[0]
yleft = dydx[beg:mid]
yright = dydx[mid:end]
return yleft - yright
Expand Down Expand Up @@ -158,6 +168,8 @@ def __init__(self, geom, func, on_boundary):
self.func = func

def error(self, X, inputs, outputs, beg, end, aux_var=None):
# User defined func is responsible for handling compatibility with the
# desired backend, this allows sel.func to include dde.grad components.
return self.func(inputs, outputs, X)[beg:end]


Expand Down Expand Up @@ -210,6 +222,8 @@ def collocation_points(self, X):
return self.points

def error(self, X, inputs, outputs, beg, end, aux_var=None):
if backend_name == "jax":
outputs = outputs[0]
if self.batch_size is not None:
if isinstance(self.component, numbers.Number):
return (
Expand Down Expand Up @@ -260,6 +274,8 @@ def collocation_points(self, X):
return self.points

def error(self, X, inputs, outputs, beg, end, aux_var=None):
# User defined func is responsible for handling compatibility with the
# desired backend.
return self.func(inputs, outputs, X)[beg:end] - self.values


Expand Down
2 changes: 2 additions & 0 deletions deepxde/icbc/initial_conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,6 @@ def error(self, X, inputs, outputs, beg, end, aux_var=None):
"IC function should return an array of shape N by 1 for each component."
"Use argument 'component' for different output components."
)
if bkd.backend_name == "jax":
outputs = outputs[0]
return outputs[beg:end, self.component : self.component + 1] - values