From 019a9a2bad864a9a611d4cbe59d9d88b9d335a6a Mon Sep 17 00:00:00 2001 From: PatReis Date: Mon, 19 Feb 2024 12:50:03 +0100 Subject: [PATCH] removed unused layers for fully model built to prevent errors in new pytorch trainer. --- changelog.md | 1 + kgcnn/layers/aggr.py | 8 ++- kgcnn/layers/attention.py | 88 ++++++++++++++++++++++----- kgcnn/layers/conv.py | 1 - kgcnn/literature/AttentiveFP/_make.py | 2 +- kgcnn/literature/DimeNetPP/_layers.py | 2 +- kgcnn/literature/HamNet/_layers.py | 2 +- kgcnn/literature/MAT/_layers.py | 1 + kgcnn/literature/MEGAN/_make.py | 1 + kgcnn/literature/MXMNet/_layers.py | 2 +- training/hyper/hyper_esol.py | 14 +++-- training/hyper/hyper_mp_jdft2d.py | 25 +++++--- training/train_graph.py | 5 +- 13 files changed, 114 insertions(+), 38 deletions(-) diff --git a/changelog.md b/changelog.md index 1f963dfd..cc3af875 100644 --- a/changelog.md +++ b/changelog.md @@ -4,6 +4,7 @@ v4.0.1 * Added further benchmark results for kgcnn version 4. * Fix error in ``kgcnn.layers.geom.PositionEncodingBasisLayer`` * Fix error in ``kgcnn.literature.GCN.make_model_weighted`` +* Fix error in ``kgcnn.literature.AttentiveFP.make_model`` * Had to change serialization for activation functions since with keras>=3.0.2 custom strings are not allowed also causing clashes with built-in functions. We catch defaults to be at least backward compatible as possible and changed to serialization dictionary. Adapted all hyperparameter. * Renamed leaky_relu and swish in ``kgcnn.ops.activ`` to leaky_relu2 and swish2. diff --git a/kgcnn/layers/aggr.py b/kgcnn/layers/aggr.py index a3c9c41d..a0723982 100644 --- a/kgcnn/layers/aggr.py +++ b/kgcnn/layers/aggr.py @@ -191,8 +191,9 @@ def __init__(self, pooling_method: str = "scatter_sum", pooling_index: int = glo self.normalize_by_weights = normalize_by_weights self.pooling_index = pooling_index self.pooling_method = pooling_method - self.to_aggregate = Aggregate(pooling_method=pooling_method) - self.to_aggregate_weights = Aggregate(pooling_method="scatter_sum") + # to_aggregate already made by super + if self.normalize_by_weights: + self.to_aggregate_weights = Aggregate(pooling_method="scatter_sum") self.axis_indices = axis_indices def build(self, input_shape): @@ -201,7 +202,8 @@ def build(self, input_shape): node_shape, edges_shape, edge_index_shape, weights_shape = [list(x) for x in input_shape] edge_index_shape.pop(self.axis_indices) self.to_aggregate.build([tuple(x) for x in [edges_shape, edge_index_shape, node_shape]]) - self.to_aggregate_weights.build([tuple(x) for x in [weights_shape, edge_index_shape, node_shape]]) + if self.normalize_by_weights: + self.to_aggregate_weights.build([tuple(x) for x in [weights_shape, edge_index_shape, node_shape]]) self.built = True def compute_output_shape(self, input_shape): diff --git a/kgcnn/layers/attention.py b/kgcnn/layers/attention.py index fc35c859..70ac891b 100644 --- a/kgcnn/layers/attention.py +++ b/kgcnn/layers/attention.py @@ -247,7 +247,11 @@ def get_config(self): return config -class MultiHeadGATV2Layer(AttentionHeadGATV2): # noqa +class MultiHeadGATV2Layer(Layer): # noqa + r"""Single layer for multiple Attention heads from :obj:`AttentionHeadGATV2` . + + Uses concatenation or averaging of heads for final output. + """ def __init__(self, units: int, @@ -255,39 +259,85 @@ def __init__(self, activation: str = "kgcnn>leaky_relu2", use_bias: bool = True, concat_heads: bool = True, + use_edge_features=False, + use_final_activation=True, + has_self_loops=True, + kernel_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + kernel_constraint=None, + bias_constraint=None, + kernel_initializer='glorot_uniform', + bias_initializer='zeros', + normalize_softmax: bool = False, **kwargs): - super(MultiHeadGATV2Layer, self).__init__( - units=units, - activation=activation, - use_bias=use_bias, - **kwargs - ) + r"""Initialize layer. + + Args: + units (int): Units for the linear trafo of node features before attention. + num_heads: Number of attention heads. + concat_heads: Whether to concatenate heads or average. + use_edge_features (bool): Append edge features to attention computation. Default is False. + use_final_activation (bool): Whether to apply the final activation for the output. + has_self_loops (bool): If the graph has self-loops. Not used here. Default is True. + activation (str): Activation. Default is "kgcnn>leaky_relu2". + use_bias (bool): Use bias. Default is True. + kernel_regularizer: Kernel regularization. Default is None. + bias_regularizer: Bias regularization. Default is None. + activity_regularizer: Activity regularization. Default is None. + kernel_constraint: Kernel constrains. Default is None. + bias_constraint: Bias constrains. Default is None. + kernel_initializer: Initializer for kernels. Default is 'glorot_uniform'. + bias_initializer: Initializer for bias. Default is 'zeros'. + """ + super(MultiHeadGATV2Layer, self).__init__(**kwargs) # Changes in keras serialization behaviour for activations in 3.0.2. # Keep string at least for default. Also renames to prevent clashes with keras leaky_relu. if activation in ["kgcnn>leaky_relu", "kgcnn>leaky_relu2"]: activation = {"class_name": "function", "config": "kgcnn>leaky_relu2"} self.num_heads = num_heads self.concat_heads = concat_heads + self.use_edge_features = use_edge_features + self.use_final_activation = use_final_activation + self.has_self_loops = has_self_loops + self.units = int(units) + self.normalize_softmax = normalize_softmax + self.use_bias = use_bias + kernel_args = {"kernel_regularizer": kernel_regularizer, + "activity_regularizer": activity_regularizer, "bias_regularizer": bias_regularizer, + "kernel_constraint": kernel_constraint, "bias_constraint": bias_constraint, + "kernel_initializer": kernel_initializer, "bias_initializer": bias_initializer} self.head_layers = [] for _ in range(num_heads): - lay_linear = Dense(units, activation=activation, use_bias=use_bias) - lay_alpha_activation = Dense(units, activation=activation, use_bias=use_bias) - lay_alpha = Dense(1, activation='linear', use_bias=False) + lay_linear = Dense(units, activation=activation, use_bias=use_bias, **kernel_args) + lay_alpha_activation = Dense(units, activation=activation, use_bias=use_bias, **kernel_args) + lay_alpha = Dense(1, activation='linear', use_bias=False, **kernel_args) self.head_layers.append((lay_linear, lay_alpha_activation, lay_alpha)) self.lay_concat_alphas = Concatenate(axis=-2) - self.lay_concat_embeddings = Concatenate(axis=-2) - self.lay_pool_attention = AggregateLocalEdgesAttention() - # self.lay_pool = AggregateLocalEdges() + + # self.lay_linear_trafo = Dense(units, activation="linear", use_bias=use_bias, **kernel_args) + # self.lay_alpha_activation = Dense(units, activation=activation, use_bias=use_bias, **kernel_args) + # self.lay_alpha = Dense(1, activation="linear", use_bias=False, **kernel_args) + self.lay_gather_in = GatherNodesIngoing() + self.lay_gather_out = GatherNodesOutgoing() + self.lay_concat = Concatenate(axis=-1) + self.lay_pool_attention = AggregateLocalEdgesAttention(normalize_softmax=normalize_softmax) + if self.use_final_activation: + self.lay_final_activ = Activation(activation=activation) if self.concat_heads: self.lay_combine_heads = Concatenate(axis=-1) else: self.lay_combine_heads = Average() - def __call__(self, inputs, **kwargs): + def build(self, input_shape): + """Build layer.""" + super(MultiHeadGATV2Layer, self).build(input_shape) + + def call(self, inputs, **kwargs): node, edge, edge_index = inputs # "a_ij" is a single-channel edge attention logits tensor. "a_ijs" is consequently the list which @@ -338,6 +388,16 @@ def __call__(self, inputs, **kwargs): def get_config(self): """Update layer config.""" config = super(MultiHeadGATV2Layer, self).get_config() + config.update({"use_edge_features": self.use_edge_features, "use_bias": self.use_bias, + "units": self.units, "has_self_loops": self.has_self_loops, + "normalize_softmax": self.normalize_softmax, + "use_final_activation": self.use_final_activation}) + if self.num_heads > 0: + conf_sub = self.head_layers[0][0].get_config() + for x in ["kernel_regularizer", "activity_regularizer", "bias_regularizer", "kernel_constraint", + "bias_constraint", "kernel_initializer", "bias_initializer", "activation"]: + if x in conf_sub: + config.update({x: conf_sub[x]}) config.update({ 'num_heads': self.num_heads, 'concat_heads': self.concat_heads diff --git a/kgcnn/layers/conv.py b/kgcnn/layers/conv.py index 13456d43..932ce727 100644 --- a/kgcnn/layers/conv.py +++ b/kgcnn/layers/conv.py @@ -181,7 +181,6 @@ def call(self, inputs, **kwargs): Returns: Tensor: Updated node features. """ - # print(inputs) node, edge, disjoint_indices = inputs x = self.lay_dense1(edge, **kwargs) x = self.lay_dense2(x, **kwargs) diff --git a/kgcnn/literature/AttentiveFP/_make.py b/kgcnn/literature/AttentiveFP/_make.py index f6d02da0..886cab8d 100644 --- a/kgcnn/literature/AttentiveFP/_make.py +++ b/kgcnn/literature/AttentiveFP/_make.py @@ -119,7 +119,7 @@ def make_model(inputs: list = None, model_inputs, input_tensor_type=input_tensor_type, cast_disjoint_kwargs=cast_disjoint_kwargs, - mask_assignment=[0, 0, 1], + mask_assignment=[0, 1, 1], index_assignment=[None, None, 0] ) diff --git a/kgcnn/literature/DimeNetPP/_layers.py b/kgcnn/literature/DimeNetPP/_layers.py index f601c5d5..15e52886 100644 --- a/kgcnn/literature/DimeNetPP/_layers.py +++ b/kgcnn/literature/DimeNetPP/_layers.py @@ -129,7 +129,7 @@ def call(self, inputs, **kwargs): # Transform via 2D spherical basis sbf = self.dense_sbf1(sbf, **kwargs) sbf = self.dense_sbf2(sbf, **kwargs) - x_kj = self.lay_mult1([x_kj, sbf], **kwargs) + x_kj = self.lay_mult2([x_kj, sbf], **kwargs) # Aggregate interactions and up-project embeddings x_kj = self.lay_pool([rbf, x_kj, id_expand], **kwargs) diff --git a/kgcnn/literature/HamNet/_layers.py b/kgcnn/literature/HamNet/_layers.py index 496b5ccd..8111a79d 100644 --- a/kgcnn/literature/HamNet/_layers.py +++ b/kgcnn/literature/HamNet/_layers.py @@ -530,7 +530,7 @@ def call(self, inputs, **kwargs): q_u_ftr, q_v_ftr = self.gather_p([q_ftr, edi], **kwargs) p_u_ftr, p_v_ftr = self.gather_q([p_ftr, edi], **kwargs) p_uv_ftr = self.lazy_sub_p([p_v_ftr, p_u_ftr], **kwargs) - q_uv_ftr = self.lazy_sub_p([q_v_ftr, q_u_ftr], **kwargs) + q_uv_ftr = self.lazy_sub_q([q_v_ftr, q_u_ftr], **kwargs) attend_ftr = self.dense_attend(hv_v_ftr, **kwargs) diff --git a/kgcnn/literature/MAT/_layers.py b/kgcnn/literature/MAT/_layers.py index feee25a0..9b2d2af9 100644 --- a/kgcnn/literature/MAT/_layers.py +++ b/kgcnn/literature/MAT/_layers.py @@ -40,6 +40,7 @@ class MATDistanceMatrix(ks.layers.Layer): def __init__(self, trafo: Union[str, None] = "exp", **kwargs): super(MATDistanceMatrix, self).__init__(**kwargs) self.trafo = trafo + # self._softmax = ks.layers.Softmax(axis=2) if self.trafo not in [None, "exp", "softmax"]: raise ValueError("`trafo` must be in [None, 'exp', 'softmax']") diff --git a/kgcnn/literature/MEGAN/_make.py b/kgcnn/literature/MEGAN/_make.py index 81d516ac..279303a9 100644 --- a/kgcnn/literature/MEGAN/_make.py +++ b/kgcnn/literature/MEGAN/_make.py @@ -1,3 +1,4 @@ +import keras as ks from ._model import MEGAN from kgcnn.models.utils import update_model_kwargs from kgcnn.layers.modules import Input diff --git a/kgcnn/literature/MXMNet/_layers.py b/kgcnn/literature/MXMNet/_layers.py index e424a1d7..8b6bc14f 100644 --- a/kgcnn/literature/MXMNet/_layers.py +++ b/kgcnn/literature/MXMNet/_layers.py @@ -119,7 +119,7 @@ def __init__(self, units: int = 64, output_units: int = 1, activation: str = "sw self.lin_rbf_out = Dense(self.dim, use_bias=False, activation="linear") - self.h_mlp = GraphMLP(self.dim, activation=activation) + # Fix for kgcnn==4.0.1: removed overwrite mlp here. Should not change model but prevents unused layers. self.y_mlp = GraphMLP([self.dim, self.dim, self.dim], activation=activation) self.y_W = Dense(self.output_dim, activation="linear", diff --git a/training/hyper/hyper_esol.py b/training/hyper/hyper_esol.py index 8a140c00..79945ec7 100644 --- a/training/hyper/hyper_esol.py +++ b/training/hyper/hyper_esol.py @@ -478,11 +478,13 @@ "config": { "name": "PAiNN", "inputs": [ - {"shape": [None], "name": "node_number", "dtype": "int64", "ragged": True}, - {"shape": [None, 3], "name": "node_coordinates", "dtype": "float32", "ragged": True}, - {"shape": [None, 2], "name": "range_indices", "dtype": "int64", "ragged": True} + {"shape": [None], "name": "node_number", "dtype": "int64"}, + {"shape": [None, 3], "name": "node_coordinates", "dtype": "float32"}, + {"shape": [None, 2], "name": "range_indices", "dtype": "int64"}, + {"shape": (), "name": "total_nodes", "dtype": "int64"}, + {"shape": (), "name": "total_ranges", "dtype": "int64"} ], - "input_tensor_type": "ragged", + "input_tensor_type": "padded", "cast_disjoint_kwargs": {}, "input_node_embedding": {"input_dim": 95, "output_dim": 128}, "bessel_basis": {"num_radial": 20, "cutoff": 5.0, "envelope_exponent": 5}, @@ -522,7 +524,9 @@ "config": {}, "methods": [ {"set_attributes": {"add_hydrogen": True}}, - {"map_list": {"method": "set_range", "max_distance": 3, "max_neighbours": 10000}} + {"map_list": {"method": "set_range", "max_distance": 3, "max_neighbours": 10000}}, + {"map_list": {"method": "count_nodes_and_edges", "total_edges": "total_ranges", + "count_edges": "range_indices"}}, ] }, "data_unit": "mol/L" diff --git a/training/hyper/hyper_mp_jdft2d.py b/training/hyper/hyper_mp_jdft2d.py index f11ea109..7b358b74 100644 --- a/training/hyper/hyper_mp_jdft2d.py +++ b/training/hyper/hyper_mp_jdft2d.py @@ -153,16 +153,18 @@ "config": { 'name': 'CGCNN', 'inputs': [ - {'shape': (None,), 'name': 'node_number', 'dtype': 'int64', 'ragged': True}, - {'shape': (None, 3), 'name': 'node_frac_coordinates', 'dtype': 'float64', 'ragged': True}, - {'shape': (None, 2), 'name': 'range_indices', 'dtype': 'int64', 'ragged': True}, - {'shape': (None, 3), 'name': 'range_image', 'dtype': 'float32', 'ragged': True}, - {'shape': (3, 3), 'name': 'graph_lattice', 'dtype': 'float64', 'ragged': False}, + {'shape': (None,), 'name': 'node_number', 'dtype': 'int64'}, + {'shape': (None, 3), 'name': 'node_frac_coordinates', 'dtype': 'float64'}, + {'shape': (None, 2), 'name': 'range_indices', 'dtype': 'int64'}, + {'shape': (None, 3), 'name': 'range_image', 'dtype': 'float32'}, + {'shape': (3, 3), 'name': 'graph_lattice', 'dtype': 'float64'}, # For `representation="asu"`: # {'shape': (None, 1), 'name': 'multiplicities', 'dtype': 'float32', 'ragged': True}, # {'shape': (None, 4, 4), 'name': 'symmops', 'dtype': 'float64', 'ragged': True}, + {"shape": (), "name": "total_nodes", "dtype": "int64"}, + {"shape": (), "name": "total_ranges", "dtype": "int64"} ], - "input_tensor_type": "ragged", + "input_tensor_type": "padded", 'input_node_embedding': {'input_dim': 95, 'output_dim': 64}, 'representation': 'unit', # None, 'asu' or 'unit' 'expand_distance': True, @@ -205,15 +207,18 @@ "config": {"with_std": True, "with_mean": True, "copy": True} }, }, - "data": { - "dataset": { + "dataset": { "class_name": "MatProjectJdft2dDataset", "module_name": "kgcnn.data.datasets.MatProjectJdft2dDataset", "config": {}, "methods": [ - {"map_list": {"method": "set_range_periodic", "max_distance": 6.0}} + {"map_list": {"method": "set_range_periodic", "max_distance": 6.0}}, + {"map_list": {"method": "count_nodes_and_edges", "total_edges": "total_ranges", + "count_edges": "range_indices", "count_nodes": "node_number", + "total_nodes": "total_nodes"}}, ] - }, + }, + "data": { "data_unit": "meV/atom" }, "info": { diff --git a/training/train_graph.py b/training/train_graph.py index 2ffdc7fc..0069e107 100644 --- a/training/train_graph.py +++ b/training/train_graph.py @@ -153,7 +153,10 @@ # Model summary model.summary() print(" Compiled with jit: %s" % model._jit_compile) # noqa - print(" Model is built: %s" % all([layer.built for layer in model._flatten_layers()])) # noqa + print(" Model is built: %s, with unbuilt: %s" % ( + all([layer.built for layer in model._flatten_layers()]), # noqa + [layer.name for layer in model._flatten_layers() if not layer.built] + )) # Run keras model-fit and take time for training. start = time.time()