-
Notifications
You must be signed in to change notification settings - Fork 3.9k
Add graph transformer #10313
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
omarkhater
wants to merge
143
commits into
pyg-team:master
Choose a base branch
from
omarkhater:add_graph_transformer
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Add graph transformer #10313
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
…tests
- Introduce GraphTransformer module:
- _readout() → global_mean_pool aggregation
- classifier → nn.Linear(hidden_dim → num_classes)
- Add parameterized pytest to verify output shape across:
* various num_graphs, num_nodes, feature_dims, num_classes
- Establishes TDD baseline for subsequent transformer layers
feat: add super-node ([CLS] token) readout support - Introduce `use_super_node` flag and `cls_token` parameter - Implement `_add_cls_token` and branch in `_readout` - Add `test_super_node_readout` to verify zero-bias behavior
- Introduce `node_feature_encoder` argument (default Identity) - Wire encoder into forward via `_encode_nodes` - Add `test_node_feature_encoder_identity` to verify scaling logic
Introduce the core Transformer‐backbone scaffold without changing model behavior: - Add `IdentityLayer(nn.Module)` that accepts `(x, batch)` and simply returns `x`. - Create `EncoderLayers` container wrapping a `ModuleList` of `IdentityLayer` instances. - Wire `self.encoder = EncoderLayers(num_encoder_layers)` into `GraphTransformer.__init__`. - Update `forward()` to call `x = self.encoder(x, data.batch)` before readout. - Preserve existing functionality: mean‐pool readout, super‐node, and node‐feature encoder all still work. - Add `test_transformer_block_identity` to verify that replacing one stub layer with a custom module (e.g. `AddOneLayer`) correctly affects the output shape.
- Test that gradients flow properly through the GraphTransformer model - Test that GraphTransformer can be traced with TorchScript
…p stubs This commit completes Cycle 4-A (RED → GREEN → REFACTOR): the codebase now contains a real, pluggable encoder skeleton ready for multi-head attention and feed-forward logic in the next mini-cycle, while all existing behaviour, TorchScript traceability, and test guarantees remain intact - Added =============== * torch_geometric.contrib.nn.layers.transformer.GraphTransformerEncoderLayer * contains the two LayerNorm blocks that will precede MHA & FFN in next iteration * forwards x, batch unchanged for now (keeps tests green) Changed ====================== - GraphTransformer * builds either a single GraphTransformerEncoderLayer (num_encoder_layers==0) or a full GraphTransformerEncoder stack. * keeps the existing public API and super-node/mean-pool read-outs. - Test-suite * rewired to use the new stack (model.encoder[0] instead of model.encoder.layers[0]), added test_encoder_layer_type that asserts the first layer is a GraphTransformerEncoderLayer - GraphTransformerEncoder – thin wrapper that holds an arbitrary number of identical encoder layers and exposes __len__, __getitem__, __setitem__ so tests (and users) can hot-swap layers. Removed =========== Legacy scaffolding IdentityLayer and EncoderLayers that were only placeholders in Cycle 3
- Introduces configurable num_heads, dropout, ffn_hidden_dim, and activation params to GraphTransformerEncoderLayer - Adds a placeholder self.self_attn = nn.Identity() so the interface is ready for real multi-head attention in a future cycle - Implements the full Add & Norm ➜ FFN ➜ Add & Norm pattern, meaning the layer now changes the tensor values (tests confirm) while preserving shape - Exposes a JIT-ignored static helper _build_key_padding_mask (returns None for now) to prepare for masking logic - Updates GraphTransformer to forward the new constructor args.
… mask - adds nn.MultiheadAttention(batch_first=True) to encoder layer - implements per-graph key_padding_mask so distinct graphs don’t cross-talk - updates tests to confirm the layer now transforms identical rows - refactors shape helpers for readability
Changed test to verify encoder affects outputs rather than assuming direct pooling Use Identity encoder for simpler test Compare outputs with raw vs scaled features Updated assertion to check that scaling changes output
Introduces an optional degree_encoder hook to enrich node embeddings with structural information (in-/out-degree, or any custom degree-based feature). When provided, the encoder’s output is added to the raw (optionally pre-encoded) node features before the transformer stack, enabling degree awareness without altering existing workflows.
Extract “sum up extra encoders” into a private helper to keep forward tidy
Ensures logits differ when a structural mask is supplied.
Keeps all earlier behaviour while unlocking arbitrary structural masks.
Add attn_mask to AddOneLayer
* `_prepend_cls_token_flat` adds one learnable row per graph and returns the updated batch vector. * `forward` uses the new helper so the encoder still receives a flat tensor. * `_readout` now picks CLS rows via vectorised indexing (`first_idx = cumsum(graph_sizes)-graph_sizes`). * Removes the temporary _add_cls_token call inside forward. This lets the CLS token flow through attention while keeping the encoder API unchanged; the new `test_cls_token_transformation` turns green.
…mation the CLS token is used directly in readout, but our model now transforms it through the encoder.
Use alias imports to avoid multi-line import formatting conflicts between yapf and isort tools. The alias import pattern (import module as _alias; Class = _alias.Class) prevents the formatting loop that occurred with parenthesized multi-line imports exceeding the 79-character line limit.
- Reduce the complexity of the init constructor from B-8 to A-1 - Average complexity dropped from 3.75 to 3.48 - Improve doc string
…bmodules - Add reset parameters to all bias providers - Add reset parameters to all positional encoders - Add reset parameters to encoder layer - Add reset parameters to optional GNN modules - Basic unit test to ensure that reset_parameters changes weights
Introduced a 'basic_encoder_config' pytest fixture for GraphTransformer tests in conftest.py. Updated bias provider tests to use the new fixture and the 'encoder_cfg' argument, improving test clarity and maintainability.
Updated test cases bias provider tests to construct GraphTransformer instances using the encoder_cfg argument instead of passing attn_bias_providers and num_encoder_layers directly. This aligns the tests with the updated GraphTransformer API and improves clarity.
Applied the @onlyFullTest decorator to test_encoder_speed, test_encoder_with_bias_speed, and the parameterized test to ensure these performance tests only run in full test environments.
Introduces a cache for attention masks in the GraphTransformer model to avoid redundant computation of key padding masks across forward passes. Updates the forward logic to use the cached mask and adds a test to verify that the mask is built only once per unique configuration. The cache is cleared on parameter reset to ensure consistency.
Allow caching only when requested, allowing key-padding masks to be cached and reused for performance improvements. Updates tests to verify correct mask caching behavior based on the new parameter.
Introduces a parameterized test to validate attention-mask caching behavior in GraphTransformer, covering scenarios such as shape changes, head changes, cache resets, and super-node usage. Ensures correct mask building and cache invalidation logic.
Moved the patching of `build_key_padding` to a pytest fixture (`patched_build_key_padding`) in conftest.py, simplifying test code in test_graph_transformer.py. This centralizes the mocking logic, reduces duplication, and ensures consistent behavior across tests that require mocking of key padding mask construction.
Refactor test cases for GraphTransformer to utilize DEFAULT_ENCODER and DEFAULT_GNN constants for encoder_cfg and gnn_cfg construction. This improves consistency with model defaults and reduces duplication. Also updates the definition of these constants in the model to avoid use Python dict literal syntax for clarity.
Renamed variable 'Ni' to 'num_nodes' for clarity when creating zero blocks in GraphAttnEdgeBias. This improves code readability without changing functionality.
The simple_none_batch fixture now explicitly sets the num_nodes attribute based on the edge_index tensor. This ensures correct graph construction when node features are None.
- Constructor takes hidden_dim and optional out_channels. Remove the legacy name num_class
Replaces the 'classifier' layer with 'output_projection' in GraphTransformer, allowing out_channels to be None and returning hidden_dim-sized outputs in that case. Updates all relevant test cases to expect per-node outputs instead of per-graph outputs, and adds a test for the no-head (out_channels=None) scenario. This change improves flexibility for downstream tasks and clarifies output shapes.
Before
F 56:0 merge_masks - C (11)
F 34:0 _to_additive - A (3)
F 7:0 build_key_padding - A (2)
3 blocks (classes, functions, methods) analyzed.
Average complexity: B (5.333333333333333)
After
F 93:0 merge_masks - B (10)
F 34:0 _to_additive - A (3)
F 56:0 _expand_key_pad - A (3)
F 73:0 _combine_additive - A (3)
F 7:0 build_key_padding - A (2)
5 blocks (classes, functions, methods) analyzed.
Average complexity: A (4.2)
Changed the initialization of node features from zeros to ones in the GraphTransformer model. This enables every feature-less graph (provide no data.x.) gains a non-zero, trainable signal that guarantees immediate gradient flow and lets the model discover the optimal constant input during training
Introduces a cast_bias parameter to GraphTransformer, allowing attention bias to be cast to the same dtype as node features. Adds comprehensive tests for attention bias dtype behavior and mixed-precision AMP stability.
Replaces custom dense batch padding and unpadding methods with torch_geometric's to_dense_batch utility in GraphTransformerEncoderLayer. This simplifies the code and leverages a well-tested utility for handling batched graph data.
This patch implements the July-2025 PyG transformer guideline that recommends starting the learnable super-node (CLS) embedding from a small-variance truncated normal distribution rather than an all-zero vector. Why the change? Distinct token signal → better learning dynamics A non-zero initialisation ensures each attention head receives a unique, low-magnitude input from the CLS token on the very first optimisation step, improving gradient flow and preventing early-epoch degeneration into a constant feature. Consistent with PyTorch Geometric’s upcoming defaults Aligns with the “Transformer note (July-2025)” so downstream users see the same behaviour across PyG models.
…ass. Refactors GraphTransformerEncoderLayer and related code to operate on padded (dense) node feature tensors instead of flat (N, C) inputs, removing the need to pass batch vectors through layers. Introduces pad_cache for efficient key padding mask reuse across layers. Updates tests and model code to match the new interface and data flow.
Rewrites the _prepend_cls_token_flat method to avoid creating full-size zero tensors and using scatter, instead relying on torch.cat and torch.argsort for improved efficiency. Updates docstrings for clarity and adjusts variable naming and comments for better readability.
…ed from 89% to 92%) Split the monolithic test_transformer_init_arg_validation into focused, parameterized tests for each argument rule. This improves test clarity, maintainability, and error isolation for GraphTransformer configuration validation.
Introduces new parameterized tests to validate input dimensions, types, and bias provider enforcement in the GraphTransformer model. These tests ensure that invalid configurations raise appropriate exceptions, improving robustness and error reporting. Increase test coverage from 92% to 95%
…rage) Rewrites and extends test cases for build_key_padding and merge_masks, improving coverage for edge cases, dtype handling, gradient flow, and error conditions. Tests now validate correct behavior for various input shapes, mask types, and parameter combinations, ensuring robustness of mask merging logic.
Introduces a test to ensure that merge_masks raises a TypeError when provided with a non-float attention mask.
Deleted the unused _mixed_batch function to clean up the test file. No functional changes to existing tests.
Adds logic to handle Data objects without a batch attribute in GraphTransformer by defaulting to a zero batch vector. Updates tests to verify that the model can process data without batch information, and documents this behavior in the code and docstrings.
for more information, see https://pre-commit.ci
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR introduces a modular and extensible GraphTransformer model for PyTorch Geometric, following the architecture described in "Transformer for Graphs: An Overview from Architecture Perspective".
Key Features
Benchmarking
I have thoroughly tested this model on a variety of node and graph classification benchmarks, including CORA, CITSEER, PUBMED, MUTAG, and PROTEINS. Results and reproducible scripts are available in my benchmarking repository:
👉 omarkhater/benchmark_graph_transformer
The model achieves strong performance across all tested datasets and supports a wide range of configurations (see the benchmark repo for details).
Future Work
If desirable, I am happy to continue working on this model to add regression task support and address any additional requirements from the PyG team.
Let me know if you have feedback or requests for further improvements!