Torch Geometric Engine
Battle-tested skill for graph, neural, networks, node. Includes structured workflows, validation checks, and reusable patterns for scientific.
Torch Geometric Engine
Build and train Graph Neural Networks (GNNs) using PyTorch Geometric (PyG), a library for deep learning on graphs and irregular structures. This skill covers graph construction, message passing layers, node/edge/graph classification, link prediction, graph generation, and scaling to large graphs.
When to Use This Skill
Choose Torch Geometric Engine when you need to:
- Train GNNs for node classification, link prediction, or graph-level tasks
- Process molecular graphs, social networks, or knowledge graphs with neural networks
- Implement custom message passing schemes and graph transformations
- Scale graph learning to large datasets with mini-batching and sampling
Consider alternatives when:
- You need graph analysis without neural networks (use NetworkX)
- You need drug discovery with specialized molecular models (use TorchDrug)
- You need static graph algorithms and analytics (use igraph or graph-tool)
Quick Start
pip install torch torch-geometric
import torch import torch.nn.functional as F from torch_geometric.datasets import Planetoid from torch_geometric.nn import GCNConv # Load Cora citation network dataset = Planetoid(root='/tmp/Cora', name='Cora') data = dataset[0] print(f"Nodes: {data.num_nodes}, Edges: {data.num_edges}") print(f"Features: {data.num_node_features}, Classes: {dataset.num_classes}") # Define GCN model class GCN(torch.nn.Module): def __init__(self, in_channels, hidden_channels, out_channels): super().__init__() self.conv1 = GCNConv(in_channels, hidden_channels) self.conv2 = GCNConv(hidden_channels, out_channels) def forward(self, x, edge_index): x = self.conv1(x, edge_index) x = F.relu(x) x = F.dropout(x, p=0.5, training=self.training) x = self.conv2(x, edge_index) return x # Train device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = GCN(dataset.num_features, 16, dataset.num_classes).to(device) data = data.to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4) model.train() for epoch in range(200): optimizer.zero_grad() out = model(data.x, data.edge_index) loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask]) loss.backward() optimizer.step() # Evaluate model.eval() pred = model(data.x, data.edge_index).argmax(dim=1) correct = (pred[data.test_mask] == data.y[data.test_mask]).sum() acc = int(correct) / int(data.test_mask.sum()) print(f"Test accuracy: {acc:.4f}")
Core Concepts
GNN Layer Types
| Layer | Class | Mechanism | Best For |
|---|---|---|---|
| GCN | GCNConv | Spectral convolution | Homogeneous graphs |
| GAT | GATConv | Attention-weighted aggregation | Variable-importance neighbors |
| GraphSAGE | SAGEConv | Sampling + aggregation | Large graphs, inductive |
| GIN | GINConv | Sum aggregation (WL-test powerful) | Graph classification |
| EdgeConv | EdgeConv | Dynamic graph construction | Point clouds |
| TransformerConv | TransformerConv | Multi-head attention on graphs | Heterogeneous features |
Graph Classification with Pooling
import torch import torch.nn.functional as F from torch_geometric.datasets import TUDataset from torch_geometric.loader import DataLoader from torch_geometric.nn import GINConv, global_mean_pool, global_add_pool from torch.nn import Linear, Sequential, ReLU, BatchNorm1d # Load molecular dataset dataset = TUDataset(root='/tmp/MUTAG', name='MUTAG') print(f"Graphs: {len(dataset)}, Classes: {dataset.num_classes}") # Split train_dataset = dataset[:150] test_dataset = dataset[150:] train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) test_loader = DataLoader(test_dataset, batch_size=32) class GINClassifier(torch.nn.Module): def __init__(self, in_channels, hidden, out_channels, num_layers=3): super().__init__() self.convs = torch.nn.ModuleList() self.bns = torch.nn.ModuleList() for i in range(num_layers): in_ch = in_channels if i == 0 else hidden mlp = Sequential( Linear(in_ch, hidden), BatchNorm1d(hidden), ReLU(), Linear(hidden, hidden), ReLU() ) self.convs.append(GINConv(mlp)) self.bns.append(BatchNorm1d(hidden)) self.classifier = Sequential( Linear(hidden, hidden), ReLU(), Linear(hidden, out_channels) ) def forward(self, x, edge_index, batch): for conv, bn in zip(self.convs, self.bns): x = conv(x, edge_index) x = bn(x) x = F.relu(x) x = global_add_pool(x, batch) # Graph-level readout return self.classifier(x) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = GINClassifier(dataset.num_features, 64, dataset.num_classes).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.01) for epoch in range(100): model.train() total_loss = 0 for batch in train_loader: batch = batch.to(device) optimizer.zero_grad() out = model(batch.x, batch.edge_index, batch.batch) loss = F.cross_entropy(out, batch.y) loss.backward() optimizer.step() total_loss += loss.item() model.eval() correct = 0 total = 0 for batch in test_loader: batch = batch.to(device) pred = model(batch.x, batch.edge_index, batch.batch).argmax(dim=1) correct += (pred == batch.y).sum().item() total += batch.y.size(0) print(f"Test accuracy: {correct/total:.4f}")
Configuration
| Parameter | Description | Default |
|---|---|---|
hidden_channels | Hidden layer dimensionality | 64 |
num_layers | Number of GNN layers | 3 |
dropout | Dropout rate | 0.5 |
learning_rate | Optimizer learning rate | 0.01 |
weight_decay | L2 regularization | 5e-4 |
batch_size | Mini-batch size for graph-level tasks | 32 |
heads | Number of attention heads (GAT) | 8 |
aggr | Aggregation function (add, mean, max) | "add" |
Best Practices
-
Start with 2-3 GNN layers, rarely go beyond 5 — GNNs suffer from over-smoothing: too many layers make all node embeddings converge to the same value. For most tasks, 2-3 layers capture sufficient neighborhood information. Add skip connections (residual) if you need deeper models.
-
Use mini-batching for graph-level tasks — PyG's
DataLoaderbatches multiple graphs into a single disconnected graph using thebatchtensor. Always passbatchto pooling functions (global_mean_pool,global_add_pool) — without it, the model treats all graphs as one. -
Choose the aggregation function based on your task —
sumaggregation (GIN) is most expressive for graph isomorphism tasks.meanaggregation (GCN, GraphSAGE) normalizes by degree and works better for node-level predictions.maxaggregation captures the most extreme neighbor feature. -
Normalize node features before training — GNN performance is sensitive to feature scale. Standardize continuous features and use one-hot encoding for categorical node attributes. Without normalization, high-magnitude features dominate message passing.
-
Use
NeighborLoaderfor large-graph training — Full-batch training on graphs with millions of nodes exhausts GPU memory.NeighborLoadersamples a fixed number of neighbors per layer, enabling mini-batch training on arbitrarily large graphs with constant memory usage.
Common Issues
Out of memory on large graphs — Full-batch GCN on a million-node graph requires storing the full adjacency and feature matrix. Use NeighborLoader or ClusterLoader for mini-batch training. Reduce hidden_channels and num_layers as a quick fix while debugging.
Node classification accuracy is random (~1/num_classes) — Check that data.train_mask is not empty and properly selects training nodes. Also verify that edge_index contains valid edges — missing or corrupted edges means nodes receive no neighbor information. Print data to inspect all attributes.
Custom dataset fails with "Data object has no attribute 'x'" — PyG expects specific attribute names: x for node features, edge_index for edges (shape [2, num_edges]), y for labels, edge_attr for edge features. When building custom Data objects, ensure edge_index is a LongTensor and is 0-indexed.
Reviews
No reviews yet. Be the first to review this template!
Similar Templates
Full-Stack Code Reviewer
Comprehensive code review skill that checks for security vulnerabilities, performance issues, accessibility, and best practices across frontend and backend code.
Test Suite Generator
Generates comprehensive test suites with unit tests, integration tests, and edge cases. Supports Jest, Vitest, Pytest, and Go testing.
Pro Architecture Workspace
Battle-tested skill for architectural, decision, making, framework. Includes structured workflows, validation checks, and reusable patterns for development.