from typing import Dict, List, Literal, Optional
import polars as pl
import random
from kloppy.domain import TrackingDataset
[docs]
def dummy_labels(dataset: TrackingDataset) -> Dict:
"""
Create dummy labels to feed into GraphNeuralNetworkConverter
"""
if not isinstance(dataset, TrackingDataset):
raise TypeError("dataset should be of type TrackingDataset (from kloppy)")
labels = dict()
for frame in dataset:
labels[frame.frame_id] = random.choice([True, False])
return labels
[docs]
def dummy_graph_ids(dataset: TrackingDataset) -> Dict:
"""
Create dummy graph_ids to feed into GraphNeuralNetworkConverter
"""
if not isinstance(dataset, TrackingDataset):
raise TypeError("dataset should be of type TrackingDataset (from kloppy)")
from uuid import uuid4
graph_ids = dict()
fake_match_id = str(uuid4())
for i, frame in enumerate(dataset):
fake_possession_id = i % 10
graph_ids[frame.frame_id] = f"{fake_match_id}-{fake_possession_id}"
return graph_ids
[docs]
def add_dummy_label_column(
dataset: pl.DataFrame,
by: List[str] = ["gameId", "playId", "frameId"],
column_name: str = "label",
random_seed: Optional[float] = None,
):
unique_combinations = dataset.sort(by).select(by).unique()
n_combinations = len(unique_combinations)
if random_seed is not None:
random.seed(random_seed)
random_values = [random.choice([0, 1]) for _ in range(n_combinations)]
random_labels = unique_combinations.with_columns(
[pl.lit(random_values).alias("__temp_random_values")]
).sort(by=by)
random_labels = random_labels.with_row_index("__temp_idx").with_columns(
[
pl.col("__temp_random_values")
.list.get(pl.col("__temp_idx"))
.alias(column_name)
]
)
random_labels = random_labels.drop(["__temp_random_values", "__temp_idx"]).sort(
by=by
)
return dataset.join(random_labels, on=by, how="left")
[docs]
def add_graph_id_column(
dataset: pl.DataFrame,
by: List[str] = ["game_id", "play_id"],
column_name: str = "graph_id",
):
return dataset.with_columns([pl.concat_str(by, separator="-").alias(column_name)])
def create_default_expression(col_name, dtype):
if dtype == pl.Boolean:
return pl.lit(False).alias(col_name)
elif dtype == pl.Int32:
return pl.lit(0).cast(pl.Int32).alias(col_name)
elif dtype == pl.Int64:
return pl.lit(0).cast(pl.Int64).alias(col_name)
elif dtype == pl.Float32:
return pl.lit(0.0).cast(pl.Float32).alias(col_name)
elif dtype == pl.Float64:
return pl.lit(0.0).alias(col_name)
elif dtype == pl.Utf8:
return pl.lit("").alias(col_name)
else:
return pl.lit(None).cast(dtype).alias(col_name)