diff --git a/.gitignore b/.gitignore index 81e4628baa7a..1b12f3075687 100644 --- a/.gitignore +++ b/.gitignore @@ -21,6 +21,7 @@ coverage.xml .vscode .idea .venv +venv/ *.out *.pt *.onnx diff --git a/CHANGELOG.md b/CHANGELOG.md index 0fcddc5443be..e82d846c6a01 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -65,6 +65,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Fixed `GraphMaskExplainer` with models with multiple heads ([#10154](https://github.com/pyg-team/pytorch_geometric/pull/10154)) - Fixed `_recursive_config()` for `torch.nn.ModuleList` and `torch.nn.ModuleDict` ([#10124](https://github.com/pyg-team/pytorch_geometric/pull/10124), [#10129](https://github.com/pyg-team/pytorch_geometric/pull/10129)) - Fixed the `k_hop_subgraph()` method for directed graphs ([#9756](https://github.com/pyg-team/pytorch_geometric/pull/9756)) - Fixed `utils.group_cat` concatenating dimension ([#9766](https://github.com/pyg-team/pytorch_geometric/pull/9766)) diff --git a/examples/explain/captum_explainer.py b/examples/explain/captum_explainer.py index 7c582591ef4e..2a6995333600 100644 --- a/examples/explain/captum_explainer.py +++ b/examples/explain/captum_explainer.py @@ -7,9 +7,8 @@ from torch_geometric.explain import CaptumExplainer, Explainer from torch_geometric.nn import GCNConv -dataset = 'Cora' -path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Planetoid') -dataset = Planetoid(path, dataset) +path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Cora') +dataset = Planetoid(path, 'Cora') data = dataset[0] @@ -26,7 +25,13 @@ def forward(self, x, edge_index): return F.log_softmax(x, dim=1) -device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +if torch.cuda.is_available(): + device = torch.device('cuda') +elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): + device = torch.device('mps') +else: + device = torch.device('cpu') + model = GCN().to(device) data = data.to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4) diff --git a/examples/explain/captum_explainer_hetero_link.py b/examples/explain/captum_explainer_hetero_link.py index 3523d07ba172..35d014ec4568 100644 --- a/examples/explain/captum_explainer_hetero_link.py +++ b/examples/explain/captum_explainer_hetero_link.py @@ -9,7 +9,12 @@ from torch_geometric.explain import CaptumExplainer, Explainer from torch_geometric.nn import SAGEConv, to_hetero -device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +if torch.cuda.is_available(): + device = torch.device('cuda') +elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): + device = torch.device('mps') +else: + device = torch.device('cpu') path = osp.join(osp.dirname(osp.realpath(__file__)), '../../data/MovieLens') dataset = MovieLens(path, model_name='all-MiniLM-L6-v2') diff --git a/examples/explain/gnn_explainer.py b/examples/explain/gnn_explainer.py index 22ffa2462778..3256539a959d 100644 --- a/examples/explain/gnn_explainer.py +++ b/examples/explain/gnn_explainer.py @@ -7,9 +7,8 @@ from torch_geometric.explain import Explainer, GNNExplainer from torch_geometric.nn import GCNConv -dataset = 'Cora' -path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Planetoid') -dataset = Planetoid(path, dataset) +path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Cora') +dataset = Planetoid(path, 'Cora') data = dataset[0] @@ -26,12 +25,19 @@ def forward(self, x, edge_index): return F.log_softmax(x, dim=1) -device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +if torch.cuda.is_available(): + device = torch.device('cuda') +elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): + device = torch.device('mps') +else: + device = torch.device('cpu') + model = GCN().to(device) data = data.to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4) +epochs = 200 -for epoch in range(1, 201): +for epoch in range(1, epochs + 1): model.train() optimizer.zero_grad() out = model(data.x, data.edge_index) @@ -41,7 +47,7 @@ def forward(self, x, edge_index): explainer = Explainer( model=model, - algorithm=GNNExplainer(epochs=200), + algorithm=GNNExplainer(epochs=epochs), explanation_type='model', node_mask_type='attributes', edge_mask_type='object', diff --git a/examples/explain/gnn_explainer_ba_shapes.py b/examples/explain/gnn_explainer_ba_shapes.py index d9c82a3cfae5..f6c4ea321379 100644 --- a/examples/explain/gnn_explainer_ba_shapes.py +++ b/examples/explain/gnn_explainer_ba_shapes.py @@ -22,7 +22,13 @@ idx = torch.arange(data.num_nodes) train_idx, test_idx = train_test_split(idx, train_size=0.8, stratify=data.y) -device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +if torch.cuda.is_available(): + device = torch.device('cuda') +elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): + device = torch.device('mps') +else: + device = torch.device('cpu') + data = data.to(device) model = GCN(data.num_node_features, hidden_channels=20, num_layers=3, out_channels=dataset.num_classes).to(device) diff --git a/examples/explain/gnn_explainer_link_pred.py b/examples/explain/gnn_explainer_link_pred.py index 080a588801f7..b466e34750e5 100644 --- a/examples/explain/gnn_explainer_link_pred.py +++ b/examples/explain/gnn_explainer_link_pred.py @@ -16,15 +16,15 @@ else: device = torch.device('cpu') -dataset = 'Cora' -path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Planetoid') +path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Cora') transform = T.Compose([ T.NormalizeFeatures(), T.ToDevice(device), T.RandomLinkSplit(num_val=0.05, num_test=0.1, is_undirected=True), ]) -dataset = Planetoid(path, dataset, transform=transform) +dataset = Planetoid(path, 'Cora', transform=transform) train_data, val_data, test_data = dataset[0] +epochs = 200 class GCN(torch.nn.Module): @@ -70,7 +70,7 @@ def test(data): return roc_auc_score(data.edge_label.cpu().numpy(), out.cpu().numpy()) -for epoch in range(1, 201): +for epoch in range(1, epochs + 1): loss = train() if epoch % 20 == 0: val_auc = test(val_data) @@ -90,7 +90,7 @@ def test(data): explainer = Explainer( model=model, explanation_type='model', - algorithm=GNNExplainer(epochs=200), + algorithm=GNNExplainer(epochs=epochs), node_mask_type='attributes', edge_mask_type='object', model_config=model_config, @@ -109,7 +109,7 @@ def test(data): explainer = Explainer( model=model, explanation_type='phenomenon', - algorithm=GNNExplainer(epochs=200), + algorithm=GNNExplainer(epochs=epochs), node_mask_type='attributes', edge_mask_type='object', model_config=model_config, diff --git a/examples/explain/graphmask_explainer.py b/examples/explain/graphmask_explainer.py index 88eb8dbe9c4c..fd58a16c5a04 100644 --- a/examples/explain/graphmask_explainer.py +++ b/examples/explain/graphmask_explainer.py @@ -7,9 +7,8 @@ from torch_geometric.explain import Explainer, GraphMaskExplainer from torch_geometric.nn import GATConv, GCNConv -device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - -path = osp.join(osp.dirname(osp.realpath(__file__)), 'data', 'Planetoid') +device = 'cuda' if torch.cuda.is_available() else 'cpu' +path = osp.join(osp.dirname(osp.realpath(__file__)), 'data', 'Cora') dataset = Planetoid(path, name='Cora') data = dataset[0].to(device) diff --git a/test/explain/algorithm/test_graphmask_explainer.py b/test/explain/algorithm/test_graphmask_explainer.py index 0bb7eacc0767..0cdc1b02a812 100644 --- a/test/explain/algorithm/test_graphmask_explainer.py +++ b/test/explain/algorithm/test_graphmask_explainer.py @@ -9,10 +9,10 @@ ModelReturnType, ModelTaskLevel, ) -from torch_geometric.nn import GCNConv, global_add_pool +from torch_geometric.nn import GATConv, global_add_pool -class GCN(torch.nn.Module): +class GAT(torch.nn.Module): def __init__(self, model_config: ModelConfig): super().__init__() self.model_config = model_config @@ -22,8 +22,8 @@ def __init__(self, model_config: ModelConfig): else: out_channels = 1 - self.conv1 = GCNConv(3, 16) - self.conv2 = GCNConv(16, out_channels) + self.conv1 = GATConv(3, 16, heads=2) + self.conv2 = GATConv(16 * 2, out_channels, heads=1) def forward(self, x, edge_index, batch=None, edge_label_index=None): x = self.conv1(x, edge_index).relu() @@ -110,7 +110,7 @@ def test_graph_mask_explainer_binary_classification( return_type=return_type, ) - model = GCN(model_config) + model = GAT(model_config) target = None if explanation_type == 'phenomenon': @@ -162,7 +162,7 @@ def test_graph_mask_explainer_multiclass_classification( return_type=return_type, ) - model = GCN(model_config) + model = GAT(model_config) target = None if explanation_type == 'phenomenon': @@ -207,7 +207,7 @@ def test_graph_mask_explainer_regression( task_level=task_level, ) - model = GCN(model_config) + model = GAT(model_config) target = None if explanation_type == 'phenomenon': diff --git a/torch_geometric/explain/algorithm/gnn_explainer.py b/torch_geometric/explain/algorithm/gnn_explainer.py index 2aeb1c5a50cc..53aa067d5c9c 100644 --- a/torch_geometric/explain/algorithm/gnn_explainer.py +++ b/torch_geometric/explain/algorithm/gnn_explainer.py @@ -197,6 +197,8 @@ def _train( index: Optional[Union[int, Tensor]] = None, **kwargs, ) -> None: + # Initialize the node and edge masks based on the node mask + # type and edge mask type defined in the explainer config. ... def _train( @@ -654,7 +656,8 @@ def explain_node( return self._convert_output(explanation, edge_index, index=node_idx, x=x) - def _convert_output(self, explanation, edge_index, index=None, x=None): + def _convert_output(self, explanation, edge_index, index=None, + x=None) -> Tuple[Tensor, Tensor]: node_mask = explanation.get('node_mask') edge_mask = explanation.get('edge_mask') diff --git a/torch_geometric/explain/algorithm/graphmask_explainer.py b/torch_geometric/explain/algorithm/graphmask_explainer.py index 998d48766c5b..e79887db1a10 100644 --- a/torch_geometric/explain/algorithm/graphmask_explainer.py +++ b/torch_geometric/explain/algorithm/graphmask_explainer.py @@ -14,6 +14,13 @@ def explain_message(self, out: Tensor, x_i: Tensor, x_j: Tensor) -> Tensor: + orig_size = out.size() + if out.ndim == 3: + out = out.view(out.size(0), -1) + if x_i.ndim == 3: + x_i = x_i.view(x_i.size(0), -1) + if x_j.ndim == 3: + x_j = x_j.view(x_j.size(0), -1) basis_messages = F.layer_norm(out, (out.size(-1), )).relu() if getattr(self, 'message_scale', None) is not None: @@ -33,6 +40,8 @@ def explain_message(self, out: Tensor, x_i: Tensor, x_j: Tensor) -> Tensor: self.latest_source_embeddings = x_j self.latest_target_embeddings = x_i + if len(orig_size) == 3: + basis_messages = basis_messages.view(orig_size[0], orig_size[1], -1) return basis_messages @@ -179,13 +188,7 @@ def _hard_concrete( return clipped_s, penalty - def _set_masks( - self, - i_dim: List[int], - j_dim: List[int], - h_dim: List[int], - x: Tensor, - ): + def _set_masks(self, x: Tensor): r"""Sets the node masks and edge masks.""" (num_nodes, num_feat), std, device = x.size(), 0.1, x.device self.feat_mask_type = self.explainer_config.node_mask_type @@ -200,12 +203,21 @@ def _set_masks( self.node_feat_mask = torch.nn.Parameter( torch.randn(1, num_feat, device=device) * std) + def _set_trainable( + self, + i_dims: List[int], + j_dims: List[int], + h_dims: List[int], + device: torch.device, + ): baselines, self.gates, full_biases = [], torch.nn.ModuleList(), [] + zipped = zip(i_dims, j_dims, h_dims) - for v_dim, m_dim, h_dim in zip(i_dim, j_dim, h_dim): + for items in zipped: + v_dim, m_dim, h_dim = items self.transform, self.layer_norm = [], [] input_dims = [v_dim, m_dim, v_dim] - for _, input_dim in enumerate(input_dims): + for input_dim in input_dims: self.transform.append( Linear(input_dim, h_dim, bias=False).to(device)) self.layer_norm.append(LayerNorm(h_dim).to(device)) @@ -363,9 +375,16 @@ def _train_explainer( for module in model.modules(): if isinstance(module, MessagePassing): input_dims.append(module.in_channels) - output_dims.append(module.out_channels) + if hasattr(module, 'heads'): + heads = module.heads + else: + heads = 1 + # If multihead attention is used, the output channels are + # multiplied by the number of heads + output_dims.append(module.out_channels * heads) - self._set_masks(input_dims, output_dims, output_dims, x) + self._set_masks(x) + self._set_trainable(input_dims, output_dims, output_dims, x.device) optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) @@ -385,7 +404,7 @@ def _train_explainer( f'Train explainer for graph {index} with layer ' f'{layer}') self._enable_layer(layer) - for epoch in range(self.epochs): + for _ in range(self.epochs): with torch.no_grad(): model(x, edge_index, **kwargs) gates, total_penalty = [], 0 @@ -405,19 +424,20 @@ def _train_explainer( for i in range(self.num_layers): output = self.full_biases[i] for j in range(len(gate_input)): + input = gate_input[j][i] + if input.ndim == 3: + input = input.view(input.size(0), -1) try: - partial = self.gates[i * 4][j](gate_input[j][i]) + partial = self.gates[i * 4][j](input) except Exception: try: - self._set_masks(output_dims, output_dims, - output_dims, x) - partial = self.gates[i * 4][j]( - gate_input[j][i]) + self._set_trainable(output_dims, output_dims, + output_dims, x.device) + partial = self.gates[i * 4][j](input) except Exception: - self._set_masks(input_dims, input_dims, - output_dims, x) - partial = self.gates[i * 4][j]( - gate_input[j][i]) + self._set_trainable(input_dims, input_dims, + output_dims, x.device) + partial = self.gates[i * 4][j](input) result = self.gates[(i * 4) + 1][j](partial) output = output + result relu_output = self.gates[(i * 4) + 2](output / @@ -511,7 +531,7 @@ def _explain( pbar = tqdm(total=self.num_layers) for i in range(self.num_layers): if self.log: - pbar.set_description("Explain") + pbar.set_description(f"Explain layer {i}") output = self.full_biases[i] for j in range(len(gate_input)): partial = self.gates[i * 4][j](gate_input[j][i]) @@ -522,6 +542,10 @@ def _explain( 3](relu_output).squeeze(dim=-1) sampling_weights, _ = self._hard_concrete( sampling_weights, training=False) + # TODO: This is a hack to make the explainer work for + # GAT where the weights here have three dimensions + if sampling_weights.ndim == 2: + sampling_weights = sampling_weights.mean(dim=-1) if i == 0: edge_weight = sampling_weights else: diff --git a/torch_geometric/explain/config.py b/torch_geometric/explain/config.py index 21f16fdf9b12..05e123809276 100644 --- a/torch_geometric/explain/config.py +++ b/torch_geometric/explain/config.py @@ -77,7 +77,7 @@ class ExplainerConfig(CastMixin): - :obj:`"attributes"`: Will mask each feature across all nodes. edge_mask_type (MaskType or str, optional): The type of mask to apply - on edges. Has the sample possible values as :obj:`node_mask_type`. + on edges. Has the same values as :obj:`node_mask_type`. (default: :obj:`None`) """ explanation_type: ExplanationType @@ -193,7 +193,7 @@ class ThresholdConfig(CastMixin): - :obj:`"topk_hard"`: Same as :obj:`"topk"` but values are set to :obj:`1` for all elements which are kept. - value (int or float, optional): The value to use when thresholding. + value (int or float, optional): The value to use for thresholding. (default: :obj:`None`) """ type: ThresholdType diff --git a/torch_geometric/explain/explainer.py b/torch_geometric/explain/explainer.py index 118cc51b4833..a64f1a4aa2cf 100644 --- a/torch_geometric/explain/explainer.py +++ b/torch_geometric/explain/explainer.py @@ -59,7 +59,7 @@ class Explainer: - :obj:`"attributes"`: Will mask each feature across all nodes. edge_mask_type (MaskType or str, optional): The type of mask to apply - on edges. Has the sample possible values as :obj:`node_mask_type`. + on edges. Has the same values as :obj:`node_mask_type`. (default: :obj:`None`) threshold_config (ThresholdConfig, optional): The threshold configuration. diff --git a/torch_geometric/explain/explanation.py b/torch_geometric/explain/explanation.py index d9702ce8abd2..42f844bd5f4d 100644 --- a/torch_geometric/explain/explanation.py +++ b/torch_geometric/explain/explanation.py @@ -110,7 +110,7 @@ def threshold( *args, **kwargs, ) -> Union['Explanation', 'HeteroExplanation']: - """Thresholds the explanation masks according to the thresholding + """Thresholds the explanation masks according to the threshold method. Args: @@ -206,7 +206,7 @@ def visualize_feature_importance( feat_labels: Optional[List[str]] = None, top_k: Optional[int] = None, ): - r"""Creates a bar plot of the node feature importances by summing up + r"""Creates a bar plot of the node feature importance by summing up the node mask across all nodes. Args: @@ -329,7 +329,7 @@ def visualize_feature_importance( feat_labels: Optional[Dict[NodeType, List[str]]] = None, top_k: Optional[int] = None, ): - r"""Creates a bar plot of the node feature importances by summing up + r"""Creates a bar plot of the node feature importance by summing up node masks across all nodes for each node type. Args: