Source code for unravel.soccer.graphs.graph_converter

import logging
import sys

from dataclasses import dataclass, field

from typing import List, Union, Dict, Literal, Any, Optional, Callable, TYPE_CHECKING

import inspect

import pathlib

from kloppy.domain import MetricPitchDimensions, Orientation

if TYPE_CHECKING:
    from spektral.data import Graph

from .graph_settings import GraphSettingsPolars
from ..dataset.kloppy_polars import KloppyPolarsDataset, Column, Group, Constant
from .features import (
    compute_node_features,
    add_global_features,
    compute_adjacency_matrix,
    compute_edge_features,
)

from ...utils import *

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
stdout_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stdout_handler)


[docs] @dataclass(repr=True) class SoccerGraphConverter(DefaultGraphConverter): """Convert soccer tracking data from Polars DataFrame to graph structures for GNN training. This class transforms soccer tracking data into graph representations suitable for Graph Neural Networks. Each frame of tracking data becomes a graph with players and the ball as nodes, with edges representing spatial relationships or team affiliations. The converter supports two GNN frameworks: - PyTorch Geometric (recommended) via :meth:`to_pytorch_graphs` - Spektral (deprecated, Python 3.11 only) via :meth:`to_spektral_graphs` Graph Structure: - **Nodes**: Players (home team, away team) and ball - **Node Features**: Position, velocity, acceleration, distances, angles (12 default features) - **Edges**: Defined by adjacency_matrix_type (team-based, spatial, or dense) - **Edge Features**: Distances, angles, relative velocities (6-7 default features) - **Global Features**: Optional match-level features attached to ball node Key Features: - Configurable node and edge feature engineering - Multiple adjacency matrix types (split_by_team, delaunay, dense) - Custom feature functions via decorators - Automatic padding for fixed-size graphs - Ball connection strategies (all players, carrier only, none) - Permutation invariance via random node ordering Args: dataset (KloppyPolarsDataset): Polars dataset with tracking data. Must have been processed with :meth:`~unravel.soccer.KloppyPolarsDataset.add_graph_ids` and optionally :meth:`~unravel.soccer.KloppyPolarsDataset.add_dummy_labels`. chunk_size (int, optional): Number of graphs to process simultaneously. Higher values use more memory but may be faster. Defaults to 20000. non_potential_receiver_node_value (float, optional): Node feature value (0-1) assigned to defending team players. Used to distinguish attackers from defenders. Defaults to 0.1. edge_feature_funcs (List[Callable], optional): Custom edge feature functions decorated with ``@graph_feature(type="edge")``. If None, uses defaults. Defaults to None. node_feature_funcs (List[Callable], optional): Custom node feature functions decorated with ``@graph_feature(type="node")``. If None, uses defaults. Defaults to None. global_feature_cols (List[str], optional): Column names from the dataset to use as graph-level features (e.g., match score, team ratings). Must be constant within each graph_id group. Defaults to empty list. global_feature_type (Literal["ball", "all"], optional): Where to attach global features. "ball" attaches to ball node only, "all" attaches to all nodes. Defaults to "ball". additional_feature_cols (List[str], optional): Extra columns from dataset to make available to custom feature functions (e.g., player height, position). Defaults to empty list. Attributes: settings (GraphSettingsPolars): Configuration for graph conversion including adjacency matrix type, padding, and feature settings. n_node_features (int): Total number of node features per node. n_edge_features (int): Total number of edge features per edge. n_graph_features (int): Total number of global/graph-level features. Raises: ValueError: If dataset is not a KloppyPolarsDataset. ValueError: If required columns (graph_id, label) are missing. ValueError: If custom feature functions are not properly decorated. Example: >>> from unravel.soccer import KloppyPolarsDataset, SoccerGraphConverter >>> from kloppy import sportec >>> >>> # Load and prepare data >>> kloppy_dataset = sportec.load_open_tracking_data(only_alive=True) >>> polars_dataset = KloppyPolarsDataset(kloppy_dataset=kloppy_dataset) >>> polars_dataset.add_dummy_labels(by=["frame_id"]) >>> polars_dataset.add_graph_ids(by=["frame_id"]) >>> >>> # Create converter >>> converter = SoccerGraphConverter( ... dataset=polars_dataset, ... self_loop_ball=True, ... adjacency_matrix_connect_type="ball", ... adjacency_matrix_type="split_by_team", ... label_type="binary", ... ) >>> >>> # Convert to PyTorch Geometric format >>> graphs = converter.to_pytorch_graphs() >>> print(f"Created {len(graphs)} graphs") >>> print(f"Node features: {converter.n_node_features}") >>> print(f"Edge features: {converter.n_edge_features}") Note: For detailed configuration options, see :class:`~unravel.soccer.GraphSettingsPolars`. For custom features, see :func:`~unravel.utils.features.graph_feature` decorator. Warning: If not using padding (``pad=False``), graphs with incomplete player data (< 22 players) will be dropped. Use ``pad=True`` for variable-sized teams. See Also: :class:`~unravel.soccer.KloppyPolarsDataset`: Prepare tracking data. :class:`~unravel.utils.GraphDataset`: Wrap graphs for training. :func:`~unravel.utils.features.graph_feature`: Create custom features. :doc:`../tutorials/soccer_gnn`: Complete GNN training tutorial. `Graph FAQ <https://github.com/unravelsports/unravelsports/blob/main/examples/graphs_faq.md>`_: Detailed configuration guide. """ dataset: KloppyPolarsDataset = None chunk_size: int = 2_0000 non_potential_receiver_node_value: float = 0.1 edge_feature_funcs: List[Callable[[Dict[str, Any]], np.ndarray]] = field( repr=False, default_factory=list ) node_feature_funcs: List[Callable[[Dict[str, Any]], np.ndarray]] = field( repr=False, default_factory=list ) global_feature_cols: Optional[List[str]] = field(repr=False, default_factory=list) global_feature_type: Literal["ball", "all"] = "ball" additional_feature_cols: Optional[List[str]] = field( repr=False, default_factory=list ) _edge_feature_dims: Dict[str, int] = field( repr=False, default_factory=dict, init=False ) _node_feature_dims: Dict[str, int] = field( repr=False, default_factory=dict, init=False ) def __post_init__(self): if not isinstance(self.dataset, KloppyPolarsDataset): raise ValueError("dataset should be of type KloppyPolarsDataset...") self.pitch_dimensions: MetricPitchDimensions = ( self.dataset.settings.pitch_dimensions ) self._kloppy_settings = self.dataset.settings self.label_column: str = ( self.label_col if self.label_col is not None else self.dataset._label_column ) self.graph_id_column: str = ( self.graph_id_col if self.graph_id_col is not None else self.dataset._graph_id_column ) self.dataset = self.dataset.data if not self.edge_feature_funcs: self.edge_feature_funcs = self.default_edge_feature_funcs self._verify_feature_funcs(self.edge_feature_funcs, feature_type="edge") if not self.node_feature_funcs: self.node_feature_funcs = self.default_node_feature_funcs self._verify_feature_funcs(self.node_feature_funcs, feature_type="node") self._sport_specific_checks() self.settings = self._apply_graph_settings() if self.pad: self.dataset = self._apply_padding() else: self.dataset = self._remove_incomplete_frames() self._sample() self._shuffle() def _sample(self): if self.sample_rate is None: return else: self.dataset = self.dataset.filter( pl.col(Column.FRAME_ID) % (1.0 / self.sample_rate) == 0 ) @staticmethod def _sort(df): sort_expr = (pl.col(Column.TEAM_ID) == Constant.BALL).cast(int) * 2 - ( (pl.col(Column.BALL_OWNING_TEAM_ID) == pl.col(Column.TEAM_ID)) & (pl.col(Column.TEAM_ID) != Constant.BALL) ).cast(int) df = df.sort([*Group.BY_FRAME, sort_expr, pl.col(Column.OBJECT_ID)]) return df def _remove_incomplete_frames(self) -> pl.DataFrame: df = self.dataset total_frames = len(df.unique(Group.BY_FRAME)) valid_frames = ( df.group_by(Group.BY_FRAME, maintain_order=True) .agg(pl.col(Column.TEAM_ID).n_unique().alias("unique_teams")) .filter(pl.col("unique_teams") == 3) .select(Group.BY_FRAME) ) dropped_frames = total_frames - len(valid_frames.unique(Group.BY_FRAME)) if dropped_frames > 0 and self.verbose: self.__warn_dropped_frames(dropped_frames, total_frames) return df.join(valid_frames, on=Group.BY_FRAME) def _apply_padding(self) -> pl.DataFrame: df = self.dataset keep_columns = [ Column.TIMESTAMP, Column.BALL_STATE, self.label_column, self.graph_id_column, ] empty_columns = [ Column.POSITION_NAME, Column.OBJECT_ID, Column.IS_BALL_CARRIER, Column.X, Column.Y, Column.Z, Column.VX, Column.VY, Column.VZ, Column.SPEED, Column.AX, Column.AY, Column.AZ, Column.ACCELERATION, ] group_by_columns = [ Column.GAME_ID, Column.PERIOD_ID, Column.FRAME_ID, Column.TEAM_ID, Column.BALL_OWNING_TEAM_ID, ] user_defined_columns = [ x for x in df.columns if x not in keep_columns + group_by_columns + empty_columns + self.global_feature_cols ] counts = df.group_by(group_by_columns, maintain_order=True).agg( pl.len().alias("count"), *[ pl.first(col).alias(col) for col in keep_columns + self.global_feature_cols ], ) counts = counts.with_columns( [ pl.when(pl.col(Column.TEAM_ID) == Constant.BALL) .then(1) .when(pl.col(Column.TEAM_ID) == pl.col(Column.BALL_OWNING_TEAM_ID)) .then(11) .otherwise(11) .alias("target_length") ] ) groups_to_pad = counts.filter( pl.col("count") < pl.col("target_length") ).with_columns((pl.col("target_length") - pl.col("count")).alias("repeats")) padding_rows = [] # This is where we pad players (missing balls get skipped because of 'target_length') for row in groups_to_pad.iter_rows(named=True): base_row = { col: row[col] for col in keep_columns + group_by_columns + self.global_feature_cols } padding_rows.extend([base_row] * row["repeats"]) # Now check if there are frames without ball rows # Get all unique frames all_frames = df.select( [ Column.GAME_ID, Column.PERIOD_ID, Column.FRAME_ID, Column.BALL_OWNING_TEAM_ID, ] + keep_columns + self.global_feature_cols ).unique() # Get frames that have ball rows frames_with_ball = ( df.filter(pl.col(Column.TEAM_ID) == Constant.BALL) .select([Column.GAME_ID, Column.PERIOD_ID, Column.FRAME_ID]) .unique() ) # Find frames missing ball rows frames_missing_ball = all_frames.join( frames_with_ball, on=[Column.GAME_ID, Column.PERIOD_ID, Column.FRAME_ID], how="anti", ) # Create a dataframe of ball rows to add with appropriate columns if frames_missing_ball.height > 0: # Create base rows for missing balls ball_rows_to_add = frames_missing_ball.with_columns( [ pl.lit(Constant.BALL).alias(Column.TEAM_ID), pl.lit(Constant.BALL).alias(Column.POSITION_NAME), ] ) # Add to padding rows using same pattern as for players for row in ball_rows_to_add.iter_rows(named=True): base_row = { col: row[col] for col in keep_columns + group_by_columns + [Column.POSITION_NAME] + self.global_feature_cols if col in row } padding_rows.append(base_row) if len(padding_rows) == 0: return df padding_df = pl.DataFrame(padding_rows) schema = df.schema padding_df = padding_df.with_columns( [create_default_expression(col, schema[col]) for col in empty_columns] + [ pl.lit(None).cast(schema[col]).alias(col) for col in user_defined_columns ] ) padding_df = padding_df.with_columns( [pl.col(col).cast(df.schema[col]).alias(col) for col in group_by_columns] ) padding_df = padding_df.join( ( df.unique(group_by_columns).select( group_by_columns + self.global_feature_cols ) ), on=group_by_columns, how="left", ) padding_df = padding_df.with_columns( [ pl.col(col_name).cast(df.schema[col_name]).alias(col_name) for col_name in df.columns ] ).select(df.columns) result = pl.concat([df, padding_df], how="vertical") total_frames = result.select(Group.BY_FRAME).unique().height frame_completeness = ( result.group_by(Group.BY_FRAME, maintain_order=True) .agg( [ (pl.col(Column.TEAM_ID).eq(Constant.BALL).sum() == 1).alias( "has_ball" ), ( pl.col(Column.TEAM_ID) .eq(pl.col(Column.BALL_OWNING_TEAM_ID)) .sum() == 11 ).alias("has_owning_team"), ( ( ~pl.col(Column.TEAM_ID).eq(Constant.BALL) & ~pl.col(Column.TEAM_ID).eq( pl.col(Column.BALL_OWNING_TEAM_ID) ) ).sum() == 11 ).alias("has_other_team"), ] ) .filter( pl.col("has_ball") & pl.col("has_owning_team") & pl.col("has_other_team") ) ) complete_frames = frame_completeness.height dropped_frames = total_frames - complete_frames if dropped_frames > 0 and self.verbose: self.__warn_dropped_frames(dropped_frames, total_frames) return result.join(frame_completeness, on=Group.BY_FRAME, how="inner") @staticmethod def __warn_dropped_frames(dropped_frames, total_frames): import warnings warnings.warn( f"""Setting pad=True drops frames that do not have at least 1 object for the attacking team, defending team or ball. This operation dropped {dropped_frames} incomplete frames out of {total_frames} total frames ({(dropped_frames/total_frames)*100:.2f}%) """ ) def _apply_graph_settings(self): return GraphSettingsPolars( home_team_id=str(self._kloppy_settings.home_team_id), away_team_id=str(self._kloppy_settings.away_team_id), players=self._kloppy_settings.players, features={ "edge": [x.__name__ for x in self.edge_feature_funcs], "node": [x.__name__ for x in self.node_feature_funcs], "global": self.global_feature_cols, }, orientation=self._kloppy_settings.orientation, pitch_dimensions=self.pitch_dimensions, max_player_speed=self.settings.max_player_speed, max_ball_speed=self.settings.max_ball_speed, max_player_acceleration=self.settings.max_player_acceleration, max_ball_acceleration=self.settings.max_ball_acceleration, self_loop_ball=self.self_loop_ball, adjacency_matrix_connect_type=self.adjacency_matrix_connect_type, adjacency_matrix_type=self.adjacency_matrix_type, label_type=self.label_type, defending_team_node_value=self.defending_team_node_value, non_potential_receiver_node_value=self.non_potential_receiver_node_value, random_seed=self.random_seed, pad=self.pad, verbose=self.verbose, ) def _sport_specific_checks(self): if not isinstance(self.label_column, str): raise Exception("'label_col' should be of type string (str)") if not isinstance(self.graph_id_column, str): raise Exception("'graph_id_col' should be of type string (str)") if not isinstance(self.chunk_size, int): raise Exception("chunk_size should be of type integer (int)") if not self.label_column in self.dataset.columns and not self.prediction: raise Exception( "Please specify a 'label_col' and add that column to your 'dataset' or set 'prediction=True' if you want to use the converted dataset to make predictions on." ) if not self.label_column in self.dataset.columns and self.prediction: self.dataset = self.dataset.with_columns( pl.lit(None).alias(self.label_column) ) if not self.graph_id_column in self.dataset.columns: raise Exception( "Please specify a 'graph_id_col' and add that column to your 'dataset' ..." ) if self.non_potential_receiver_node_value and not isinstance( self.non_potential_receiver_node_value, float ): raise Exception( "'non_potential_receiver_node_value' should be of type float" ) @property def _exprs_variables(self): exprs_variables = [ Column.X, Column.Y, Column.Z, Column.SPEED, Column.VX, Column.VY, Column.VZ, Column.ACCELERATION, Column.AX, Column.AY, Column.AZ, Column.TEAM_ID, Column.POSITION_NAME, Column.BALL_OWNING_TEAM_ID, Column.IS_BALL_CARRIER, Column.OBJECT_ID, self.graph_id_column, self.label_column, ] exprs = ( exprs_variables + self.global_feature_cols + self.additional_feature_cols ) return exprs @property def default_node_feature_funcs(self) -> list: return [ x_normed, y_normed, speeds_normed, velocity_components_2d_normed, distance_to_goal_normed, distance_to_ball_normed, is_possession_team, is_gk, is_ball, angle_to_goal_components_2d_normed, angle_to_ball_components_2d_normed, is_ball_carrier, ] @property def default_edge_feature_funcs(self) -> list: return [ distances_between_players_normed, speed_difference_normed, angle_between_players_normed, velocity_difference_normed, ] def __add_additional_kwargs(self, d): d["ball_id"] = Constant.BALL d["possession_team_id"] = d[Column.BALL_OWNING_TEAM_ID][0] d["is_gk"] = np.where( d[Column.POSITION_NAME] == self.settings.goalkeeper_id, True, False ) d["position"] = np.nan_to_num( np.stack((d[Column.X], d[Column.Y], d[Column.Z]), axis=-1), nan=1e-10, posinf=1e3, neginf=-1e3, ) d["velocity"] = np.nan_to_num( np.stack((d[Column.VX], d[Column.VY], d[Column.VZ]), axis=-1), nan=1e-10, posinf=1e3, neginf=-1e3, ) if len(np.where(d["team_id"] == d["ball_id"])[0]) >= 1: ball_index = np.where(d["team_id"] == d["ball_id"])[0] ball_position = d["position"][ball_index][0] else: ball_position = np.asarray([0.0, 0.0, 0.0]) ball_index = 0 ball_carriers = np.where(d[Column.IS_BALL_CARRIER] == True)[0] if len(ball_carriers) == 0: ball_carrier_idx = None else: ball_carrier_idx = ball_carriers[0] d["ball_position"] = ball_position d["ball_idx"] = ball_index d["ball_carrier_idx"] = ball_carrier_idx return d def _compute(self, args: List[pl.Series]) -> dict: frame_data: dict = { col: args[i].to_numpy() for i, col in enumerate(self._exprs_variables) } frame_data = self.__add_additional_kwargs(frame_data) frame_id = args[-1][0] ball_owning_team_id = frame_data[Column.BALL_OWNING_TEAM_ID][0] if not np.all( frame_data[self.graph_id_column] == frame_data[self.graph_id_column][0] ): raise ValueError( "graph_id selection contains multiple different values. Make sure each graph_id is unique by at least game_id and frame_id..." ) if not self.prediction and not np.all( frame_data[self.label_column] == frame_data[self.label_column][0] ): raise ValueError( """Label selection contains multiple different values for a single selection (group by) of game_id and frame_id, make sure this is not the case. Each group can only have 1 label.""" ) adjacency_matrix = compute_adjacency_matrix( settings=self.settings, **frame_data ) edge_features, self._edge_feature_dims = compute_edge_features( adjacency_matrix=adjacency_matrix, funcs=self.edge_feature_funcs, opts=self.feature_opts, settings=self.settings, **frame_data, ) node_features, self._node_feature_dims = compute_node_features( funcs=self.node_feature_funcs, opts=self.feature_opts, settings=self.settings, **frame_data, ) if self.global_feature_cols: failed = [ col for col in self.global_feature_cols if not np.all(frame_data[col] == frame_data[col][0]) ] if failed: raise ValueError( f"""graph_feature_cols contains multiple different values for a group in the groupby ({Group.BY_FRAME}) selection for the columns {failed}. Make sure each group has the same values per individual column.""" ) global_features = ( np.asarray([frame_data[col] for col in self.global_feature_cols]).T[0] if self.global_feature_cols else None ) for col in self.global_feature_cols: self._node_feature_dims[col] = 1 node_features = add_global_features( node_features=node_features, global_features=global_features, global_feature_type=self.global_feature_type, **frame_data, ) return { "e": edge_features.tolist(), # Remove pl.Series wrapper "x": node_features.tolist(), # Remove pl.Series wrapper "a": adjacency_matrix.tolist(), # Remove pl.Series wrapper "e_shape_0": edge_features.shape[0], "e_shape_1": edge_features.shape[1], "x_shape_0": node_features.shape[0], "x_shape_1": node_features.shape[1], "a_shape_0": adjacency_matrix.shape[0], "a_shape_1": adjacency_matrix.shape[1], self.graph_id_column: frame_data[self.graph_id_column][0], self.label_column: frame_data[self.label_column][0], "frame_id": frame_id, "object_ids": frame_data[ Column.OBJECT_ID ].tolist(), # Remove pl.Series wrapper "ball_owning_team_id": ball_owning_team_id, } def _convert(self): # Group and aggregate in one step return ( self.dataset.group_by(Group.BY_FRAME, maintain_order=True) .agg( pl.map_groups( exprs=self._exprs_variables + [Column.FRAME_ID], function=self._compute, return_dtype=self.return_dtypes, returns_scalar=True, ).alias("result_dict") ) .with_columns( [ *[ pl.col("result_dict").struct.field(f).alias(f) for f in [ "a", "e", "x", self.graph_id_column, self.label_column, "frame_id", "object_ids", "ball_owning_team_id", ] ], *[ pl.col("result_dict") .struct.field(f"{m}_shape_{i}") .alias(f"{m}_shape_{i}") for m in ["a", "e", "x"] for i in [0, 1] ], ] ) .drop("result_dict") )
[docs] def get_players_by_team_id(self, team_id): return [ player for player in self.settings.players if player["team_id"] == team_id ]
[docs] def get_player_by_id(self, player_id): for player in self.settings.players: if player["player_id"] == player_id: return player return None
[docs] def plot( self, file_path: str, fps: int = None, timestamp: pl.duration = None, end_timestamp: pl.duration = None, period_id: int = None, team_color_a: str = "#CD0E61", team_color_b: str = "#0066CC", ball_color: str = "black", sort: bool = True, color_by: Literal["ball_owning", "static_home_away"] = "ball_owning", anonymous: bool = False, plot_type: Literal["pitch_only", "graph_only", "full"] = "full", show_label: bool = True, show_ball_label: bool = False, show_timestamp: bool = True, next_closest_timestamp: bool = False, ): """ Plot tracking data as a static image or video file. This method visualizes tracking data for players and the ball. It can generate either: - A single PNG image (if either fps or end_timestamp is None, or both are None) - An MP4 video (if both fps and end_timestamp are provided) Parameters ---------- file_path : str The output path where the PNG or MP4 file will be saved fps : int, optional Frames per second for video output. If None, a static image is generated timestamp : pl.duration, optional The starting timestamp to plot. If None, starts from the beginning of available data end_timestamp : pl.duration, optional The ending timestamp for video output. If None, a static image is generated period_id : int, optional ID of the match period to visualize. If None, all periods are included team_color_a : str, default "#CD0E61" Hex color code for Team A visualization team_color_b : str, default "#0066CC" Hex color code for Team B visualization ball_color : str, default "black" Color for ball visualization color_by : Literal["ball_owning", "static_home_away"], default "ball_owning" Method for coloring the teams: - "ball_owning": Colors teams based on ball possession - "static_home_away": Uses static colors for home and away teams anonymous : bool, default False Whether to anonymize player labels plot_type : Literal["pitch_only", "graph_only", "full"], default "full" Type of plot to generate: - "pitch_only": Shows only the soccer pitch visualization - "graph_only": Shows only the graph features (node features, adjacency matrix, edge features) - "full": Shows both pitch and graph visualizations show_pitch_label : bool, default True Whether to show the label on the pitch visualization show_pitch_timestamp : bool, default True Whether to show the timestamp on the pitch visualization next_closest_timestamp : bool, default False When plotting a .png and the timestamp isn't 100% correct we find the next correct timestamp and use that to plot. Returns ------- None The function saves the output file to the specified file_path but doesn't return any value Notes ----- Output file type is determined by parameters: - PNG: Generated when either fps or end_timestamp is None, or both are None - MP4: Generated when both fps and end_timestamp are provided Raises ------ ValueError If file extension doesn't match the parameters provided (e.g., .mp4 extension but missing fps or end_timestamp, or .png extension with both fps and end_timestamp) """ try: import matplotlib.animation as animation import matplotlib.pyplot as plt from matplotlib.gridspec import GridSpec except ImportError: raise ImportError( "Seems like you don't have matplotlib installed. Please" " install it using: pip install matplotlib" ) if (fps is None and end_timestamp is not None) or ( fps is not None and end_timestamp is None ): raise ValueError( "Both 'fps' and 'end_timestamp' must be provided together to generate a video. " ) # Determine the output type based on parameters generate_video = fps is not None and end_timestamp is not None # Get file extension if it exists path = pathlib.Path(file_path) file_extension = path.suffix.lower() if path.suffix else "" # If no extension, add the appropriate one based on parameters if not file_extension: suffix = ".mp4" if generate_video else ".png" file_path = str(path.with_suffix(suffix)) # Otherwise, validate that the extension matches the parameters else: if generate_video and file_extension != ".mp4": raise ValueError( f"Parameters fps and end_timestamp indicate video output, " f"but file extension is '{file_extension}'. Use '.mp4' extension for video output." ) elif not generate_video and file_extension == ".mp4": raise ValueError( "To generate an MP4 video, both 'fps' and 'end_timestamp' must be provided. " "For static image output, use a '.png' extension." ) elif not generate_video and file_extension != ".png": raise ValueError( f"For static image output, use '.png' extension instead of '{file_extension}'." ) self._team_color_a = team_color_a self._team_color_b = team_color_b self._ball_color = ball_color self._color_by = color_by self._plot_type = plot_type self._show_label = show_label self._show_ball_label = show_ball_label self._show_pitch_timestamp = show_timestamp self._next_closest_timestamp = next_closest_timestamp self._ball_carrier_color = "black" if period_id is not None and not isinstance(period_id, int): raise TypeError("period_id should be of type integer") if all(x is None for x in [timestamp, end_timestamp, period_id]): # No filters specified, use the entire dataset df = self.dataset elif timestamp is not None and period_id is not None: if end_timestamp is not None: # Both timestamp and end_timestamp provided - filter for a range df = self.dataset.filter( (pl.col(Column.TIMESTAMP).is_between(timestamp, end_timestamp)) & (pl.col(Column.PERIOD_ID) == period_id) ) else: # Only timestamp provided (no end_timestamp) - filter for specific timestamp df = self.dataset.filter( (pl.col(Column.TIMESTAMP) == timestamp) & (pl.col(Column.PERIOD_ID) == period_id) ) # Handle the case where a single timestamp has multiple frame_ids df = ( df.with_columns( pl.col(Column.FRAME_ID) .rank(method="min") .over(Column.TIMESTAMP) .alias("frame_rank") ) # Keep only rows where the frame has rank = 1 (first frame for each timestamp) .filter(pl.col("frame_rank") == 1).drop("frame_rank") ) else: raise ValueError( "Please specify both timestamp and period_id, or specify all of timestamp, end_timestamp, and period_id, or none of them." ) if df.is_empty(): if not generate_video and self._next_closest_timestamp: idx = self.dataset.sort(Column.FRAME_ID)[ Column.TIMESTAMP ].search_sorted(timestamp) result = self.dataset[idx] df = self.dataset.filter( (pl.col(Column.TIMESTAMP) == result[Column.TIMESTAMP][0]) & (pl.col(Column.PERIOD_ID) == result[Column.PERIOD_ID][0]) ) # Handle the case where a single timestamp has multiple frame_ids df = ( df.with_columns( pl.col(Column.FRAME_ID) .rank(method="min") .over(Column.TIMESTAMP) .alias("frame_rank") ) # Keep only rows where the frame has rank = 1 (first frame for each timestamp) .filter(pl.col("frame_rank") == 1).drop("frame_rank") ) else: if not generate_video: raise ValueError( "Selection is empty, please try different timestamp(s) or set next_closest_timestamp=True..." ) raise ValueError( "Selection is empty, please try different timestamp(s)..." ) def setup_gridspec(): """Setup GridSpec based on plot_type""" if self._plot_type == "pitch_only": return GridSpec(1, 1, left=0.05, right=0.95, bottom=0.05, top=0.95) elif self._plot_type == "graph_only": return GridSpec( 2, 2, width_ratios=[1.2, 0.8], height_ratios=[1, 1], wspace=0.2, # Increased spacing hspace=0.3, # Increased spacing left=0.08, right=0.92, bottom=0.1, # More bottom margin top=0.9, ) # More top margin else: # "full" return GridSpec( 2, 3, width_ratios=[2, 1, 3], height_ratios=[1, 1], wspace=0.15, # Increased spacing hspace=0.1, # Increased spacing left=0.05, right=0.98, # Slightly reduced right margin bottom=0.08, # More bottom margin top=0.95, ) def plot_graph(): """Plot graph features (node features, adjacency matrix, edge features)""" import matplotlib.pyplot as plt num_rows = self._graph["x"].shape[0] labels = ( [ ( self.get_player_by_id(pid)["jersey_no"] if pid != Constant.BALL else Constant.BALL ) for pid in self._graph["object_ids"] ] if not anonymous else [str(i) for i in range(num_rows)] ) # Determine subplot positions based on plot_type if self._plot_type == "graph_only": node_pos = (0, 0) adj_pos = (1, 0) edge_pos = (slice(None), 1) else: # "full" node_pos = (0, 0) adj_pos = (1, 0) edge_pos = (slice(None), 1) # Plot node features ax1 = self._fig.add_subplot(self._gs[node_pos]) ax1.imshow(self._graph["x"], aspect="auto", cmap="YlOrRd") ax1.set_xlabel(f"Node Features {self._graph['x'].shape}") # Set y labels to integers ax1.set_yticks(range(num_rows)) ax1.set_yticklabels(labels) node_feature_yticklabels = feature_ticklabels(self._node_feature_dims) ax1.xaxis.set_ticks_position("top") ax1.set_xticks(range(len(node_feature_yticklabels))) ax1.set_xticklabels(node_feature_yticklabels, rotation=45, ha="left") # Plot adjacency matrix ax2 = self._fig.add_subplot(self._gs[adj_pos]) ax2.imshow(self._graph["a"].toarray(), aspect="auto", cmap="YlOrRd") ax2.set_xlabel(f"Adjacency Matrix {self._graph['a'].shape}") # Set both x and y labels to integers num_rows_a = self._graph["a"].toarray().shape[0] num_cols_a = self._graph["a"].toarray().shape[1] ax2.set_yticks(range(num_rows_a)) ax2.set_yticklabels(labels) ax2.xaxis.set_ticks_position("top") ax2.set_xticks(range(num_cols_a)) ax2.set_xticklabels(labels) # Plot Edge Features ax3 = self._fig.add_subplot(self._gs[edge_pos]) _, size_a = non_zeros( self._graph["a"].toarray()[0 : self._ball_carrier_idx] ) ball_carrier_edge_idx, num_rows_e = non_zeros( np.asarray( [list(x) for x in self._graph["a"].toarray()][ self._ball_carrier_idx ] ) ) im3 = ax3.imshow( self._graph["e"][size_a : num_rows_e + size_a, :], aspect="auto", cmap="YlOrRd", ) ax3.set_yticks(range(num_rows_e)) ax3.set_yticklabels(list(ball_carrier_edge_idx[0]), fontsize=18) ball_carrier_edge_idxs = list(ball_carrier_edge_idx[0]) ax3.set_xlabel(f"Edge Features {self._graph['e'].shape}") ax3_labels = ax3.get_yticklabels() if self._ball_carrier_idx in ball_carrier_edge_idx[0]: idx_position = list(ball_carrier_edge_idx[0]).index( self._ball_carrier_idx ) # Modify just that specific label ax3_labels[idx_position].set_color(self._ball_carrier_color) ax3_labels[idx_position].set_fontweight("bold") # Set the modified labels back ax3.set_yticklabels([labels[i] for i in ball_carrier_edge_idxs]) # Set x labels to edge function names at the top, rotated 45 degrees edge_feature_xticklabels = feature_ticklabels(self._edge_feature_dims) ax3.xaxis.set_ticks_position("top") ax3.set_xticks(range(len(edge_feature_xticklabels))) ax3.set_xticklabels(edge_feature_xticklabels, rotation=45, ha="left") plt.colorbar(im3, ax=ax3, fraction=0.1, pad=0.2) def plot_vertical_pitch(frame_data: pl.DataFrame): """Plot the soccer pitch visualization""" try: from mplsoccer import VerticalPitch except ImportError: raise ImportError( "Seems like you don't have mplsoccer installed. Please" " install it using: pip install mplsoccer" ) # Determine subplot position based on plot_type if self._plot_type == "pitch_only": pitch_pos = (0, 0) else: # "full" pitch_pos = (slice(None), 2) ax4 = self._fig.add_subplot(self._gs[pitch_pos]) pitch = VerticalPitch( pitch_type="secondspectrum", pitch_length=self.pitch_dimensions.pitch_length, pitch_width=self.pitch_dimensions.pitch_width, pitch_color="#ffffff", pad_top=-0.05, ) pitch.draw(ax=ax4) player_and_ball(frame_data=frame_data, ax=ax4) direction_of_play_arrow(ax=ax4) def feature_ticklabels(feature_dims): _feature_ticklabels = [] for key, value in feature_dims.items(): if value == 1: _feature_ticklabels.append(key) else: _feature_ticklabels.extend([key] + [None] * (value - 1)) return _feature_ticklabels def direction_of_play_arrow(ax): arrow_x = -30 arrow_y = -7.5 arrow_dx = 0 arrow_dy = 15 if self.settings.orientation == Orientation.STATIC_HOME_AWAY: if self._ball_owning_team_id != self.settings.home_team_id: arrow_y = arrow_y * -1 arrow_dy = arrow_dy * -1 elif self.settings.orientation == Orientation.BALL_OWNING_TEAM: pass else: raise ValueError(f"Unsupported orientation {self.settings.orientation}") # Create the arrow to indicate direction of play ax.arrow( arrow_x, arrow_y, arrow_dx, arrow_dy, head_width=3, head_length=2, fc="#c2c2c2", ec="#c2c2c2", width=0.5, length_includes_head=True, zorder=-1, ) def player_and_ball(frame_data, ax): if self._color_by == "ball_owning": team_id = self._ball_owning_team_id elif self._color_by == "static_home_away": team_id = self.settings.home_team_id else: raise ValueError(f"Unsupported color_by {self._color_by}") for i, r in enumerate(frame_data.iter_rows(named=True)): v, vy, vx, y, x = ( r[Column.SPEED], r[Column.VX], r[Column.VY], r[Column.X], r[Column.Y], ) is_ball = True if r[Column.TEAM_ID] == self.settings.ball_id else False if not is_ball: if team_id is None: team_id = r[Column.TEAM_ID] color = ( self._team_color_a if r[Column.TEAM_ID] == team_id else self._team_color_b ) if r[Column.IS_BALL_CARRIER] == True: self._ball_carrier_color = color ax.scatter(x, y, color=color, s=450) if v > 1.0: ax.annotate( "", xy=(x + vx, y + vy), xytext=(x, y), arrowprops=dict(arrowstyle="->", color=color, lw=3), ) else: ax.scatter(x, y, color=self._ball_color, s=250, zorder=10) # Text with white border text = ax.text( x + (-1.2 if is_ball else 0.0), y + (-1.2 if is_ball else 0.0), ( ( self.get_player_by_id(r[Column.OBJECT_ID])["jersey_no"] if r[Column.OBJECT_ID] != Constant.BALL else Constant.BALL if self._show_ball_label else "" ) if not anonymous else ( str(i) if r[Column.OBJECT_ID] != Constant.BALL else Constant.BALL if self._show_ball_label else "" ) ), color=self._ball_color if is_ball else color, fontsize=12, ha="center", va="center", zorder=15 if is_ball else 5, ) import matplotlib.patheffects as path_effects text.set_path_effects( [ path_effects.Stroke(linewidth=6, foreground="white"), path_effects.Normal(), ] ) # Add label and timestamp to pitch if enabled if self._show_label: ax.set_xlabel(f"Label: {frame_data['label'][0]}", fontsize=22) if self._show_pitch_timestamp: ax.set_title(self._gameclock, fontsize=22) def frame_plot(self, frame_data): def timestamp_to_gameclock(timestamp, period_id): total_seconds = timestamp.total_seconds() minutes = int(total_seconds // 60) seconds = int(total_seconds % 60) milliseconds = int((total_seconds % 1) * 1000) return f"[{period_id}] - {minutes}:{seconds:02d}:{milliseconds:03d}" # Setup GridSpec based on plot_type self._gs = setup_gridspec() # Only process graph data if we need to show graphs if self._plot_type in ["graph_only", "full"]: # Process the current frame features = self._compute( [frame_data[col] for col in self._exprs_variables] ) a = make_sparse( reshape_from_size( features["a"], features["a_shape_0"], features["a_shape_1"] ) ) x = reshape_from_size( features["x"], features["x_shape_0"], features["x_shape_1"] ) e = reshape_from_size( features["e"], features["e_shape_0"], features["e_shape_1"] ) y = np.asarray([features[self.label_column]]) self._graph = { "a": a, "x": x, "e": e, "y": y, "frame_id": features["frame_id"], "object_ids": frame_data[Column.OBJECT_ID], "ball_owning_team_id": frame_data[Column.BALL_OWNING_TEAM_ID][0], } self._ball_carrier_idx = np.where( frame_data[Column.IS_BALL_CARRIER] == True )[0][0] self._ball_owning_team_id = list(frame_data[Column.BALL_OWNING_TEAM_ID])[0] self._gameclock = timestamp_to_gameclock( timestamp=list(frame_data["timestamp"])[0], period_id=list(frame_data["period_id"])[0], ) # Plot based on plot_type if self._plot_type == "pitch_only": plot_vertical_pitch(frame_data) elif self._plot_type == "graph_only": plot_graph() else: # "full" plot_vertical_pitch(frame_data) plot_graph() plt.tight_layout() # Adjust figure size based on plot_type if self._plot_type == "pitch_only": self._fig = plt.figure(figsize=(8, 12)) elif self._plot_type == "graph_only": self._fig = plt.figure(figsize=(14, 10)) else: # "full" self._fig = plt.figure(figsize=(25, 18)) self._fig.subplots_adjust(left=0.06, right=1.0, bottom=0.05) if sort: df = self._sort(df) if generate_video: writer = animation.FFMpegWriter(fps=fps, bitrate=1500) with writer.saving(self._fig, file_path, dpi=300): for group_id, frame_data in df.group_by( Group.BY_FRAME, maintain_order=True ): self._fig.clear() frame_plot(self, frame_data) writer.grab_frame() else: frame_plot(self, frame_data=df) plt.savefig(file_path, dpi=300)