import logging
import sys
from dataclasses import dataclass, field
from typing import List, Union, Dict, Literal, Any, Optional, Callable, TYPE_CHECKING
import inspect
import pathlib
from kloppy.domain import MetricPitchDimensions, Orientation
if TYPE_CHECKING:
from spektral.data import Graph
from .graph_settings import GraphSettingsPolars
from ..dataset.kloppy_polars import KloppyPolarsDataset, Column, Group, Constant
from .features import (
compute_node_features,
add_global_features,
compute_adjacency_matrix,
compute_edge_features,
)
from ...utils import *
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
stdout_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stdout_handler)
[docs]
@dataclass(repr=True)
class SoccerGraphConverter(DefaultGraphConverter):
"""Convert soccer tracking data from Polars DataFrame to graph structures for GNN training.
This class transforms soccer tracking data into graph representations suitable for
Graph Neural Networks. Each frame of tracking data becomes a graph with players and
the ball as nodes, with edges representing spatial relationships or team affiliations.
The converter supports two GNN frameworks:
- PyTorch Geometric (recommended) via :meth:`to_pytorch_graphs`
- Spektral (deprecated, Python 3.11 only) via :meth:`to_spektral_graphs`
Graph Structure:
- **Nodes**: Players (home team, away team) and ball
- **Node Features**: Position, velocity, acceleration, distances, angles (12 default features)
- **Edges**: Defined by adjacency_matrix_type (team-based, spatial, or dense)
- **Edge Features**: Distances, angles, relative velocities (6-7 default features)
- **Global Features**: Optional match-level features attached to ball node
Key Features:
- Configurable node and edge feature engineering
- Multiple adjacency matrix types (split_by_team, delaunay, dense)
- Custom feature functions via decorators
- Automatic padding for fixed-size graphs
- Ball connection strategies (all players, carrier only, none)
- Permutation invariance via random node ordering
Args:
dataset (KloppyPolarsDataset): Polars dataset with tracking data. Must have
been processed with :meth:`~unravel.soccer.KloppyPolarsDataset.add_graph_ids`
and optionally :meth:`~unravel.soccer.KloppyPolarsDataset.add_dummy_labels`.
chunk_size (int, optional): Number of graphs to process simultaneously. Higher
values use more memory but may be faster. Defaults to 20000.
non_potential_receiver_node_value (float, optional): Node feature value (0-1)
assigned to defending team players. Used to distinguish attackers from
defenders. Defaults to 0.1.
edge_feature_funcs (List[Callable], optional): Custom edge feature functions
decorated with ``@graph_feature(type="edge")``. If None, uses defaults.
Defaults to None.
node_feature_funcs (List[Callable], optional): Custom node feature functions
decorated with ``@graph_feature(type="node")``. If None, uses defaults.
Defaults to None.
global_feature_cols (List[str], optional): Column names from the dataset to use
as graph-level features (e.g., match score, team ratings). Must be constant
within each graph_id group. Defaults to empty list.
global_feature_type (Literal["ball", "all"], optional): Where to attach global
features. "ball" attaches to ball node only, "all" attaches to all nodes.
Defaults to "ball".
additional_feature_cols (List[str], optional): Extra columns from dataset to
make available to custom feature functions (e.g., player height, position).
Defaults to empty list.
Attributes:
settings (GraphSettingsPolars): Configuration for graph conversion including
adjacency matrix type, padding, and feature settings.
n_node_features (int): Total number of node features per node.
n_edge_features (int): Total number of edge features per edge.
n_graph_features (int): Total number of global/graph-level features.
Raises:
ValueError: If dataset is not a KloppyPolarsDataset.
ValueError: If required columns (graph_id, label) are missing.
ValueError: If custom feature functions are not properly decorated.
Example:
>>> from unravel.soccer import KloppyPolarsDataset, SoccerGraphConverter
>>> from kloppy import sportec
>>>
>>> # Load and prepare data
>>> kloppy_dataset = sportec.load_open_tracking_data(only_alive=True)
>>> polars_dataset = KloppyPolarsDataset(kloppy_dataset=kloppy_dataset)
>>> polars_dataset.add_dummy_labels(by=["frame_id"])
>>> polars_dataset.add_graph_ids(by=["frame_id"])
>>>
>>> # Create converter
>>> converter = SoccerGraphConverter(
... dataset=polars_dataset,
... self_loop_ball=True,
... adjacency_matrix_connect_type="ball",
... adjacency_matrix_type="split_by_team",
... label_type="binary",
... )
>>>
>>> # Convert to PyTorch Geometric format
>>> graphs = converter.to_pytorch_graphs()
>>> print(f"Created {len(graphs)} graphs")
>>> print(f"Node features: {converter.n_node_features}")
>>> print(f"Edge features: {converter.n_edge_features}")
Note:
For detailed configuration options, see :class:`~unravel.soccer.GraphSettingsPolars`.
For custom features, see :func:`~unravel.utils.features.graph_feature` decorator.
Warning:
If not using padding (``pad=False``), graphs with incomplete player data
(< 22 players) will be dropped. Use ``pad=True`` for variable-sized teams.
See Also:
:class:`~unravel.soccer.KloppyPolarsDataset`: Prepare tracking data.
:class:`~unravel.utils.GraphDataset`: Wrap graphs for training.
:func:`~unravel.utils.features.graph_feature`: Create custom features.
:doc:`../tutorials/soccer_gnn`: Complete GNN training tutorial.
`Graph FAQ <https://github.com/unravelsports/unravelsports/blob/main/examples/graphs_faq.md>`_: Detailed configuration guide.
"""
dataset: KloppyPolarsDataset = None
chunk_size: int = 2_0000
non_potential_receiver_node_value: float = 0.1
edge_feature_funcs: List[Callable[[Dict[str, Any]], np.ndarray]] = field(
repr=False, default_factory=list
)
node_feature_funcs: List[Callable[[Dict[str, Any]], np.ndarray]] = field(
repr=False, default_factory=list
)
global_feature_cols: Optional[List[str]] = field(repr=False, default_factory=list)
global_feature_type: Literal["ball", "all"] = "ball"
additional_feature_cols: Optional[List[str]] = field(
repr=False, default_factory=list
)
_edge_feature_dims: Dict[str, int] = field(
repr=False, default_factory=dict, init=False
)
_node_feature_dims: Dict[str, int] = field(
repr=False, default_factory=dict, init=False
)
def __post_init__(self):
if not isinstance(self.dataset, KloppyPolarsDataset):
raise ValueError("dataset should be of type KloppyPolarsDataset...")
self.pitch_dimensions: MetricPitchDimensions = (
self.dataset.settings.pitch_dimensions
)
self._kloppy_settings = self.dataset.settings
self.label_column: str = (
self.label_col if self.label_col is not None else self.dataset._label_column
)
self.graph_id_column: str = (
self.graph_id_col
if self.graph_id_col is not None
else self.dataset._graph_id_column
)
self.dataset = self.dataset.data
if not self.edge_feature_funcs:
self.edge_feature_funcs = self.default_edge_feature_funcs
self._verify_feature_funcs(self.edge_feature_funcs, feature_type="edge")
if not self.node_feature_funcs:
self.node_feature_funcs = self.default_node_feature_funcs
self._verify_feature_funcs(self.node_feature_funcs, feature_type="node")
self._sport_specific_checks()
self.settings = self._apply_graph_settings()
if self.pad:
self.dataset = self._apply_padding()
else:
self.dataset = self._remove_incomplete_frames()
self._sample()
self._shuffle()
def _sample(self):
if self.sample_rate is None:
return
else:
self.dataset = self.dataset.filter(
pl.col(Column.FRAME_ID) % (1.0 / self.sample_rate) == 0
)
@staticmethod
def _sort(df):
sort_expr = (pl.col(Column.TEAM_ID) == Constant.BALL).cast(int) * 2 - (
(pl.col(Column.BALL_OWNING_TEAM_ID) == pl.col(Column.TEAM_ID))
& (pl.col(Column.TEAM_ID) != Constant.BALL)
).cast(int)
df = df.sort([*Group.BY_FRAME, sort_expr, pl.col(Column.OBJECT_ID)])
return df
def _remove_incomplete_frames(self) -> pl.DataFrame:
df = self.dataset
total_frames = len(df.unique(Group.BY_FRAME))
valid_frames = (
df.group_by(Group.BY_FRAME, maintain_order=True)
.agg(pl.col(Column.TEAM_ID).n_unique().alias("unique_teams"))
.filter(pl.col("unique_teams") == 3)
.select(Group.BY_FRAME)
)
dropped_frames = total_frames - len(valid_frames.unique(Group.BY_FRAME))
if dropped_frames > 0 and self.verbose:
self.__warn_dropped_frames(dropped_frames, total_frames)
return df.join(valid_frames, on=Group.BY_FRAME)
def _apply_padding(self) -> pl.DataFrame:
df = self.dataset
keep_columns = [
Column.TIMESTAMP,
Column.BALL_STATE,
self.label_column,
self.graph_id_column,
]
empty_columns = [
Column.POSITION_NAME,
Column.OBJECT_ID,
Column.IS_BALL_CARRIER,
Column.X,
Column.Y,
Column.Z,
Column.VX,
Column.VY,
Column.VZ,
Column.SPEED,
Column.AX,
Column.AY,
Column.AZ,
Column.ACCELERATION,
]
group_by_columns = [
Column.GAME_ID,
Column.PERIOD_ID,
Column.FRAME_ID,
Column.TEAM_ID,
Column.BALL_OWNING_TEAM_ID,
]
user_defined_columns = [
x
for x in df.columns
if x
not in keep_columns
+ group_by_columns
+ empty_columns
+ self.global_feature_cols
]
counts = df.group_by(group_by_columns, maintain_order=True).agg(
pl.len().alias("count"),
*[
pl.first(col).alias(col)
for col in keep_columns + self.global_feature_cols
],
)
counts = counts.with_columns(
[
pl.when(pl.col(Column.TEAM_ID) == Constant.BALL)
.then(1)
.when(pl.col(Column.TEAM_ID) == pl.col(Column.BALL_OWNING_TEAM_ID))
.then(11)
.otherwise(11)
.alias("target_length")
]
)
groups_to_pad = counts.filter(
pl.col("count") < pl.col("target_length")
).with_columns((pl.col("target_length") - pl.col("count")).alias("repeats"))
padding_rows = []
# This is where we pad players (missing balls get skipped because of 'target_length')
for row in groups_to_pad.iter_rows(named=True):
base_row = {
col: row[col]
for col in keep_columns + group_by_columns + self.global_feature_cols
}
padding_rows.extend([base_row] * row["repeats"])
# Now check if there are frames without ball rows
# Get all unique frames
all_frames = df.select(
[
Column.GAME_ID,
Column.PERIOD_ID,
Column.FRAME_ID,
Column.BALL_OWNING_TEAM_ID,
]
+ keep_columns
+ self.global_feature_cols
).unique()
# Get frames that have ball rows
frames_with_ball = (
df.filter(pl.col(Column.TEAM_ID) == Constant.BALL)
.select([Column.GAME_ID, Column.PERIOD_ID, Column.FRAME_ID])
.unique()
)
# Find frames missing ball rows
frames_missing_ball = all_frames.join(
frames_with_ball,
on=[Column.GAME_ID, Column.PERIOD_ID, Column.FRAME_ID],
how="anti",
)
# Create a dataframe of ball rows to add with appropriate columns
if frames_missing_ball.height > 0:
# Create base rows for missing balls
ball_rows_to_add = frames_missing_ball.with_columns(
[
pl.lit(Constant.BALL).alias(Column.TEAM_ID),
pl.lit(Constant.BALL).alias(Column.POSITION_NAME),
]
)
# Add to padding rows using same pattern as for players
for row in ball_rows_to_add.iter_rows(named=True):
base_row = {
col: row[col]
for col in keep_columns
+ group_by_columns
+ [Column.POSITION_NAME]
+ self.global_feature_cols
if col in row
}
padding_rows.append(base_row)
if len(padding_rows) == 0:
return df
padding_df = pl.DataFrame(padding_rows)
schema = df.schema
padding_df = padding_df.with_columns(
[create_default_expression(col, schema[col]) for col in empty_columns]
+ [
pl.lit(None).cast(schema[col]).alias(col)
for col in user_defined_columns
]
)
padding_df = padding_df.with_columns(
[pl.col(col).cast(df.schema[col]).alias(col) for col in group_by_columns]
)
padding_df = padding_df.join(
(
df.unique(group_by_columns).select(
group_by_columns + self.global_feature_cols
)
),
on=group_by_columns,
how="left",
)
padding_df = padding_df.with_columns(
[
pl.col(col_name).cast(df.schema[col_name]).alias(col_name)
for col_name in df.columns
]
).select(df.columns)
result = pl.concat([df, padding_df], how="vertical")
total_frames = result.select(Group.BY_FRAME).unique().height
frame_completeness = (
result.group_by(Group.BY_FRAME, maintain_order=True)
.agg(
[
(pl.col(Column.TEAM_ID).eq(Constant.BALL).sum() == 1).alias(
"has_ball"
),
(
pl.col(Column.TEAM_ID)
.eq(pl.col(Column.BALL_OWNING_TEAM_ID))
.sum()
== 11
).alias("has_owning_team"),
(
(
~pl.col(Column.TEAM_ID).eq(Constant.BALL)
& ~pl.col(Column.TEAM_ID).eq(
pl.col(Column.BALL_OWNING_TEAM_ID)
)
).sum()
== 11
).alias("has_other_team"),
]
)
.filter(
pl.col("has_ball")
& pl.col("has_owning_team")
& pl.col("has_other_team")
)
)
complete_frames = frame_completeness.height
dropped_frames = total_frames - complete_frames
if dropped_frames > 0 and self.verbose:
self.__warn_dropped_frames(dropped_frames, total_frames)
return result.join(frame_completeness, on=Group.BY_FRAME, how="inner")
@staticmethod
def __warn_dropped_frames(dropped_frames, total_frames):
import warnings
warnings.warn(
f"""Setting pad=True drops frames that do not have at least 1 object for the attacking team, defending team or ball.
This operation dropped {dropped_frames} incomplete frames out of {total_frames} total frames ({(dropped_frames/total_frames)*100:.2f}%)
"""
)
def _apply_graph_settings(self):
return GraphSettingsPolars(
home_team_id=str(self._kloppy_settings.home_team_id),
away_team_id=str(self._kloppy_settings.away_team_id),
players=self._kloppy_settings.players,
features={
"edge": [x.__name__ for x in self.edge_feature_funcs],
"node": [x.__name__ for x in self.node_feature_funcs],
"global": self.global_feature_cols,
},
orientation=self._kloppy_settings.orientation,
pitch_dimensions=self.pitch_dimensions,
max_player_speed=self.settings.max_player_speed,
max_ball_speed=self.settings.max_ball_speed,
max_player_acceleration=self.settings.max_player_acceleration,
max_ball_acceleration=self.settings.max_ball_acceleration,
self_loop_ball=self.self_loop_ball,
adjacency_matrix_connect_type=self.adjacency_matrix_connect_type,
adjacency_matrix_type=self.adjacency_matrix_type,
label_type=self.label_type,
defending_team_node_value=self.defending_team_node_value,
non_potential_receiver_node_value=self.non_potential_receiver_node_value,
random_seed=self.random_seed,
pad=self.pad,
verbose=self.verbose,
)
def _sport_specific_checks(self):
if not isinstance(self.label_column, str):
raise Exception("'label_col' should be of type string (str)")
if not isinstance(self.graph_id_column, str):
raise Exception("'graph_id_col' should be of type string (str)")
if not isinstance(self.chunk_size, int):
raise Exception("chunk_size should be of type integer (int)")
if not self.label_column in self.dataset.columns and not self.prediction:
raise Exception(
"Please specify a 'label_col' and add that column to your 'dataset' or set 'prediction=True' if you want to use the converted dataset to make predictions on."
)
if not self.label_column in self.dataset.columns and self.prediction:
self.dataset = self.dataset.with_columns(
pl.lit(None).alias(self.label_column)
)
if not self.graph_id_column in self.dataset.columns:
raise Exception(
"Please specify a 'graph_id_col' and add that column to your 'dataset' ..."
)
if self.non_potential_receiver_node_value and not isinstance(
self.non_potential_receiver_node_value, float
):
raise Exception(
"'non_potential_receiver_node_value' should be of type float"
)
@property
def _exprs_variables(self):
exprs_variables = [
Column.X,
Column.Y,
Column.Z,
Column.SPEED,
Column.VX,
Column.VY,
Column.VZ,
Column.ACCELERATION,
Column.AX,
Column.AY,
Column.AZ,
Column.TEAM_ID,
Column.POSITION_NAME,
Column.BALL_OWNING_TEAM_ID,
Column.IS_BALL_CARRIER,
Column.OBJECT_ID,
self.graph_id_column,
self.label_column,
]
exprs = (
exprs_variables + self.global_feature_cols + self.additional_feature_cols
)
return exprs
@property
def default_node_feature_funcs(self) -> list:
return [
x_normed,
y_normed,
speeds_normed,
velocity_components_2d_normed,
distance_to_goal_normed,
distance_to_ball_normed,
is_possession_team,
is_gk,
is_ball,
angle_to_goal_components_2d_normed,
angle_to_ball_components_2d_normed,
is_ball_carrier,
]
@property
def default_edge_feature_funcs(self) -> list:
return [
distances_between_players_normed,
speed_difference_normed,
angle_between_players_normed,
velocity_difference_normed,
]
def __add_additional_kwargs(self, d):
d["ball_id"] = Constant.BALL
d["possession_team_id"] = d[Column.BALL_OWNING_TEAM_ID][0]
d["is_gk"] = np.where(
d[Column.POSITION_NAME] == self.settings.goalkeeper_id, True, False
)
d["position"] = np.nan_to_num(
np.stack((d[Column.X], d[Column.Y], d[Column.Z]), axis=-1),
nan=1e-10,
posinf=1e3,
neginf=-1e3,
)
d["velocity"] = np.nan_to_num(
np.stack((d[Column.VX], d[Column.VY], d[Column.VZ]), axis=-1),
nan=1e-10,
posinf=1e3,
neginf=-1e3,
)
if len(np.where(d["team_id"] == d["ball_id"])[0]) >= 1:
ball_index = np.where(d["team_id"] == d["ball_id"])[0]
ball_position = d["position"][ball_index][0]
else:
ball_position = np.asarray([0.0, 0.0, 0.0])
ball_index = 0
ball_carriers = np.where(d[Column.IS_BALL_CARRIER] == True)[0]
if len(ball_carriers) == 0:
ball_carrier_idx = None
else:
ball_carrier_idx = ball_carriers[0]
d["ball_position"] = ball_position
d["ball_idx"] = ball_index
d["ball_carrier_idx"] = ball_carrier_idx
return d
def _compute(self, args: List[pl.Series]) -> dict:
frame_data: dict = {
col: args[i].to_numpy() for i, col in enumerate(self._exprs_variables)
}
frame_data = self.__add_additional_kwargs(frame_data)
frame_id = args[-1][0]
ball_owning_team_id = frame_data[Column.BALL_OWNING_TEAM_ID][0]
if not np.all(
frame_data[self.graph_id_column] == frame_data[self.graph_id_column][0]
):
raise ValueError(
"graph_id selection contains multiple different values. Make sure each graph_id is unique by at least game_id and frame_id..."
)
if not self.prediction and not np.all(
frame_data[self.label_column] == frame_data[self.label_column][0]
):
raise ValueError(
"""Label selection contains multiple different values for a single selection (group by) of game_id and frame_id,
make sure this is not the case. Each group can only have 1 label."""
)
adjacency_matrix = compute_adjacency_matrix(
settings=self.settings, **frame_data
)
edge_features, self._edge_feature_dims = compute_edge_features(
adjacency_matrix=adjacency_matrix,
funcs=self.edge_feature_funcs,
opts=self.feature_opts,
settings=self.settings,
**frame_data,
)
node_features, self._node_feature_dims = compute_node_features(
funcs=self.node_feature_funcs,
opts=self.feature_opts,
settings=self.settings,
**frame_data,
)
if self.global_feature_cols:
failed = [
col
for col in self.global_feature_cols
if not np.all(frame_data[col] == frame_data[col][0])
]
if failed:
raise ValueError(
f"""graph_feature_cols contains multiple different values for a group in the groupby ({Group.BY_FRAME}) selection for the columns {failed}. Make sure each group has the same values per individual column."""
)
global_features = (
np.asarray([frame_data[col] for col in self.global_feature_cols]).T[0]
if self.global_feature_cols
else None
)
for col in self.global_feature_cols:
self._node_feature_dims[col] = 1
node_features = add_global_features(
node_features=node_features,
global_features=global_features,
global_feature_type=self.global_feature_type,
**frame_data,
)
return {
"e": edge_features.tolist(), # Remove pl.Series wrapper
"x": node_features.tolist(), # Remove pl.Series wrapper
"a": adjacency_matrix.tolist(), # Remove pl.Series wrapper
"e_shape_0": edge_features.shape[0],
"e_shape_1": edge_features.shape[1],
"x_shape_0": node_features.shape[0],
"x_shape_1": node_features.shape[1],
"a_shape_0": adjacency_matrix.shape[0],
"a_shape_1": adjacency_matrix.shape[1],
self.graph_id_column: frame_data[self.graph_id_column][0],
self.label_column: frame_data[self.label_column][0],
"frame_id": frame_id,
"object_ids": frame_data[
Column.OBJECT_ID
].tolist(), # Remove pl.Series wrapper
"ball_owning_team_id": ball_owning_team_id,
}
def _convert(self):
# Group and aggregate in one step
return (
self.dataset.group_by(Group.BY_FRAME, maintain_order=True)
.agg(
pl.map_groups(
exprs=self._exprs_variables + [Column.FRAME_ID],
function=self._compute,
return_dtype=self.return_dtypes,
returns_scalar=True,
).alias("result_dict")
)
.with_columns(
[
*[
pl.col("result_dict").struct.field(f).alias(f)
for f in [
"a",
"e",
"x",
self.graph_id_column,
self.label_column,
"frame_id",
"object_ids",
"ball_owning_team_id",
]
],
*[
pl.col("result_dict")
.struct.field(f"{m}_shape_{i}")
.alias(f"{m}_shape_{i}")
for m in ["a", "e", "x"]
for i in [0, 1]
],
]
)
.drop("result_dict")
)
[docs]
def get_players_by_team_id(self, team_id):
return [
player for player in self.settings.players if player["team_id"] == team_id
]
[docs]
def get_player_by_id(self, player_id):
for player in self.settings.players:
if player["player_id"] == player_id:
return player
return None
[docs]
def plot(
self,
file_path: str,
fps: int = None,
timestamp: pl.duration = None,
end_timestamp: pl.duration = None,
period_id: int = None,
team_color_a: str = "#CD0E61",
team_color_b: str = "#0066CC",
ball_color: str = "black",
sort: bool = True,
color_by: Literal["ball_owning", "static_home_away"] = "ball_owning",
anonymous: bool = False,
plot_type: Literal["pitch_only", "graph_only", "full"] = "full",
show_label: bool = True,
show_ball_label: bool = False,
show_timestamp: bool = True,
next_closest_timestamp: bool = False,
):
"""
Plot tracking data as a static image or video file.
This method visualizes tracking data for players and the ball. It can generate either:
- A single PNG image (if either fps or end_timestamp is None, or both are None)
- An MP4 video (if both fps and end_timestamp are provided)
Parameters
----------
file_path : str
The output path where the PNG or MP4 file will be saved
fps : int, optional
Frames per second for video output. If None, a static image is generated
timestamp : pl.duration, optional
The starting timestamp to plot. If None, starts from the beginning of available data
end_timestamp : pl.duration, optional
The ending timestamp for video output. If None, a static image is generated
period_id : int, optional
ID of the match period to visualize. If None, all periods are included
team_color_a : str, default "#CD0E61"
Hex color code for Team A visualization
team_color_b : str, default "#0066CC"
Hex color code for Team B visualization
ball_color : str, default "black"
Color for ball visualization
color_by : Literal["ball_owning", "static_home_away"], default "ball_owning"
Method for coloring the teams:
- "ball_owning": Colors teams based on ball possession
- "static_home_away": Uses static colors for home and away teams
anonymous : bool, default False
Whether to anonymize player labels
plot_type : Literal["pitch_only", "graph_only", "full"], default "full"
Type of plot to generate:
- "pitch_only": Shows only the soccer pitch visualization
- "graph_only": Shows only the graph features (node features, adjacency matrix, edge features)
- "full": Shows both pitch and graph visualizations
show_pitch_label : bool, default True
Whether to show the label on the pitch visualization
show_pitch_timestamp : bool, default True
Whether to show the timestamp on the pitch visualization
next_closest_timestamp : bool, default False
When plotting a .png and the timestamp isn't 100% correct we find the next correct timestamp and use that to plot.
Returns
-------
None
The function saves the output file to the specified file_path but doesn't return any value
Notes
-----
Output file type is determined by parameters:
- PNG: Generated when either fps or end_timestamp is None, or both are None
- MP4: Generated when both fps and end_timestamp are provided
Raises
------
ValueError
If file extension doesn't match the parameters provided (e.g., .mp4 extension
but missing fps or end_timestamp, or .png extension with both fps and end_timestamp)
"""
try:
import matplotlib.animation as animation
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
except ImportError:
raise ImportError(
"Seems like you don't have matplotlib installed. Please"
" install it using: pip install matplotlib"
)
if (fps is None and end_timestamp is not None) or (
fps is not None and end_timestamp is None
):
raise ValueError(
"Both 'fps' and 'end_timestamp' must be provided together to generate a video. "
)
# Determine the output type based on parameters
generate_video = fps is not None and end_timestamp is not None
# Get file extension if it exists
path = pathlib.Path(file_path)
file_extension = path.suffix.lower() if path.suffix else ""
# If no extension, add the appropriate one based on parameters
if not file_extension:
suffix = ".mp4" if generate_video else ".png"
file_path = str(path.with_suffix(suffix))
# Otherwise, validate that the extension matches the parameters
else:
if generate_video and file_extension != ".mp4":
raise ValueError(
f"Parameters fps and end_timestamp indicate video output, "
f"but file extension is '{file_extension}'. Use '.mp4' extension for video output."
)
elif not generate_video and file_extension == ".mp4":
raise ValueError(
"To generate an MP4 video, both 'fps' and 'end_timestamp' must be provided. "
"For static image output, use a '.png' extension."
)
elif not generate_video and file_extension != ".png":
raise ValueError(
f"For static image output, use '.png' extension instead of '{file_extension}'."
)
self._team_color_a = team_color_a
self._team_color_b = team_color_b
self._ball_color = ball_color
self._color_by = color_by
self._plot_type = plot_type
self._show_label = show_label
self._show_ball_label = show_ball_label
self._show_pitch_timestamp = show_timestamp
self._next_closest_timestamp = next_closest_timestamp
self._ball_carrier_color = "black"
if period_id is not None and not isinstance(period_id, int):
raise TypeError("period_id should be of type integer")
if all(x is None for x in [timestamp, end_timestamp, period_id]):
# No filters specified, use the entire dataset
df = self.dataset
elif timestamp is not None and period_id is not None:
if end_timestamp is not None:
# Both timestamp and end_timestamp provided - filter for a range
df = self.dataset.filter(
(pl.col(Column.TIMESTAMP).is_between(timestamp, end_timestamp))
& (pl.col(Column.PERIOD_ID) == period_id)
)
else:
# Only timestamp provided (no end_timestamp) - filter for specific timestamp
df = self.dataset.filter(
(pl.col(Column.TIMESTAMP) == timestamp)
& (pl.col(Column.PERIOD_ID) == period_id)
)
# Handle the case where a single timestamp has multiple frame_ids
df = (
df.with_columns(
pl.col(Column.FRAME_ID)
.rank(method="min")
.over(Column.TIMESTAMP)
.alias("frame_rank")
)
# Keep only rows where the frame has rank = 1 (first frame for each timestamp)
.filter(pl.col("frame_rank") == 1).drop("frame_rank")
)
else:
raise ValueError(
"Please specify both timestamp and period_id, or specify all of timestamp, end_timestamp, and period_id, or none of them."
)
if df.is_empty():
if not generate_video and self._next_closest_timestamp:
idx = self.dataset.sort(Column.FRAME_ID)[
Column.TIMESTAMP
].search_sorted(timestamp)
result = self.dataset[idx]
df = self.dataset.filter(
(pl.col(Column.TIMESTAMP) == result[Column.TIMESTAMP][0])
& (pl.col(Column.PERIOD_ID) == result[Column.PERIOD_ID][0])
)
# Handle the case where a single timestamp has multiple frame_ids
df = (
df.with_columns(
pl.col(Column.FRAME_ID)
.rank(method="min")
.over(Column.TIMESTAMP)
.alias("frame_rank")
)
# Keep only rows where the frame has rank = 1 (first frame for each timestamp)
.filter(pl.col("frame_rank") == 1).drop("frame_rank")
)
else:
if not generate_video:
raise ValueError(
"Selection is empty, please try different timestamp(s) or set next_closest_timestamp=True..."
)
raise ValueError(
"Selection is empty, please try different timestamp(s)..."
)
def setup_gridspec():
"""Setup GridSpec based on plot_type"""
if self._plot_type == "pitch_only":
return GridSpec(1, 1, left=0.05, right=0.95, bottom=0.05, top=0.95)
elif self._plot_type == "graph_only":
return GridSpec(
2,
2,
width_ratios=[1.2, 0.8],
height_ratios=[1, 1],
wspace=0.2, # Increased spacing
hspace=0.3, # Increased spacing
left=0.08,
right=0.92,
bottom=0.1, # More bottom margin
top=0.9,
) # More top margin
else: # "full"
return GridSpec(
2,
3,
width_ratios=[2, 1, 3],
height_ratios=[1, 1],
wspace=0.15, # Increased spacing
hspace=0.1, # Increased spacing
left=0.05,
right=0.98, # Slightly reduced right margin
bottom=0.08, # More bottom margin
top=0.95,
)
def plot_graph():
"""Plot graph features (node features, adjacency matrix, edge features)"""
import matplotlib.pyplot as plt
num_rows = self._graph["x"].shape[0]
labels = (
[
(
self.get_player_by_id(pid)["jersey_no"]
if pid != Constant.BALL
else Constant.BALL
)
for pid in self._graph["object_ids"]
]
if not anonymous
else [str(i) for i in range(num_rows)]
)
# Determine subplot positions based on plot_type
if self._plot_type == "graph_only":
node_pos = (0, 0)
adj_pos = (1, 0)
edge_pos = (slice(None), 1)
else: # "full"
node_pos = (0, 0)
adj_pos = (1, 0)
edge_pos = (slice(None), 1)
# Plot node features
ax1 = self._fig.add_subplot(self._gs[node_pos])
ax1.imshow(self._graph["x"], aspect="auto", cmap="YlOrRd")
ax1.set_xlabel(f"Node Features {self._graph['x'].shape}")
# Set y labels to integers
ax1.set_yticks(range(num_rows))
ax1.set_yticklabels(labels)
node_feature_yticklabels = feature_ticklabels(self._node_feature_dims)
ax1.xaxis.set_ticks_position("top")
ax1.set_xticks(range(len(node_feature_yticklabels)))
ax1.set_xticklabels(node_feature_yticklabels, rotation=45, ha="left")
# Plot adjacency matrix
ax2 = self._fig.add_subplot(self._gs[adj_pos])
ax2.imshow(self._graph["a"].toarray(), aspect="auto", cmap="YlOrRd")
ax2.set_xlabel(f"Adjacency Matrix {self._graph['a'].shape}")
# Set both x and y labels to integers
num_rows_a = self._graph["a"].toarray().shape[0]
num_cols_a = self._graph["a"].toarray().shape[1]
ax2.set_yticks(range(num_rows_a))
ax2.set_yticklabels(labels)
ax2.xaxis.set_ticks_position("top")
ax2.set_xticks(range(num_cols_a))
ax2.set_xticklabels(labels)
# Plot Edge Features
ax3 = self._fig.add_subplot(self._gs[edge_pos])
_, size_a = non_zeros(
self._graph["a"].toarray()[0 : self._ball_carrier_idx]
)
ball_carrier_edge_idx, num_rows_e = non_zeros(
np.asarray(
[list(x) for x in self._graph["a"].toarray()][
self._ball_carrier_idx
]
)
)
im3 = ax3.imshow(
self._graph["e"][size_a : num_rows_e + size_a, :],
aspect="auto",
cmap="YlOrRd",
)
ax3.set_yticks(range(num_rows_e))
ax3.set_yticklabels(list(ball_carrier_edge_idx[0]), fontsize=18)
ball_carrier_edge_idxs = list(ball_carrier_edge_idx[0])
ax3.set_xlabel(f"Edge Features {self._graph['e'].shape}")
ax3_labels = ax3.get_yticklabels()
if self._ball_carrier_idx in ball_carrier_edge_idx[0]:
idx_position = list(ball_carrier_edge_idx[0]).index(
self._ball_carrier_idx
)
# Modify just that specific label
ax3_labels[idx_position].set_color(self._ball_carrier_color)
ax3_labels[idx_position].set_fontweight("bold")
# Set the modified labels back
ax3.set_yticklabels([labels[i] for i in ball_carrier_edge_idxs])
# Set x labels to edge function names at the top, rotated 45 degrees
edge_feature_xticklabels = feature_ticklabels(self._edge_feature_dims)
ax3.xaxis.set_ticks_position("top")
ax3.set_xticks(range(len(edge_feature_xticklabels)))
ax3.set_xticklabels(edge_feature_xticklabels, rotation=45, ha="left")
plt.colorbar(im3, ax=ax3, fraction=0.1, pad=0.2)
def plot_vertical_pitch(frame_data: pl.DataFrame):
"""Plot the soccer pitch visualization"""
try:
from mplsoccer import VerticalPitch
except ImportError:
raise ImportError(
"Seems like you don't have mplsoccer installed. Please"
" install it using: pip install mplsoccer"
)
# Determine subplot position based on plot_type
if self._plot_type == "pitch_only":
pitch_pos = (0, 0)
else: # "full"
pitch_pos = (slice(None), 2)
ax4 = self._fig.add_subplot(self._gs[pitch_pos])
pitch = VerticalPitch(
pitch_type="secondspectrum",
pitch_length=self.pitch_dimensions.pitch_length,
pitch_width=self.pitch_dimensions.pitch_width,
pitch_color="#ffffff",
pad_top=-0.05,
)
pitch.draw(ax=ax4)
player_and_ball(frame_data=frame_data, ax=ax4)
direction_of_play_arrow(ax=ax4)
def feature_ticklabels(feature_dims):
_feature_ticklabels = []
for key, value in feature_dims.items():
if value == 1:
_feature_ticklabels.append(key)
else:
_feature_ticklabels.extend([key] + [None] * (value - 1))
return _feature_ticklabels
def direction_of_play_arrow(ax):
arrow_x = -30
arrow_y = -7.5
arrow_dx = 0
arrow_dy = 15
if self.settings.orientation == Orientation.STATIC_HOME_AWAY:
if self._ball_owning_team_id != self.settings.home_team_id:
arrow_y = arrow_y * -1
arrow_dy = arrow_dy * -1
elif self.settings.orientation == Orientation.BALL_OWNING_TEAM:
pass
else:
raise ValueError(f"Unsupported orientation {self.settings.orientation}")
# Create the arrow to indicate direction of play
ax.arrow(
arrow_x,
arrow_y,
arrow_dx,
arrow_dy,
head_width=3,
head_length=2,
fc="#c2c2c2",
ec="#c2c2c2",
width=0.5,
length_includes_head=True,
zorder=-1,
)
def player_and_ball(frame_data, ax):
if self._color_by == "ball_owning":
team_id = self._ball_owning_team_id
elif self._color_by == "static_home_away":
team_id = self.settings.home_team_id
else:
raise ValueError(f"Unsupported color_by {self._color_by}")
for i, r in enumerate(frame_data.iter_rows(named=True)):
v, vy, vx, y, x = (
r[Column.SPEED],
r[Column.VX],
r[Column.VY],
r[Column.X],
r[Column.Y],
)
is_ball = True if r[Column.TEAM_ID] == self.settings.ball_id else False
if not is_ball:
if team_id is None:
team_id = r[Column.TEAM_ID]
color = (
self._team_color_a
if r[Column.TEAM_ID] == team_id
else self._team_color_b
)
if r[Column.IS_BALL_CARRIER] == True:
self._ball_carrier_color = color
ax.scatter(x, y, color=color, s=450)
if v > 1.0:
ax.annotate(
"",
xy=(x + vx, y + vy),
xytext=(x, y),
arrowprops=dict(arrowstyle="->", color=color, lw=3),
)
else:
ax.scatter(x, y, color=self._ball_color, s=250, zorder=10)
# Text with white border
text = ax.text(
x + (-1.2 if is_ball else 0.0),
y + (-1.2 if is_ball else 0.0),
(
(
self.get_player_by_id(r[Column.OBJECT_ID])["jersey_no"]
if r[Column.OBJECT_ID] != Constant.BALL
else Constant.BALL if self._show_ball_label else ""
)
if not anonymous
else (
str(i)
if r[Column.OBJECT_ID] != Constant.BALL
else Constant.BALL if self._show_ball_label else ""
)
),
color=self._ball_color if is_ball else color,
fontsize=12,
ha="center",
va="center",
zorder=15 if is_ball else 5,
)
import matplotlib.patheffects as path_effects
text.set_path_effects(
[
path_effects.Stroke(linewidth=6, foreground="white"),
path_effects.Normal(),
]
)
# Add label and timestamp to pitch if enabled
if self._show_label:
ax.set_xlabel(f"Label: {frame_data['label'][0]}", fontsize=22)
if self._show_pitch_timestamp:
ax.set_title(self._gameclock, fontsize=22)
def frame_plot(self, frame_data):
def timestamp_to_gameclock(timestamp, period_id):
total_seconds = timestamp.total_seconds()
minutes = int(total_seconds // 60)
seconds = int(total_seconds % 60)
milliseconds = int((total_seconds % 1) * 1000)
return f"[{period_id}] - {minutes}:{seconds:02d}:{milliseconds:03d}"
# Setup GridSpec based on plot_type
self._gs = setup_gridspec()
# Only process graph data if we need to show graphs
if self._plot_type in ["graph_only", "full"]:
# Process the current frame
features = self._compute(
[frame_data[col] for col in self._exprs_variables]
)
a = make_sparse(
reshape_from_size(
features["a"], features["a_shape_0"], features["a_shape_1"]
)
)
x = reshape_from_size(
features["x"], features["x_shape_0"], features["x_shape_1"]
)
e = reshape_from_size(
features["e"], features["e_shape_0"], features["e_shape_1"]
)
y = np.asarray([features[self.label_column]])
self._graph = {
"a": a,
"x": x,
"e": e,
"y": y,
"frame_id": features["frame_id"],
"object_ids": frame_data[Column.OBJECT_ID],
"ball_owning_team_id": frame_data[Column.BALL_OWNING_TEAM_ID][0],
}
self._ball_carrier_idx = np.where(
frame_data[Column.IS_BALL_CARRIER] == True
)[0][0]
self._ball_owning_team_id = list(frame_data[Column.BALL_OWNING_TEAM_ID])[0]
self._gameclock = timestamp_to_gameclock(
timestamp=list(frame_data["timestamp"])[0],
period_id=list(frame_data["period_id"])[0],
)
# Plot based on plot_type
if self._plot_type == "pitch_only":
plot_vertical_pitch(frame_data)
elif self._plot_type == "graph_only":
plot_graph()
else: # "full"
plot_vertical_pitch(frame_data)
plot_graph()
plt.tight_layout()
# Adjust figure size based on plot_type
if self._plot_type == "pitch_only":
self._fig = plt.figure(figsize=(8, 12))
elif self._plot_type == "graph_only":
self._fig = plt.figure(figsize=(14, 10))
else: # "full"
self._fig = plt.figure(figsize=(25, 18))
self._fig.subplots_adjust(left=0.06, right=1.0, bottom=0.05)
if sort:
df = self._sort(df)
if generate_video:
writer = animation.FFMpegWriter(fps=fps, bitrate=1500)
with writer.saving(self._fig, file_path, dpi=300):
for group_id, frame_data in df.group_by(
Group.BY_FRAME, maintain_order=True
):
self._fig.clear()
frame_plot(self, frame_data)
writer.grab_frame()
else:
frame_plot(self, frame_data=df)
plt.savefig(file_path, dpi=300)