diff --git a/deepxde/data/fpde.py b/deepxde/data/fpde.py index 673739966..5e84cc28e 100644 --- a/deepxde/data/fpde.py +++ b/deepxde/data/fpde.py @@ -7,7 +7,6 @@ from .pde import PDE from .. import backend as bkd from .. import config -from ..backend import is_tensor, backend_name from ..utils import array_ops_compat, run_if_all_none @@ -100,7 +99,7 @@ def __init__( def losses_train(self, targets, outputs, loss_fn, inputs, model, aux=None): bcs_start = np.cumsum([0] + self.num_bcs) # do not cache int_mat when alpha is a learnable parameter - if is_tensor(self.alpha): + if bkd.is_tensor(self.alpha): int_mat = self.get_int_matrix(True) else: if self.int_mat_train is not None: @@ -119,9 +118,14 @@ def losses_train(self, targets, outputs, loss_fn, inputs, model, aux=None): loss_fn(bkd.zeros_like(fi), fi) for fi in f ] + outputs_fpde = outputs + if bkd.backend_name == "jax": + # JAX requires pure functions + outputs_fpde = (outputs, aux[0]) + for i, bc in enumerate(self.bcs): beg, end = bcs_start[i], bcs_start[i + 1] - error = bc.error(self.train_x, inputs, outputs, beg, end) + error = bc.error(self.train_x, inputs, outputs_fpde, beg, end) losses.append( loss_fn(bkd.zeros_like(error), error) ) @@ -138,7 +142,7 @@ def losses_test(self, targets, outputs, loss_fn, inputs, model, aux=None): def train_next_batch(self, batch_size=None): # do not cache train data when alpha is a learnable parameter - if not is_tensor(self.alpha) or backend_name == "tensorflow.compat.v1": + if not bkd.is_tensor(self.alpha) or bkd.backend_name == "tensorflow.compat.v1": if self.train_x is not None: return self.train_x, self.train_y if self.disc.meshtype == "static": @@ -168,7 +172,7 @@ def train_next_batch(self, batch_size=None): def test(self): # do not cache test data when alpha is a learnable parameter - if not is_tensor(self.alpha) or backend_name == "tensorflow.compat.v1": + if not bkd.is_tensor(self.alpha) or bkd.backend_name == "tensorflow.compat.v1": if self.test_x is not None: return self.test_x, self.test_y diff --git a/deepxde/data/ide.py b/deepxde/data/ide.py index 6f01b0f92..0b6864da8 100644 --- a/deepxde/data/ide.py +++ b/deepxde/data/ide.py @@ -57,9 +57,14 @@ def losses_train(self, targets, outputs, loss_fn, inputs, model, aux=None): f = [fi[bcs_start[-1] :] for fi in f] losses = [loss_fn(bkd.zeros_like(fi), fi) for fi in f] + outputs_ide = outputs + if bkd.backend_name == "jax": + # JAX requires pure functions + outputs_ide = (outputs, aux[0]) + for i, bc in enumerate(self.bcs): beg, end = bcs_start[i], bcs_start[i + 1] - error = bc.error(self.train_x, inputs, outputs, beg, end) + error = bc.error(self.train_x, inputs, outputs_ide, beg, end) losses.append(loss_fn(bkd.zeros_like(error), error)) return losses diff --git a/deepxde/data/pde.py b/deepxde/data/pde.py index acbcc6e25..fcf08214b 100644 --- a/deepxde/data/pde.py +++ b/deepxde/data/pde.py @@ -3,7 +3,6 @@ from .data import Data from .. import backend as bkd from .. import config -from ..backend import backend_name from ..utils import get_num_args, run_if_all_none, mpi_scatter_from_rank0 @@ -128,9 +127,8 @@ def __init__( self.test() def losses(self, targets, outputs, loss_fn, inputs, model, aux=None): - if backend_name in ["tensorflow.compat.v1", "tensorflow", "pytorch", "paddle"]: - outputs_pde = outputs - elif backend_name == "jax": + outputs_pde = outputs + if bkd.backend_name == "jax": # JAX requires pure functions outputs_pde = (outputs, aux[0]) @@ -166,7 +164,7 @@ def losses(self, targets, outputs, loss_fn, inputs, model, aux=None): for i, bc in enumerate(self.bcs): beg, end = bcs_start[i], bcs_start[i + 1] # The same BC points are used for training and testing. - error = bc.error(self.train_x, inputs, outputs, beg, end) + error = bc.error(self.train_x, inputs, outputs_pde, beg, end) losses.append(loss_fn[len(error_f) + i](bkd.zeros_like(error), error)) return losses diff --git a/deepxde/data/pde_operator.py b/deepxde/data/pde_operator.py index 5de5a6f46..b1cd6d4bf 100644 --- a/deepxde/data/pde_operator.py +++ b/deepxde/data/pde_operator.py @@ -80,13 +80,19 @@ def losses(self, targets, outputs, loss_fn, inputs, model, aux=None): bcs_start = np.cumsum([0] + self.num_bcs) error_f = [fi[bcs_start[-1] :] for fi in f] losses = [loss_fn(bkd.zeros_like(error), error) for error in error_f] + + outputs_pdeoperator = outputs + if bkd.backend_name == "jax": + # JAX requries pure functions + outputs_pdeoperator = (outputs, aux[0]) + for i, bc in enumerate(self.pde.bcs): beg, end = bcs_start[i], bcs_start[i + 1] # The same BC points are used for training and testing. error = bc.error( self.train_x[1], inputs[1], - outputs, + outputs_pdeoperator, beg, end, aux_var=self.train_aux_vars, diff --git a/deepxde/icbc/boundary_conditions.py b/deepxde/icbc/boundary_conditions.py index e1f863a08..5e82d88fe 100644 --- a/deepxde/icbc/boundary_conditions.py +++ b/deepxde/icbc/boundary_conditions.py @@ -52,9 +52,11 @@ def collocation_points(self, X): return self.filter(X) def normal_derivative(self, X, inputs, outputs, beg, end): - dydx = grad.jacobian(outputs, inputs, i=self.component, j=None)[beg:end] + dydx = grad.jacobian(outputs, inputs, i=self.component, j=None) + if backend_name == "jax": + dydx = dydx[0] n = self.boundary_normal(X, beg, end, None) - return bkd.sum(dydx * n, 1, keepdims=True) + return bkd.sum(dydx[beg:end] * n, 1, keepdims=True) @abstractmethod def error(self, X, inputs, outputs, beg, end, aux_var=None): @@ -77,6 +79,8 @@ def error(self, X, inputs, outputs, beg, end, aux_var=None): "DirichletBC function should return an array of shape N by 1 for each " "component. Use argument 'component' for different output components." ) + if backend_name == "jax": + outputs = outputs[0] return outputs[beg:end, self.component : self.component + 1] - values @@ -100,9 +104,11 @@ def __init__(self, geom, func, on_boundary, component=0): self.func = func def error(self, X, inputs, outputs, beg, end, aux_var=None): - return self.normal_derivative(X, inputs, outputs, beg, end) - self.func( - X[beg:end], outputs[beg:end] - ) + normal_derivative = self.normal_derivative(X, inputs, outputs, beg, end) + if backend_name == "jax": + outputs = outputs[0] + values = self.func(X[beg:end], outputs[beg:end]) + return normal_derivative - values class PeriodicBC(BC): @@ -125,10 +131,14 @@ def collocation_points(self, X): def error(self, X, inputs, outputs, beg, end, aux_var=None): mid = beg + (end - beg) // 2 if self.derivative_order == 0: + if backend_name == "jax": + outputs = outputs[0] yleft = outputs[beg:mid, self.component : self.component + 1] yright = outputs[mid:end, self.component : self.component + 1] else: dydx = grad.jacobian(outputs, inputs, i=self.component, j=self.component_x) + if backend_name == "jax": + dydx = dydx[0] yleft = dydx[beg:mid] yright = dydx[mid:end] return yleft - yright @@ -158,6 +168,8 @@ def __init__(self, geom, func, on_boundary): self.func = func def error(self, X, inputs, outputs, beg, end, aux_var=None): + # User defined func is responsible for handling compatibility with the + # desired backend, this allows sel.func to include dde.grad components. return self.func(inputs, outputs, X)[beg:end] @@ -210,6 +222,8 @@ def collocation_points(self, X): return self.points def error(self, X, inputs, outputs, beg, end, aux_var=None): + if backend_name == "jax": + outputs = outputs[0] if self.batch_size is not None: if isinstance(self.component, numbers.Number): return ( @@ -260,6 +274,8 @@ def collocation_points(self, X): return self.points def error(self, X, inputs, outputs, beg, end, aux_var=None): + # User defined func is responsible for handling compatibility with the + # desired backend. return self.func(inputs, outputs, X)[beg:end] - self.values diff --git a/deepxde/icbc/initial_conditions.py b/deepxde/icbc/initial_conditions.py index a7ca57f2f..cfef7b44a 100644 --- a/deepxde/icbc/initial_conditions.py +++ b/deepxde/icbc/initial_conditions.py @@ -33,4 +33,6 @@ def error(self, X, inputs, outputs, beg, end, aux_var=None): "IC function should return an array of shape N by 1 for each component." "Use argument 'component' for different output components." ) + if bkd.backend_name == "jax": + outputs = outputs[0] return outputs[beg:end, self.component : self.component + 1] - values