from dataclasses import dataclass, field
from typing import List
import polars as pl
import numpy as np
from kloppy.domain import Dimension, Unit, Orientation
from ...utils import (
DefaultSettings,
DefaultDataset,
AmericanFootballPitchDimensions,
add_dummy_label_column,
add_graph_id_column,
)
from .objects import Column, Group, Constant
[docs]
@dataclass
class BigDataBowlDataset(DefaultDataset):
"""Load and preprocess NFL Big Data Bowl tracking data into Polars DataFrame format.
This class handles NFL tracking data from the Big Data Bowl competition, converting
CSV files into a standardized Polars DataFrame with computed velocities, standardized
coordinate systems, and orientation normalization. It processes three input files:
tracking data, player metadata, and play information.
The loader performs:
- Coordinate system standardization (centering at midfield)
- Orientation normalization (attacking left-to-right)
- Angle conversion (degrees → radians in [-π, π] range)
- Player metadata enrichment (height, weight, position)
- Play-level information joining (possession team, play details)
- Metric conversion (imperial → metric for anthropometrics)
The resulting dataset is ready for graph construction via
:class:`~unravel.american_football.graphs.AmericanFootballGraphConverter`.
Args:
tracking_file_path (str): Path to tracking CSV file. Must contain columns:
gameId, playId, nflId, frameId, x, y, s (speed), o (orientation),
dir (direction), team (or club).
players_file_path (str): Path to players CSV file. Must contain: nflId,
position (or officialPosition), height, weight.
plays_file_path (str): Path to plays CSV file. Must contain: gameId, playId,
possessionTeam.
sample_rate (float, optional): Sampling rate for downsampling frames. For example,
0.5 keeps every 2nd frame. Defaults to None (no downsampling).
max_player_speed (float, optional): Maximum physically plausible player speed (m/s)
for filtering outliers. Defaults to 12.0 m/s (~27 mph).
max_ball_speed (float, optional): Maximum physically plausible ball speed (m/s).
Defaults to 28.0 m/s (~63 mph).
max_player_acceleration (float, optional): Maximum player acceleration (m/s²).
Defaults to 6.0 m/s².
max_ball_acceleration (float, optional): Maximum ball acceleration (m/s²).
Defaults to 13.5 m/s².
orient_ball_owning (bool, optional): Whether to normalize coordinate system so
the offense always attacks left-to-right. Defaults to True (recommended).
**kwargs: Additional arguments passed to DefaultDataset.
Attributes:
data (pl.DataFrame): Processed tracking data with columns:
- game_id, play_id, frame_id: Identifiers
- object_id: Player NFL ID (or "football" for ball)
- team_id: Team abbreviation or "football"
- x, y: Position in yards (centered at midfield)
- s: Speed in yards/second
- o: Orientation angle in radians [-π, π]
- dir: Direction of movement in radians [-π, π]
- position_name: Player position (e.g., "QB", "WR", "CB")
- height_cm: Player height in centimeters (rounded to nearest 10cm)
- weight_kg: Player weight in kilograms (rounded to nearest 10kg)
- ball_owning_team_id: Team with possession
settings (DefaultSettings): Configuration object with pitch dimensions,
orientation settings, and speed thresholds.
Raises:
NotImplementedError: If orient_ball_owning=False (currently unsupported).
Example:
>>> from unravel.american_football.dataset import BigDataBowlDataset
>>>
>>> # Load Big Data Bowl 2024 data
>>> dataset = BigDataBowlDataset(
... tracking_file_path="tracking_week_1.csv",
... players_file_path="players.csv",
... plays_file_path="plays.csv",
... sample_rate=1.0, # Use all frames
... orient_ball_owning=True
... )
>>>
>>> # Access processed data
>>> print(dataset.data)
>>> print(f"Total frames: {dataset.data['frame_id'].n_unique()}")
>>> print(f"Total plays: {dataset.data['play_id'].n_unique()}")
>>>
>>> # Downsample to 5 Hz (every other frame from 10 Hz)
>>> dataset_5hz = BigDataBowlDataset(
... tracking_file_path="tracking_week_1.csv",
... players_file_path="players.csv",
... plays_file_path="plays.csv",
... sample_rate=0.5 # Keep every 2nd frame
... )
>>>
>>> # Add dummy labels for GNN training
>>> dataset.add_dummy_labels()
>>> dataset.add_graph_ids()
Note:
- Big Data Bowl data uses yards as the unit. The coordinate system is centered
at midfield (x=0) with y=0 at the center of the field.
- Player heights and weights are rounded to the nearest 10 cm / 10 kg to protect
player privacy while retaining useful anthropometric information.
- The orientation normalization (orient_ball_owning=True) ensures offensive
players always attack from left to right, simplifying model training.
- Frame IDs are computed as: play_id * 100,000 + frameId to ensure global uniqueness.
- The "football" object has team_id="football" and is included in every frame.
Warning:
NFL Big Data Bowl data format can vary by year. This loader is tested on
2023-2024 formats. Older competitions may require modifications.
See Also:
:class:`~unravel.american_football.graphs.AmericanFootballGraphConverter`:
Convert to graph format for GNN training.
:meth:`add_dummy_labels`: Add placeholder labels for testing.
:meth:`add_graph_ids`: Add graph identifiers for batching.
:doc:`../tutorials/american_football`: Tutorial on NFL tracking data analysis.
"""
[docs]
def __init__(
self,
tracking_file_path: str,
players_file_path: str,
plays_file_path: str,
sample_rate: float = None,
max_player_speed: float = 12.0,
max_ball_speed: float = 28.0,
max_player_acceleration: float = 6.0,
max_ball_acceleration: float = 13.5,
orient_ball_owning: bool = True,
**kwargs,
):
super().__init__(**kwargs)
self.tracking_file_path = tracking_file_path
self.players_file_path = players_file_path
self.plays_file_path = plays_file_path
self.sample_rate = 1 if sample_rate is None else sample_rate
self._max_player_speed = max_player_speed
self._max_ball_speed = max_ball_speed
self._max_player_acceleration = max_player_acceleration
self._max_ball_acceleration = max_ball_acceleration
self._orient_ball_owning = orient_ball_owning
self.load()
def __apply_settings(
self,
):
return DefaultSettings(
provider="nfl",
home_team_id=None,
away_team_id=None,
pitch_dimensions=AmericanFootballPitchDimensions(),
orientation=(
Orientation.BALL_OWNING_TEAM
if self._orient_ball_owning
else Orientation.NOT_SET
),
max_player_speed=self._max_player_speed,
max_ball_speed=self._max_ball_speed,
max_player_acceleration=self._max_player_acceleration,
max_ball_acceleration=self._max_ball_acceleration,
ball_carrier_threshold=None,
)
[docs]
def load(self):
self.settings = self.__apply_settings()
pitch_length = self.settings.pitch_dimensions.pitch_length
pitch_width = self.settings.pitch_dimensions.pitch_width
sample = 1.0 / self.sample_rate
df = pl.scan_csv(
self.tracking_file_path,
separator=",",
encoding="utf8",
null_values=["NA", "NULL", ""],
try_parse_dates=True,
)
play_direction = "left"
if "club" in df.collect_schema().names():
df = df.rename({"club": Column.TEAM_ID})
elif "team" in df.collect_schema().names():
df = df.rename({"team": Column.TEAM_ID})
if self._orient_ball_owning:
df = (
df.with_columns(
pl.when(pl.col("playDirection") == play_direction)
.then(pl.col(Column.ORIENTATION) + 180) # rotate 180 degrees
.otherwise(pl.col(Column.ORIENTATION))
.alias(Column.ORIENTATION),
pl.when(pl.col("playDirection") == play_direction)
.then(pl.col(Column.DIRECTION) + 180) # rotate 180 degrees
.otherwise(pl.col(Column.DIRECTION))
.alias(Column.DIRECTION),
)
.with_columns(
[
(pl.col(Column.X) - (pitch_length / 2)).alias(Column.X),
(pl.col(Column.Y) - (pitch_width / 2)).alias(Column.Y),
# convert to radian on (-pi, pi) range
(
((pl.col(Column.ORIENTATION) * np.pi / 180) + np.pi)
% (2 * np.pi)
- np.pi
).alias(Column.ORIENTATION),
(
((pl.col(Column.DIRECTION) * np.pi / 180) + np.pi)
% (2 * np.pi)
- np.pi
).alias(Column.DIRECTION),
]
)
.with_columns(
[
pl.when(pl.col("playDirection") == play_direction)
.then(pl.col(Column.X) * -1.0)
.otherwise(pl.col(Column.X))
.alias(Column.X),
pl.when(pl.col("playDirection") == play_direction)
.then(pl.col(Column.Y) * -1.0)
.otherwise(pl.col(Column.Y))
.alias(Column.Y),
# set "football" to nflId -9999 for ordering purposes
pl.when(pl.col(Column.TEAM_ID) == Constant.BALL)
.then(-9999.9)
.otherwise(pl.col("nflId"))
.alias("nflId"),
]
)
.with_columns(
[
pl.lit(play_direction).alias("playDirection"),
]
)
.filter((pl.col("frameId") % sample) == 0)
).collect()
else:
raise NotImplementedError(
"Currently, BigDataBowlDataset only allows Orientation.BALL_OWNING"
)
players = pl.read_csv(
self.players_file_path,
separator=",",
encoding="utf8",
null_values=["NA", "NULL", ""],
schema_overrides={"birthDate": pl.Date},
ignore_errors=True,
)
if "position" in players.columns:
players = players.rename({"position": Column.POSITION_NAME})
elif "officialPosition" in players.columns:
players = players.rename({"officialPosition": Column.POSITION_NAME})
players = players.with_columns(
pl.col("nflId").cast(pl.Float64, strict=False).alias("nflId")
)
players = self._convert_weight_height_to_metric(df=players)
plays = pl.read_csv(
self.plays_file_path,
separator=",",
encoding="utf8",
null_values=["NA", "NULL", ""],
try_parse_dates=True,
).rename(
{
"gameId": Column.GAME_ID,
"playId": Column.PLAY_ID,
"possessionTeam": Column.BALL_OWNING_TEAM_ID,
}
)
df = df.join(
(
players.select(
[
"nflId",
Column.POSITION_NAME,
Column.HEIGHT_CM,
Column.WEIGHT_KG,
]
)
),
on="nflId",
how="left",
)
df = df.rename(
{
"nflId": Column.OBJECT_ID,
"gameId": Column.GAME_ID,
"playId": Column.PLAY_ID,
"s": Column.SPEED,
}
)
df = df.join(
(plays.select(Group.BY_PLAY_BALL_OWNING)),
on=[Column.GAME_ID, Column.PLAY_ID],
how="left",
)
df = df.with_columns(
[
(pl.col(Column.PLAY_ID) * 100_000 + pl.col("frameId")).alias(
Column.FRAME_ID
)
]
).drop(["frameId"])
self.data = df.sort(
[Column.GAME_ID, Column.PLAY_ID, Column.FRAME_ID, Column.OBJECT_ID]
)
# update pitch dimensions to how it looks after loading
self.settings.pitch_dimensions = AmericanFootballPitchDimensions(
x_dim=Dimension(min=-pitch_length / 2, max=pitch_length / 2),
y_dim=Dimension(min=-pitch_width / 2, max=pitch_width / 2),
standardized=False,
unit=Unit.YARDS,
pitch_length=pitch_length,
pitch_width=pitch_width,
)
return self.data, self.settings
[docs]
def add_dummy_labels(
self, by: List[str] = [Column.GAME_ID, Column.FRAME_ID]
) -> pl.DataFrame:
self.data = add_dummy_label_column(self.data, by, self._label_column)
return self.data
[docs]
def add_graph_ids(self, by: List[str] = [Column.GAME_ID]) -> pl.DataFrame:
self.data = add_graph_id_column(self.data, by, self._graph_id_column)
return self.data
@staticmethod
def _convert_weight_height_to_metric(df: pl.DataFrame):
df = df.with_columns(
[
pl.col("height")
.str.extract(r"(\d+)")
.cast(pl.Float64)
.alias("feet"), # Extract feet and cast to float
pl.col("height")
.str.extract(r"\d+-(\d+)", 1)
.cast(pl.Float64)
.alias("inches"), # Extract inches and cast to float
]
)
# Convert height and weight to centimeters and kilograms
# Round them to 0.1 to make sure we don't leak any player specific info
df = (
df.with_columns(
[
((pl.col("feet") * 30.48 + pl.col("inches") * 2.54) / 10)
.round(0)
.alias(Column.HEIGHT_CM),
((pl.col("weight") * 0.453592) / 10)
.round(0)
.alias(Column.WEIGHT_KG),
]
)
.with_columns(
[
(pl.col(Column.HEIGHT_CM) * 10).alias(Column.HEIGHT_CM),
(pl.col(Column.WEIGHT_KG) * 10).alias(Column.WEIGHT_KG),
]
)
.drop(["height", "feet", "inches", "weight"])
)
return df