Classifiers

Graph Neural Network classifiers for sports analytics.

The classifiers module provides pre-built Graph Neural Network architectures optimized for sports tracking data. These models can be used with both PyTorch Geometric and Spektral (deprecated).

PyTorch Geometric

class unravel.classifiers.PyGLightningCrystalGraphClassifier[source]

Bases: LightningModule

PyTorch Lightning wrapper for Crystal Graph Classifier with training loop.

This class wraps PyGCrystalGraphClassifier with PyTorch Lightning functionality, providing automatic training loops, logging, checkpointing, and metrics tracking for binary classification tasks.

The model includes: - Automatic training/validation/test loops - AUROC and accuracy metric tracking - Learning rate scheduling with ReduceLROnPlateau - Automatic checkpointing and logging - Easy prediction interface

Parameters:
  • n_layers (int, optional) – Number of CGConv layers. Defaults to 3.

  • channels (int, optional) – Hidden dimension size. Defaults to 128.

  • drop_out (float, optional) – Dropout probability. Defaults to 0.5.

  • n_out (int, optional) – Number of output features. Defaults to 1.

  • lr (float, optional) – Learning rate for Adam optimizer. Defaults to 0.001.

  • weight_decay (float, optional) – L2 penalty coefficient. Defaults to 0.0.

Raises:

ImportError – If PyTorch Lightning or torchmetrics is not installed.

model

The underlying GNN model.

Type:

PyGCrystalGraphClassifier

criterion

Binary cross-entropy loss function.

Type:

torch.nn.BCELoss

train_auc

Training AUROC metric.

Type:

AUROC

train_acc

Training accuracy metric.

Type:

Accuracy

val_auc

Validation AUROC metric.

Type:

AUROC

val_acc

Validation accuracy metric.

Type:

Accuracy

test_auc

Test AUROC metric.

Type:

AUROC

test_acc

Test accuracy metric.

Type:

Accuracy

Example

>>> from unravel.classifiers import PyGLightningCrystalGraphClassifier
>>> import pytorch_lightning as pyl
>>> from torch_geometric.loader import DataLoader
>>>
>>> # Initialize model
>>> model = PyGLightningCrystalGraphClassifier(
...     n_layers=3,
...     channels=128,
...     lr=0.001
... )
>>>
>>> # Train
>>> trainer = pyl.Trainer(max_epochs=50, accelerator="auto")
>>> trainer.fit(model, train_loader, val_loader)
>>>
>>> # Test
>>> trainer.test(model, test_loader)
>>>
>>> # Predict
>>> predictions = trainer.predict(model, pred_loader)
>>>
>>> # Save/load checkpoint
>>> trainer.save_checkpoint("model.ckpt")
>>> model = PyGLightningCrystalGraphClassifier.load_from_checkpoint("model.ckpt")

Note

This model uses binary cross-entropy loss and is designed for binary classification tasks. For multi-class or regression tasks, you may need to modify the loss function and output activation.

__init__(n_layers=3, channels=128, drop_out=0.5, n_out=1, lr=0.001, weight_decay=0.0)[source]
forward(x, edge_index, edge_attr, batch)[source]

Forward pass through the model.

Parameters:
  • x (torch.Tensor) – Node features.

  • edge_index (torch.LongTensor) – Edge indices.

  • edge_attr (torch.Tensor) – Edge features.

  • batch (torch.LongTensor) – Batch vector.

Returns:

Predictions with shape [batch_size].

Return type:

torch.Tensor

training_step(batch, batch_idx)[source]

Training step executed for each batch.

Parameters:
  • batch – Batch of graph data from DataLoader.

  • batch_idx (int) – Index of the current batch.

Returns:

Training loss for this batch.

Return type:

torch.Tensor

validation_step(batch, batch_idx)[source]

Validation step executed for each batch.

Parameters:
  • batch – Batch of graph data from DataLoader.

  • batch_idx (int) – Index of the current batch.

Returns:

Validation loss for this batch.

Return type:

torch.Tensor

test_step(batch, batch_idx)[source]

Test step for model evaluation.

Computes test loss and metrics (AUROC and accuracy) for the given batch.

Parameters:
  • batch – Batch of graph data from DataLoader.

  • batch_idx (int) – Index of the current batch.

Returns:

Test loss for this batch.

Return type:

torch.Tensor

predict_step(batch, batch_idx)[source]

Prediction step for inference.

Returns predicted probabilities for the given batch without computing loss.

Parameters:
  • batch – Batch of graph data from DataLoader.

  • batch_idx (int) – Index of the current batch.

Returns:

Predicted probabilities with shape [batch_size].

Values are in range [0, 1].

Return type:

torch.Tensor

Example

>>> predictions = trainer.predict(model, pred_loader)
>>> # predictions is a list of tensors, one per batch
>>> all_preds = torch.cat(predictions)
configure_optimizers()[source]

Configure optimizer and learning rate scheduler.

Uses Adam optimizer with learning rate scheduling via ReduceLROnPlateau. The learning rate is reduced by a factor of 0.5 when validation loss plateaus for 3 epochs.

Returns:

Dictionary containing:
  • ’optimizer’: Adam optimizer instance

  • ’lr_scheduler’: Dict with scheduler and monitoring configuration

Return type:

dict

Note

The learning rate scheduler monitors ‘val_loss’ and reduces the learning rate when validation loss stops improving.

Spektral

class unravel.classifiers.CrystalGraphClassifier[source]

Bases: Model

Default Graph Classifier with CrystalConvolution layers as presented in Sahasrabudhe & Bekkers (2023)

__init__(n_layers=3, channels=128, drop_out=0.5, n_out=1, **kwargs)[source]
Parameters:
call(inputs)[source]

Usage Examples

PyTorch Geometric

from unravel.classifiers import PyGLightningCrystalGraphClassifier
import pytorch_lightning as pyl
from torch_geometric.loader import DataLoader

# Initialize model
model = PyGLightningCrystalGraphClassifier(
    node_features=12,
    edge_features=6,
    global_features=0,
    output_features=1,
    learning_rate=0.001,
)

# Train
trainer = pyl.Trainer(max_epochs=50)
trainer.fit(model, train_loader, val_loader)

# Test
trainer.test(model, test_loader)

# Predict
predictions = trainer.predict(model, pred_loader)

Spektral

from unravel.classifiers import CrystalGraphClassifier

# Initialize model
model = CrystalGraphClassifier(
    node_features=12,
    edge_features=6,
    output_features=1,
)

# Compile
model.compile(
    optimizer='adam',
    loss='binary_crossentropy',
    metrics=['accuracy']
)

# Train
model.fit(x=train_data, y=train_labels, epochs=50, validation_data=(val_data, val_labels))