Classifiers
Graph Neural Network classifiers for sports analytics.
The classifiers module provides pre-built Graph Neural Network architectures optimized for sports tracking data. These models can be used with both PyTorch Geometric and Spektral (deprecated).
PyTorch Geometric
- class unravel.classifiers.PyGLightningCrystalGraphClassifier[source]
Bases:
LightningModulePyTorch Lightning wrapper for Crystal Graph Classifier with training loop.
This class wraps
PyGCrystalGraphClassifierwith PyTorch Lightning functionality, providing automatic training loops, logging, checkpointing, and metrics tracking for binary classification tasks.The model includes: - Automatic training/validation/test loops - AUROC and accuracy metric tracking - Learning rate scheduling with ReduceLROnPlateau - Automatic checkpointing and logging - Easy prediction interface
- Parameters:
n_layers (
int, optional) – Number of CGConv layers. Defaults to 3.channels (
int, optional) – Hidden dimension size. Defaults to 128.drop_out (
float, optional) – Dropout probability. Defaults to 0.5.n_out (
int, optional) – Number of output features. Defaults to 1.lr (
float, optional) – Learning rate for Adam optimizer. Defaults to 0.001.weight_decay (
float, optional) – L2 penalty coefficient. Defaults to 0.0.
- Raises:
ImportError – If PyTorch Lightning or torchmetrics is not installed.
- model
The underlying GNN model.
- Type:
PyGCrystalGraphClassifier
- criterion
Binary cross-entropy loss function.
- Type:
- train_auc
Training AUROC metric.
- Type:
AUROC
- train_acc
Training accuracy metric.
- Type:
Accuracy
- val_auc
Validation AUROC metric.
- Type:
AUROC
- val_acc
Validation accuracy metric.
- Type:
Accuracy
- test_auc
Test AUROC metric.
- Type:
AUROC
- test_acc
Test accuracy metric.
- Type:
Accuracy
Example
>>> from unravel.classifiers import PyGLightningCrystalGraphClassifier >>> import pytorch_lightning as pyl >>> from torch_geometric.loader import DataLoader >>> >>> # Initialize model >>> model = PyGLightningCrystalGraphClassifier( ... n_layers=3, ... channels=128, ... lr=0.001 ... ) >>> >>> # Train >>> trainer = pyl.Trainer(max_epochs=50, accelerator="auto") >>> trainer.fit(model, train_loader, val_loader) >>> >>> # Test >>> trainer.test(model, test_loader) >>> >>> # Predict >>> predictions = trainer.predict(model, pred_loader) >>> >>> # Save/load checkpoint >>> trainer.save_checkpoint("model.ckpt") >>> model = PyGLightningCrystalGraphClassifier.load_from_checkpoint("model.ckpt")
Note
This model uses binary cross-entropy loss and is designed for binary classification tasks. For multi-class or regression tasks, you may need to modify the loss function and output activation.
- __init__(n_layers=3, channels=128, drop_out=0.5, n_out=1, lr=0.001, weight_decay=0.0)[source]
- forward(x, edge_index, edge_attr, batch)[source]
Forward pass through the model.
- Parameters:
x (
torch.Tensor) – Node features.edge_index (
torch.LongTensor) – Edge indices.edge_attr (
torch.Tensor) – Edge features.batch (
torch.LongTensor) – Batch vector.
- Returns:
Predictions with shape [batch_size].
- Return type:
- training_step(batch, batch_idx)[source]
Training step executed for each batch.
- Parameters:
batch – Batch of graph data from DataLoader.
batch_idx (
int) – Index of the current batch.
- Returns:
Training loss for this batch.
- Return type:
- validation_step(batch, batch_idx)[source]
Validation step executed for each batch.
- Parameters:
batch – Batch of graph data from DataLoader.
batch_idx (
int) – Index of the current batch.
- Returns:
Validation loss for this batch.
- Return type:
- test_step(batch, batch_idx)[source]
Test step for model evaluation.
Computes test loss and metrics (AUROC and accuracy) for the given batch.
- Parameters:
batch – Batch of graph data from DataLoader.
batch_idx (
int) – Index of the current batch.
- Returns:
Test loss for this batch.
- Return type:
- predict_step(batch, batch_idx)[source]
Prediction step for inference.
Returns predicted probabilities for the given batch without computing loss.
- Parameters:
batch – Batch of graph data from DataLoader.
batch_idx (
int) – Index of the current batch.
- Returns:
- Predicted probabilities with shape [batch_size].
Values are in range [0, 1].
- Return type:
Example
>>> predictions = trainer.predict(model, pred_loader) >>> # predictions is a list of tensors, one per batch >>> all_preds = torch.cat(predictions)
- configure_optimizers()[source]
Configure optimizer and learning rate scheduler.
Uses Adam optimizer with learning rate scheduling via ReduceLROnPlateau. The learning rate is reduced by a factor of 0.5 when validation loss plateaus for 3 epochs.
- Returns:
- Dictionary containing:
’optimizer’: Adam optimizer instance
’lr_scheduler’: Dict with scheduler and monitoring configuration
- Return type:
Note
The learning rate scheduler monitors ‘val_loss’ and reduces the learning rate when validation loss stops improving.
Spektral
Usage Examples
PyTorch Geometric
from unravel.classifiers import PyGLightningCrystalGraphClassifier
import pytorch_lightning as pyl
from torch_geometric.loader import DataLoader
# Initialize model
model = PyGLightningCrystalGraphClassifier(
node_features=12,
edge_features=6,
global_features=0,
output_features=1,
learning_rate=0.001,
)
# Train
trainer = pyl.Trainer(max_epochs=50)
trainer.fit(model, train_loader, val_loader)
# Test
trainer.test(model, test_loader)
# Predict
predictions = trainer.predict(model, pred_loader)
Spektral
from unravel.classifiers import CrystalGraphClassifier
# Initialize model
model = CrystalGraphClassifier(
node_features=12,
edge_features=6,
output_features=1,
)
# Compile
model.compile(
optimizer='adam',
loss='binary_crossentropy',
metrics=['accuracy']
)
# Train
model.fit(x=train_data, y=train_labels, epochs=50, validation_data=(val_data, val_labels))