Skip to content

Explainer minor improvements and bug fixes #10154

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
Open
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ coverage.xml
.vscode
.idea
.venv
venv/
*.out
*.pt
*.onnx
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed `GraphMaskExplainer` with models with multiple heads ([#10154](https://github.com/pyg-team/pytorch_geometric/pull/10154))
- Fixed `_recursive_config()` for `torch.nn.ModuleList` and `torch.nn.ModuleDict` ([#10124](https://github.com/pyg-team/pytorch_geometric/pull/10124), [#10129](https://github.com/pyg-team/pytorch_geometric/pull/10129))
- Fixed the `k_hop_subgraph()` method for directed graphs ([#9756](https://github.com/pyg-team/pytorch_geometric/pull/9756))
- Fixed `utils.group_cat` concatenating dimension ([#9766](https://github.com/pyg-team/pytorch_geometric/pull/9766))
Expand Down
13 changes: 9 additions & 4 deletions examples/explain/captum_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@
from torch_geometric.explain import CaptumExplainer, Explainer
from torch_geometric.nn import GCNConv

dataset = 'Cora'
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Planetoid')
dataset = Planetoid(path, dataset)
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Cora')
dataset = Planetoid(path, 'Cora')
data = dataset[0]


Expand All @@ -26,7 +25,13 @@ def forward(self, x, edge_index):
return F.log_softmax(x, dim=1)


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if torch.cuda.is_available():
device = torch.device('cuda')
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
device = torch.device('mps')
else:
device = torch.device('cpu')

model = GCN().to(device)
data = data.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
Expand Down
7 changes: 6 additions & 1 deletion examples/explain/captum_explainer_hetero_link.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,12 @@
from torch_geometric.explain import CaptumExplainer, Explainer
from torch_geometric.nn import SAGEConv, to_hetero

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if torch.cuda.is_available():
device = torch.device('cuda')
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
device = torch.device('mps')
else:
device = torch.device('cpu')

path = osp.join(osp.dirname(osp.realpath(__file__)), '../../data/MovieLens')
dataset = MovieLens(path, model_name='all-MiniLM-L6-v2')
Expand Down
18 changes: 12 additions & 6 deletions examples/explain/gnn_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@
from torch_geometric.explain import Explainer, GNNExplainer
from torch_geometric.nn import GCNConv

dataset = 'Cora'
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Planetoid')
dataset = Planetoid(path, dataset)
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Cora')
dataset = Planetoid(path, 'Cora')
data = dataset[0]


Expand All @@ -26,12 +25,19 @@ def forward(self, x, edge_index):
return F.log_softmax(x, dim=1)


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if torch.cuda.is_available():
device = torch.device('cuda')
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
device = torch.device('mps')
else:
device = torch.device('cpu')

model = GCN().to(device)
data = data.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
epochs = 200

for epoch in range(1, 201):
for epoch in range(1, epochs + 1):
model.train()
optimizer.zero_grad()
out = model(data.x, data.edge_index)
Expand All @@ -41,7 +47,7 @@ def forward(self, x, edge_index):

explainer = Explainer(
model=model,
algorithm=GNNExplainer(epochs=200),
algorithm=GNNExplainer(epochs=epochs),
explanation_type='model',
node_mask_type='attributes',
edge_mask_type='object',
Expand Down
8 changes: 7 additions & 1 deletion examples/explain/gnn_explainer_ba_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,13 @@
idx = torch.arange(data.num_nodes)
train_idx, test_idx = train_test_split(idx, train_size=0.8, stratify=data.y)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if torch.cuda.is_available():
device = torch.device('cuda')
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
device = torch.device('mps')
else:
device = torch.device('cpu')

data = data.to(device)
model = GCN(data.num_node_features, hidden_channels=20, num_layers=3,
out_channels=dataset.num_classes).to(device)
Expand Down
12 changes: 6 additions & 6 deletions examples/explain/gnn_explainer_link_pred.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@
else:
device = torch.device('cpu')

dataset = 'Cora'
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Planetoid')
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Cora')
transform = T.Compose([
T.NormalizeFeatures(),
T.ToDevice(device),
T.RandomLinkSplit(num_val=0.05, num_test=0.1, is_undirected=True),
])
dataset = Planetoid(path, dataset, transform=transform)
dataset = Planetoid(path, 'Cora', transform=transform)
train_data, val_data, test_data = dataset[0]
epochs = 200


class GCN(torch.nn.Module):
Expand Down Expand Up @@ -70,7 +70,7 @@ def test(data):
return roc_auc_score(data.edge_label.cpu().numpy(), out.cpu().numpy())


for epoch in range(1, 201):
for epoch in range(1, epochs + 1):
loss = train()
if epoch % 20 == 0:
val_auc = test(val_data)
Expand All @@ -90,7 +90,7 @@ def test(data):
explainer = Explainer(
model=model,
explanation_type='model',
algorithm=GNNExplainer(epochs=200),
algorithm=GNNExplainer(epochs=epochs),
node_mask_type='attributes',
edge_mask_type='object',
model_config=model_config,
Expand All @@ -109,7 +109,7 @@ def test(data):
explainer = Explainer(
model=model,
explanation_type='phenomenon',
algorithm=GNNExplainer(epochs=200),
algorithm=GNNExplainer(epochs=epochs),
node_mask_type='attributes',
edge_mask_type='object',
model_config=model_config,
Expand Down
5 changes: 2 additions & 3 deletions examples/explain/graphmask_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@
from torch_geometric.explain import Explainer, GraphMaskExplainer
from torch_geometric.nn import GATConv, GCNConv

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

path = osp.join(osp.dirname(osp.realpath(__file__)), 'data', 'Planetoid')
device = 'cuda' if torch.cuda.is_available() else 'cpu'
path = osp.join(osp.dirname(osp.realpath(__file__)), 'data', 'Cora')
dataset = Planetoid(path, name='Cora')
data = dataset[0].to(device)

Expand Down
14 changes: 7 additions & 7 deletions test/explain/algorithm/test_graphmask_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
ModelReturnType,
ModelTaskLevel,
)
from torch_geometric.nn import GCNConv, global_add_pool
from torch_geometric.nn import GATConv, global_add_pool


class GCN(torch.nn.Module):
class GAT(torch.nn.Module):
def __init__(self, model_config: ModelConfig):
super().__init__()
self.model_config = model_config
Expand All @@ -22,8 +22,8 @@ def __init__(self, model_config: ModelConfig):
else:
out_channels = 1

self.conv1 = GCNConv(3, 16)
self.conv2 = GCNConv(16, out_channels)
self.conv1 = GATConv(3, 16, heads=2)
self.conv2 = GATConv(16 * 2, out_channels, heads=1)

def forward(self, x, edge_index, batch=None, edge_label_index=None):
x = self.conv1(x, edge_index).relu()
Expand Down Expand Up @@ -110,7 +110,7 @@ def test_graph_mask_explainer_binary_classification(
return_type=return_type,
)

model = GCN(model_config)
model = GAT(model_config)

target = None
if explanation_type == 'phenomenon':
Expand Down Expand Up @@ -162,7 +162,7 @@ def test_graph_mask_explainer_multiclass_classification(
return_type=return_type,
)

model = GCN(model_config)
model = GAT(model_config)

target = None
if explanation_type == 'phenomenon':
Expand Down Expand Up @@ -207,7 +207,7 @@ def test_graph_mask_explainer_regression(
task_level=task_level,
)

model = GCN(model_config)
model = GAT(model_config)

target = None
if explanation_type == 'phenomenon':
Expand Down
5 changes: 4 additions & 1 deletion torch_geometric/explain/algorithm/gnn_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,8 @@ def _train(
index: Optional[Union[int, Tensor]] = None,
**kwargs,
) -> None:
# Initialize the node and edge masks based on the node mask
# type and edge mask type defined in the explainer config.
...

def _train(
Expand Down Expand Up @@ -654,7 +656,8 @@ def explain_node(
return self._convert_output(explanation, edge_index, index=node_idx,
x=x)

def _convert_output(self, explanation, edge_index, index=None, x=None):
def _convert_output(self, explanation, edge_index, index=None,
x=None) -> Tuple[Tensor, Tensor]:
node_mask = explanation.get('node_mask')
edge_mask = explanation.get('edge_mask')

Expand Down
68 changes: 46 additions & 22 deletions torch_geometric/explain/algorithm/graphmask_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@


def explain_message(self, out: Tensor, x_i: Tensor, x_j: Tensor) -> Tensor:
orig_size = out.size()
if out.ndim == 3:
out = out.view(out.size(0), -1)
if x_i.ndim == 3:
x_i = x_i.view(x_i.size(0), -1)
if x_j.ndim == 3:
x_j = x_j.view(x_j.size(0), -1)
basis_messages = F.layer_norm(out, (out.size(-1), )).relu()

if getattr(self, 'message_scale', None) is not None:
Expand All @@ -33,6 +40,8 @@ def explain_message(self, out: Tensor, x_i: Tensor, x_j: Tensor) -> Tensor:
self.latest_source_embeddings = x_j
self.latest_target_embeddings = x_i

if len(orig_size) == 3:
basis_messages = basis_messages.view(orig_size[0], orig_size[1], -1)
return basis_messages


Expand Down Expand Up @@ -179,13 +188,7 @@ def _hard_concrete(

return clipped_s, penalty

def _set_masks(
self,
i_dim: List[int],
j_dim: List[int],
h_dim: List[int],
x: Tensor,
):
def _set_masks(self, x: Tensor):
r"""Sets the node masks and edge masks."""
(num_nodes, num_feat), std, device = x.size(), 0.1, x.device
self.feat_mask_type = self.explainer_config.node_mask_type
Expand All @@ -200,12 +203,21 @@ def _set_masks(
self.node_feat_mask = torch.nn.Parameter(
torch.randn(1, num_feat, device=device) * std)

def _set_trainable(
self,
i_dims: List[int],
j_dims: List[int],
h_dims: List[int],
device: torch.device,
):
baselines, self.gates, full_biases = [], torch.nn.ModuleList(), []
zipped = zip(i_dims, j_dims, h_dims)

for v_dim, m_dim, h_dim in zip(i_dim, j_dim, h_dim):
for items in zipped:
v_dim, m_dim, h_dim = items
self.transform, self.layer_norm = [], []
input_dims = [v_dim, m_dim, v_dim]
for _, input_dim in enumerate(input_dims):
for input_dim in input_dims:
self.transform.append(
Linear(input_dim, h_dim, bias=False).to(device))
self.layer_norm.append(LayerNorm(h_dim).to(device))
Expand Down Expand Up @@ -363,9 +375,16 @@ def _train_explainer(
for module in model.modules():
if isinstance(module, MessagePassing):
input_dims.append(module.in_channels)
output_dims.append(module.out_channels)
if hasattr(module, 'heads'):
heads = module.heads
else:
heads = 1
# If multihead attention is used, the output channels are
# multiplied by the number of heads
output_dims.append(module.out_channels * heads)

self._set_masks(input_dims, output_dims, output_dims, x)
self._set_masks(x)
self._set_trainable(input_dims, output_dims, output_dims, x.device)

optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)

Expand All @@ -385,7 +404,7 @@ def _train_explainer(
f'Train explainer for graph {index} with layer '
f'{layer}')
self._enable_layer(layer)
for epoch in range(self.epochs):
for _ in range(self.epochs):
with torch.no_grad():
model(x, edge_index, **kwargs)
gates, total_penalty = [], 0
Expand All @@ -405,19 +424,20 @@ def _train_explainer(
for i in range(self.num_layers):
output = self.full_biases[i]
for j in range(len(gate_input)):
input = gate_input[j][i]
if input.ndim == 3:
input = input.view(input.size(0), -1)
try:
partial = self.gates[i * 4][j](gate_input[j][i])
partial = self.gates[i * 4][j](input)
except Exception:
try:
self._set_masks(output_dims, output_dims,
output_dims, x)
partial = self.gates[i * 4][j](
gate_input[j][i])
self._set_trainable(output_dims, output_dims,
output_dims, x.device)
partial = self.gates[i * 4][j](input)
except Exception:
self._set_masks(input_dims, input_dims,
output_dims, x)
partial = self.gates[i * 4][j](
gate_input[j][i])
self._set_trainable(input_dims, input_dims,
output_dims, x.device)
partial = self.gates[i * 4][j](input)
result = self.gates[(i * 4) + 1][j](partial)
output = output + result
relu_output = self.gates[(i * 4) + 2](output /
Expand Down Expand Up @@ -511,7 +531,7 @@ def _explain(
pbar = tqdm(total=self.num_layers)
for i in range(self.num_layers):
if self.log:
pbar.set_description("Explain")
pbar.set_description(f"Explain layer {i}")
output = self.full_biases[i]
for j in range(len(gate_input)):
partial = self.gates[i * 4][j](gate_input[j][i])
Expand All @@ -522,6 +542,10 @@ def _explain(
3](relu_output).squeeze(dim=-1)
sampling_weights, _ = self._hard_concrete(
sampling_weights, training=False)
# TODO: This is a hack to make the explainer work for
# GAT where the weights here have three dimensions
if sampling_weights.ndim == 2:
sampling_weights = sampling_weights.mean(dim=-1)
if i == 0:
edge_weight = sampling_weights
else:
Expand Down
Loading
Loading