@@ -83,7 +83,7 @@ def assemble(expr, *args, **kwargs):
8383 zero_bc_nodes : bool
8484 If `True`, set the boundary condition nodes in the
8585 output tensor to zero rather than to the values prescribed by the
86- boundary condition. Default is `False `.
86+ boundary condition. Default is `True `.
8787 diagonal : bool
8888 If assembling a matrix is it diagonal?
8989 weight : float
@@ -143,7 +143,6 @@ def get_assembler(form, *args, **kwargs):
143143
144144 """
145145 is_base_form_preprocessed = kwargs .pop ('is_base_form_preprocessed' , False )
146- bcs = kwargs .get ('bcs' , None )
147146 fc_params = kwargs .get ('form_compiler_parameters' , None )
148147 if isinstance (form , ufl .form .BaseForm ) and not is_base_form_preprocessed :
149148 mat_type = kwargs .get ('mat_type' , None )
@@ -155,8 +154,13 @@ def get_assembler(form, *args, **kwargs):
155154 if len (form .arguments ()) == 0 :
156155 return ZeroFormAssembler (form , form_compiler_parameters = fc_params )
157156 elif len (form .arguments ()) == 1 or diagonal :
158- return OneFormAssembler (form , * args , bcs = bcs , form_compiler_parameters = fc_params , needs_zeroing = kwargs .get ('needs_zeroing' , True ),
159- zero_bc_nodes = kwargs .get ('zero_bc_nodes' , False ), diagonal = diagonal )
157+ return OneFormAssembler (form , * args ,
158+ bcs = kwargs .get ("bcs" , None ),
159+ form_compiler_parameters = fc_params ,
160+ needs_zeroing = kwargs .get ("needs_zeroing" , True ),
161+ zero_bc_nodes = kwargs .get ("zero_bc_nodes" , True ),
162+ diagonal = diagonal ,
163+ weight = kwargs .get ("weight" , 1.0 ))
160164 elif len (form .arguments ()) == 2 :
161165 return TwoFormAssembler (form , * args , ** kwargs )
162166 else :
@@ -308,7 +312,7 @@ def __init__(self,
308312 sub_mat_type = None ,
309313 options_prefix = None ,
310314 appctx = None ,
311- zero_bc_nodes = False ,
315+ zero_bc_nodes = True ,
312316 diagonal = False ,
313317 weight = 1.0 ,
314318 allocation_integral_types = None ):
@@ -381,6 +385,12 @@ def visitor(e, *operands):
381385 visited = {}
382386 result = BaseFormAssembler .base_form_postorder_traversal (self ._form , visitor , visited )
383387
388+ # Apply BCs after assembly
389+ rank = len (self ._form .arguments ())
390+ if rank == 1 and not isinstance (result , ufl .ZeroBaseForm ):
391+ for bc in self ._bcs :
392+ bc .zero (result )
393+
384394 if tensor :
385395 BaseFormAssembler .update_tensor (result , tensor )
386396 return tensor
@@ -405,8 +415,8 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
405415 if rank == 0 :
406416 assembler = ZeroFormAssembler (form , form_compiler_parameters = self ._form_compiler_params )
407417 elif rank == 1 or (rank == 2 and self ._diagonal ):
408- assembler = OneFormAssembler (form , bcs = self . _bcs , form_compiler_parameters = self ._form_compiler_params ,
409- zero_bc_nodes = self ._zero_bc_nodes , diagonal = self ._diagonal )
418+ assembler = OneFormAssembler (form , form_compiler_parameters = self ._form_compiler_params ,
419+ zero_bc_nodes = self ._zero_bc_nodes , diagonal = self ._diagonal , weight = self . _weight )
410420 elif rank == 2 :
411421 assembler = TwoFormAssembler (form , bcs = self ._bcs , form_compiler_parameters = self ._form_compiler_params ,
412422 mat_type = self ._mat_type , sub_mat_type = self ._sub_mat_type ,
@@ -577,10 +587,15 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
577587 @staticmethod
578588 def update_tensor (assembled_base_form , tensor ):
579589 if isinstance (tensor , (firedrake .Function , firedrake .Cofunction )):
580- assembled_base_form .dat .copy (tensor .dat )
590+ if isinstance (assembled_base_form , ufl .ZeroBaseForm ):
591+ tensor .dat .zero ()
592+ else :
593+ assembled_base_form .dat .copy (tensor .dat )
581594 elif isinstance (tensor , matrix .MatrixBase ):
582- # Uses the PETSc copy method.
583- assembled_base_form .petscmat .copy (tensor .petscmat )
595+ if isinstance (assembled_base_form , ufl .ZeroBaseForm ):
596+ tensor .petscmat .zeroEntries ()
597+ else :
598+ assembled_base_form .petscmat .copy (tensor .petscmat )
584599 else :
585600 raise NotImplementedError ("Cannot update tensor of type %s" % type (tensor ))
586601
@@ -807,9 +822,9 @@ def restructure_base_form(expr, visited=None):
807822 return ufl .action (expr , ustar )
808823
809824 # -- Case (6) -- #
810- if isinstance (expr , ufl .FormSum ) and all (isinstance ( c , ufl .core . base_form_operator . BaseFormOperator ) for c in expr .components ()):
811- # Return ufl.Sum
812- return sum ([ c for c in expr .components ()] )
825+ if isinstance (expr , ufl .FormSum ) and all (ufl .duals . is_dual ( a . function_space ()) for a in expr .arguments ()):
826+ # Return ufl.Sum if we are assembling a FormSum with Coarguments (a primal expression)
827+ return sum (w * c for w , c in zip ( expr .weights (), expr . components ()) )
813828 return expr
814829
815830 @staticmethod
@@ -1138,7 +1153,7 @@ class OneFormAssembler(ParloopFormAssembler):
11381153
11391154 Parameters
11401155 ----------
1141- form : ufl.Form or slate.TensorBasehe
1156+ form : ufl.Form or slate.TensorBase
11421157 1-form.
11431158
11441159 Notes
@@ -1149,14 +1164,15 @@ class OneFormAssembler(ParloopFormAssembler):
11491164
11501165 @classmethod
11511166 def _cache_key (cls , form , bcs = None , form_compiler_parameters = None , needs_zeroing = True ,
1152- zero_bc_nodes = False , diagonal = False ):
1167+ zero_bc_nodes = True , diagonal = False , weight = 1.0 ):
11531168 bcs = solving ._extract_bcs (bcs )
1154- return tuple (bcs ), tuplify (form_compiler_parameters ), needs_zeroing , zero_bc_nodes , diagonal
1169+ return tuple (bcs ), tuplify (form_compiler_parameters ), needs_zeroing , zero_bc_nodes , diagonal , weight
11551170
11561171 @FormAssembler ._skip_if_initialised
11571172 def __init__ (self , form , bcs = None , form_compiler_parameters = None , needs_zeroing = True ,
1158- zero_bc_nodes = False , diagonal = False ):
1173+ zero_bc_nodes = True , diagonal = False , weight = 1.0 ):
11591174 super ().__init__ (form , bcs = bcs , form_compiler_parameters = form_compiler_parameters , needs_zeroing = needs_zeroing )
1175+ self ._weight = weight
11601176 self ._diagonal = diagonal
11611177 self ._zero_bc_nodes = zero_bc_nodes
11621178 if self ._diagonal and any (isinstance (bc , EquationBCSplit ) for bc in self ._bcs ):
@@ -1185,23 +1201,21 @@ def _apply_bc(self, tensor, bc):
11851201 elif isinstance (bc , EquationBCSplit ):
11861202 bc .zero (tensor )
11871203 type (self )(bc .f , bcs = bc .bcs , form_compiler_parameters = self ._form_compiler_params , needs_zeroing = False ,
1188- zero_bc_nodes = self ._zero_bc_nodes , diagonal = self ._diagonal ).assemble (tensor = tensor )
1204+ zero_bc_nodes = self ._zero_bc_nodes , diagonal = self ._diagonal , weight = self . _weight ).assemble (tensor = tensor )
11891205 else :
11901206 raise AssertionError
11911207
11921208 def _apply_dirichlet_bc (self , tensor , bc ):
1193- if not self ._zero_bc_nodes :
1194- tensor_func = tensor .riesz_representation (riesz_map = "l2" )
1195- if self ._diagonal :
1196- bc .set (tensor_func , 1 )
1197- else :
1198- bc .apply (tensor_func )
1199- tensor .assign (tensor_func .riesz_representation (riesz_map = "l2" ))
1209+ if self ._diagonal :
1210+ bc .set (tensor , self ._weight )
1211+ elif not self ._zero_bc_nodes :
1212+ # NOTE this only works if tensor is a Function and not a Cofunction
1213+ bc .apply (tensor )
12001214 else :
12011215 bc .zero (tensor )
12021216
12031217 def _check_tensor (self , tensor ):
1204- if tensor .function_space () != self ._form .arguments ()[0 ].function_space ():
1218+ if tensor .function_space () != self ._form .arguments ()[0 ].function_space (). dual () :
12051219 raise ValueError ("Form's argument does not match provided result tensor" )
12061220
12071221 @staticmethod
@@ -2127,14 +2141,13 @@ def iter_active_coefficients(form, kinfo):
21272141
21282142 @staticmethod
21292143 def iter_constants (form , kinfo ):
2130- """Yield the form constants"""
2144+ """Yield the form constants referenced in ``kinfo``. """
21312145 if isinstance (form , slate .TensorBase ):
2132- for const in form .constants ():
2133- yield const
2146+ all_constants = form .constants ()
21342147 else :
21352148 all_constants = extract_firedrake_constants (form )
2136- for constant_index in kinfo .constant_numbers :
2137- yield all_constants [constant_index ]
2149+ for constant_index in kinfo .constant_numbers :
2150+ yield all_constants [constant_index ]
21382151
21392152 @staticmethod
21402153 def index_function_spaces (form , indices ):
0 commit comments