Skip to content

Commit

Permalink
refactor layers
Browse files Browse the repository at this point in the history
  • Loading branch information
aimat committed May 15, 2021
1 parent 82512e8 commit 027d133
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 6 deletions.
17 changes: 12 additions & 5 deletions kgcnn/layers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

class GraphBaseLayer(tf.keras.layers.Layer):
"""
Base layer of graph layers used in kgcnn that holds some additional information about the graph, which can
improve performance if set differently. Also input type check to support different tensor in- and output.
Base layer for graph layers used in kgcnn that holds some additional information about the graph, which can
improve performance, if set differently. Also input type check to support different tensor in- and output.
Args:
node_indexing (str): Indices referring to 'sample' or to the continuous 'batch'.
Expand All @@ -17,13 +17,14 @@ class GraphBaseLayer(tf.keras.layers.Layer):
ragged_validate (bool): Whether to validate ragged tensor. Default is False.
is_sorted (bool): If the edge indices are sorted for first ingoing index. Default is False.
has_unconnected (bool): If unconnected nodes are allowed. Default is True.
is_directed (bool): If the graph can be considered directed.
"""

def __init__(self,
node_indexing="sample",
partition_type="row_length",
input_tensor_type="ragged",
output_tensor_type="ragged",
output_tensor_type=None,
ragged_validate=False,
is_sorted=False,
has_unconnected=True,
Expand All @@ -36,7 +37,10 @@ def __init__(self,
self.node_indexing = node_indexing
self.partition_type = partition_type
self.input_tensor_type = input_tensor_type
self.output_tensor_type = output_tensor_type
if output_tensor_type is None:
self.output_tensor_type = input_tensor_type
else:
self.output_tensor_type = output_tensor_type
self.ragged_validate = ragged_validate
self.is_sorted = is_sorted
self.has_unconnected = has_unconnected
Expand All @@ -45,7 +49,7 @@ def __init__(self,
self._tensor_input_type_implemented = ["ragged", "values_partition", "disjoint",
"tensor", "RaggedTensor", "Tensor"]


self._tensor_input_type_found = []
self._test_tensor_input = kgcnn_ops_static_test_tensor_input_type(self.input_tensor_type,
self._tensor_input_type_implemented,
self.node_indexing)
Expand All @@ -63,3 +67,6 @@ def get_config(self):
})


def _kgcnn_map_input(self, inputs, num_input):

return input
2 changes: 1 addition & 1 deletion kgcnn/layers/gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

class GatherNodes(ks.layers.Layer):
"""
Gather nodes by edge_indices. Index list must match node tensor.
Gather nodes by node indices. Index list must match node tensor. An edge is defined by index tuple (i,j).
If graphs edge_indices were in 'sample' mode, the edge_indices must be corrected for disjoint graphs.
Expand Down

0 comments on commit 027d133

Please sign in to comment.