Source code for unravel.american_football.graphs.graph_converter
from warnings import warn
from dataclasses import dataclass
import polars as pl
import numpy as np
from typing import List, Optional
from ..dataset import BigDataBowlDataset, Group, Column, Constant
from .graph_settings import (
AmericanFootballGraphSettings,
AmericanFootballPitchDimensions,
)
from .features import (
compute_node_features,
compute_edge_features,
compute_adjacency_matrix,
)
from ...utils import *
[docs]
@dataclass(repr=True)
class AmericanFootballGraphConverter(DefaultGraphConverter):
"""Convert NFL Big Data Bowl tracking data to graph structures for GNN training.
This class transforms American Football tracking data from Polars DataFrames into
graph representations suitable for Graph Neural Networks (GNNs). Each frame becomes
a graph with players and the football as nodes, with edges representing spatial
relationships or team affiliations.
The converter supports two GNN frameworks:
- PyTorch Geometric (recommended) via :meth:`to_pytorch_graphs`
- Spektral (deprecated, Python 3.11 only) via :meth:`to_spektral_graphs`
Graph Structure:
- **Nodes**: Players (22 total: 11 offense + 11 defense) and football
- **Node Features**: Position, velocity, acceleration, orientation, direction,
height, weight, position type (12+ default features)
- **Edges**: Defined by adjacency_matrix_type (team-based, spatial, or dense)
- **Edge Features**: Distances, angles, relative velocities, orientations (7 default features)
- **Global Features**: Optional play-level features attached to football node
- **Labels**: Play outcome or custom labels (e.g., yards gained, tackle probability)
The graph structure captures:
- Offensive and defensive formations
- Player movements and accelerations
- Spatial relationships between players
- Position-specific information (QB, WR, CB, etc.)
- Body orientations and movement directions
- Anthropometric data (height, weight)
Args:
dataset (BigDataBowlDataset): Preprocessed NFL tracking data with player positions,
velocities, and play information.
chunk_size (int, optional): Number of frames to process per batch for memory
efficiency. Defaults to 2000.
attacking_non_qb_node_value (float, optional): Node feature value (0-1) assigned
to offensive players who are not the quarterback. Used to distinguish QB from
other offensive players in node features. Defaults to 0.1.
graph_feature_cols (Optional[List[str]], optional): List of column names containing
graph-level features (e.g., win probability, expected points) to attach to the
football node. These columns must have the same value for all nodes in each frame.
Defaults to None (no graph features).
**kwargs: Additional parameters passed to DefaultGraphConverter, including:
- adjacency_matrix_type: Edge connectivity pattern
- label_col: Column name for graph labels
- graph_id_col: Column name for graph identifiers
- prediction: Whether in prediction mode (no labels)
Attributes:
dataset (pl.DataFrame): Processed tracking data from BigDataBowlDataset.
settings (AmericanFootballGraphSettings): Configuration with pitch dimensions,
adjacency patterns, and feature settings.
label_column (str): Name of the label column for supervised learning.
graph_id_column (str): Name of the graph ID column for batching.
Raises:
Exception: If dataset is not an instance of BigDataBowlDataset.
Exception: If label_column or graph_id_column are not strings.
Exception: If label_column is missing when not in prediction mode.
Exception: If graph_id_column is missing from dataset.
Exception: If attacking_non_qb_node_value is not float or int.
Exception: If frames with missing football or insufficient players are detected.
Example:
>>> from unravel.american_football.dataset import BigDataBowlDataset
>>> from unravel.american_football.graphs import AmericanFootballGraphConverter
>>>
>>> # Load Big Data Bowl data
>>> dataset = BigDataBowlDataset(
... tracking_file_path="tracking.csv",
... players_file_path="players.csv",
... plays_file_path="plays.csv"
... )
>>>
>>> # Add labels and graph IDs
>>> dataset.add_dummy_labels()
>>> dataset.add_graph_ids()
>>>
>>> # Initialize converter
>>> converter = AmericanFootballGraphConverter(
... dataset=dataset,
... adjacency_matrix_type="delaunay",
... label_col="label",
... graph_id_col="graph_id"
... )
>>>
>>> # Convert to PyTorch Geometric format
>>> pyg_dataset = converter.to_pytorch_graphs()
>>> print(f"Number of graphs: {len(pyg_dataset)}")
>>> print(f"Node features: {pyg_dataset[0].x.shape}")
>>> print(f"Edge features: {pyg_dataset[0].edge_attr.shape}")
>>>
>>> # Add graph-level features (e.g., expected points)
>>> # First, join expected points to dataset
>>> dataset.data = dataset.data.join(
... expected_points_df,
... on=["game_id", "play_id"],
... how="left"
... )
>>> converter = AmericanFootballGraphConverter(
... dataset=dataset,
... graph_feature_cols=["expected_points", "win_probability"]
... )
>>> pyg_dataset = converter.to_pytorch_graphs()
Note:
- The converter automatically filters out frames with missing footballs or
insufficient players (< 10 per frame).
- Node ordering: Offensive players (sorted by ascending team_id sort), then
defensive players, then football (always last).
- The QB receives a special node feature value (1.0), while other offensive
players receive attacking_non_qb_node_value (default 0.1).
- Graph-level features must be constant within each frame. If they vary, a
ValueError is raised.
- Position names are encoded as one-hot vectors in node features.
Warning:
Spektral support is deprecated and only works on Python 3.11. Use PyTorch
Geometric for new projects.
See Also:
:class:`~unravel.american_football.dataset.BigDataBowlDataset`: Data loading
and preprocessing.
:meth:`to_pytorch_graphs`: Convert to PyTorch Geometric DataLoader.
:meth:`to_spektral_graphs`: Convert to Spektral format (deprecated).
:doc:`../tutorials/american_football`: Complete tutorial on NFL GNN modeling.
"""
[docs]
def __init__(
self,
dataset: BigDataBowlDataset,
chunk_size: int = 2_000,
attacking_non_qb_node_value: float = 0.1,
graph_feature_cols: Optional[List[str]] = None,
**kwargs,
):
super().__init__(**kwargs)
if not isinstance(dataset, BigDataBowlDataset):
raise Exception("'dataset' should be an instance of BigDataBowlDataset")
self.label_column: str = (
self.label_col if self.label_col is not None else dataset._label_column
)
self.graph_id_column: str = (
self.graph_id_col
if self.graph_id_col is not None
else dataset._graph_id_column
)
self.sample_rate = kwargs.get("sample_rate", None)
self.chunk_size = chunk_size
self.attacking_non_qb_node_value = attacking_non_qb_node_value
self.graph_feature_cols = graph_feature_cols
self.settings = self._apply_graph_settings(settings=dataset.settings)
self.dataset: pl.DataFrame = dataset.data
self._sport_specific_checks()
self._sample()
self._shuffle()
@staticmethod
def _sort(df):
sort_expr = (pl.col(Column.TEAM_ID) == Constant.BALL).cast(int) * 2 - (
(pl.col(Column.BALL_OWNING_TEAM_ID) == pl.col(Column.TEAM_ID))
& (pl.col(Column.TEAM_ID) != Constant.BALL)
).cast(int)
df = df.sort([*Group.BY_FRAME, sort_expr, pl.col(Column.OBJECT_ID)])
return df
def _sample(self):
if self.sample_rate is None:
return
else:
self.dataset = self.dataset.filter(
pl.col(Column.FRAME_ID) % (1.0 / self.sample_rate) == 0
)
def _sport_specific_checks(self):
def __remove_with_missing_values(min_object_count: int = 10):
cs = (
self.dataset.group_by(Group.BY_FRAME, maintain_order=True)
.agg(pl.len().alias("size"))
.filter(
pl.col("size") < min_object_count
) # Step 2: Keep groups with size < 10
)
self.dataset = self.dataset.join(cs, on=Group.BY_FRAME, how="anti")
if len(cs) > 0:
warn(
f"Removed {len(cs)} frames with less than {min_object_count} objects...",
UserWarning,
)
def __remove_with_missing_football():
cs = (
self.dataset.group_by(Group.BY_FRAME, maintain_order=True)
.agg(
[
pl.len().alias("size"), # Count total rows in each group
pl.col(Column.TEAM_ID)
.filter(pl.col(Column.TEAM_ID) == Constant.BALL)
.count()
.alias("football_count"), # Count rows where team == 'football'
]
)
.filter(
(pl.col("football_count") == 0)
) # Step 2: Keep groups with size < 10 and no "football"
)
self.dataset = self.dataset.join(cs, on=Group.BY_FRAME, how="anti")
if len(cs) > 0:
warn(
f"Removed {len(cs)} frames with a missing '{Constant.BALL}' object...",
UserWarning,
)
if not isinstance(self.label_column, str):
raise Exception("'label_col' should be of type string (str)")
if not isinstance(self.graph_id_column, str):
raise Exception("'graph_id_col' should be of type string (str)")
if not isinstance(self.chunk_size, int):
raise Exception("chunk_size should be of type integer (int)")
if not self.label_column in self.dataset.columns and not self.prediction:
raise Exception(
"Please specify a 'label_col' and add that column to your 'dataset' or set 'prediction=True' if you want to use the converted dataset to make predictions on."
)
if not self.graph_id_column in self.dataset.columns:
raise Exception(
"Please specify a 'graph_id_col' and add that column to your 'dataset' ..."
)
# Parameter Checks
if not isinstance(self.attacking_non_qb_node_value, (int, float)):
raise Exception(
"'attacking_non_qb_node_value' should be of type float or integer (int)"
)
__remove_with_missing_values(min_object_count=10)
__remove_with_missing_football()
def _apply_graph_settings(self, settings):
return AmericanFootballGraphSettings(
pitch_dimensions=settings.pitch_dimensions,
max_player_speed=settings.max_player_speed,
max_ball_speed=settings.max_ball_speed,
max_ball_acceleration=settings.max_ball_acceleration,
max_player_acceleration=settings.max_player_acceleration,
self_loop_ball=self.self_loop_ball,
adjacency_matrix_connect_type=self.adjacency_matrix_connect_type,
adjacency_matrix_type=self.adjacency_matrix_type,
label_type=self.label_type,
defending_team_node_value=self.defending_team_node_value,
attacking_non_qb_node_value=self.attacking_non_qb_node_value,
random_seed=self.random_seed,
pad=self.pad,
verbose=self.verbose,
)
@property
def _exprs_variables(self):
exprs_variables = [
Column.X,
Column.Y,
Column.SPEED,
Column.ACCELERATION,
Column.ORIENTATION,
Column.DIRECTION,
Column.TEAM_ID,
Column.POSITION_NAME,
Column.BALL_OWNING_TEAM_ID,
Column.HEIGHT_CM,
Column.WEIGHT_KG,
self.graph_id_column,
self.label_column,
]
exprs = (
exprs_variables
if self.graph_feature_cols is None
else exprs_variables + self.graph_feature_cols
)
return exprs
def _compute(self, args: List[pl.Series]) -> dict:
d = {col: args[i].to_numpy() for i, col in enumerate(self._exprs_variables)}
frame_id = args[-1][0]
if self.graph_feature_cols is not None:
failed = [
col
for col in self.graph_feature_cols
if not np.all(d[col] == d[col][0])
]
if failed:
raise ValueError(
f"""graph_feature_cols contains multiple different values for a group in the groupby ({Group.BY_FRAME}) selection for the columns {failed}. Make sure each group has the same values per individual column."""
)
graph_features = (
np.asarray([d[col] for col in self.graph_feature_cols]).T[0]
if self.graph_feature_cols
else None
)
if not np.all(d[self.graph_id_column] == d[self.graph_id_column][0]):
raise Exception(
"GraphId selection contains multiple different values. Make sure each graph_id is unique by at least game_id and frame_id..."
)
if not self.prediction and not np.all(
d[self.label_column] == d[self.label_column][0]
):
raise Exception(
"""Label selection contains multiple different values for a single selection (group by) of game_id and frame_id,
make sure this is not the case. Each group can only have 1 label."""
)
adjacency_matrix = compute_adjacency_matrix(
team=d[Column.TEAM_ID],
possession_team=d[Column.BALL_OWNING_TEAM_ID],
settings=self.settings,
)
edge_features = compute_edge_features(
adjacency_matrix=adjacency_matrix,
p=np.stack((d[Column.X], d[Column.Y]), axis=-1),
s=d[Column.SPEED],
a=d[Column.ACCELERATION],
dir=d[Column.DIRECTION],
o=d[Column.ORIENTATION],
team=d[Column.TEAM_ID],
settings=self.settings,
)
node_features = compute_node_features(
x=d[Column.X],
y=d[Column.Y],
s=d[Column.SPEED],
a=d[Column.ACCELERATION],
dir=d[Column.DIRECTION],
o=d[Column.ORIENTATION],
team=d[Column.TEAM_ID],
official_position=d[Column.POSITION_NAME],
possession_team=d[Column.BALL_OWNING_TEAM_ID],
height=d[Column.HEIGHT_CM],
weight=d[Column.WEIGHT_KG],
graph_features=graph_features,
settings=self.settings,
)
return {
"e": edge_features.tolist(), # Remove pl.Series wrapper
"x": node_features.tolist(), # Remove pl.Series wrapper
"a": adjacency_matrix.tolist(), # Remove pl.Series wrapper
"e_shape_0": edge_features.shape[0],
"e_shape_1": edge_features.shape[1],
"x_shape_0": node_features.shape[0],
"x_shape_1": node_features.shape[1],
"a_shape_0": adjacency_matrix.shape[0],
"a_shape_1": adjacency_matrix.shape[1],
self.graph_id_column: d[self.graph_id_column][0],
self.label_column: d[self.label_column][0],
"frame_id": frame_id,
}
@property
def return_dtypes(self):
return pl.Struct(
{
"e": pl.List(pl.List(pl.Float64)),
"x": pl.List(pl.List(pl.Float64)),
"a": pl.List(pl.List(pl.Int32)),
"e_shape_0": pl.Int64,
"e_shape_1": pl.Int64,
"x_shape_0": pl.Int64,
"x_shape_1": pl.Int64,
"a_shape_0": pl.Int64,
"a_shape_1": pl.Int64,
self.graph_id_column: pl.String,
self.label_column: pl.Int64,
"frame_id": pl.Int64,
}
)
def _convert(self):
# Group and aggregate in one step
return (
self.dataset.group_by(Group.BY_FRAME, maintain_order=True)
.agg(
pl.map_groups(
exprs=self._exprs_variables + [Column.FRAME_ID],
function=self._compute,
return_dtype=self.return_dtypes,
returns_scalar=True,
).alias("result_dict")
)
.with_columns(
[
*[
pl.col("result_dict").struct.field(f).alias(f)
for f in [
"a",
"e",
"x",
self.graph_id_column,
self.label_column,
"frame_id",
]
],
*[
pl.col("result_dict")
.struct.field(f"{m}_shape_{i}")
.alias(f"{m}_shape_{i}")
for m in ["a", "e", "x"]
for i in [0, 1]
],
]
)
.drop("result_dict")
)