Source code for unravel.utils.objects.graph_dataset

import logging
import sys
from typing import List, Tuple, Union, Optional, Literal

import numpy as np

import gzip
import pickle
from pathlib import Path

import warnings

from collections.abc import Sequence

from unravel.utils.exceptions import NoGraphIdsWarning, SpektralDependencyError


def load_pickle_gz(file_path):
    with gzip.open(file_path, "rb") as f:
        data = pickle.load(f)
    return data


class _GraphDatasetMixin:
    """
    Base mixin for graph dataset functionality.
    Framework-agnostic implementation that works with both Spektral and PyTorch Geometric.
    """

    def __init__(self, **kwargs):
        """
        Constructor to load parameters.

        Args:
            pickle_folder: Path to folder containing .pickle.gz files
            pickle_file: Path to single .pickle.gz file
            graphs: List of graph objects (Spektral Graph, PyG Data, or dicts)
            format: Optional explicit format specification ('spektral' or 'pyg')
            sample_rate: Sampling rate (1.0 = use all data)
        """
        self._kwargs = kwargs
        self._explicit_format = kwargs.get("format", None)

        sample_rate = kwargs.get("sample_rate", 1.0)
        self.sample = 1.0 / sample_rate

        if kwargs.get("pickle_folder", None):
            pickle_folder = Path(kwargs["pickle_folder"])
            self.graphs = None
            # Loop over all .pickle.gz files in the folder
            for pickle_file in pickle_folder.glob("*.pickle.gz"):
                data = load_pickle_gz(pickle_file)
                if not self.graphs:
                    self.graphs = self.__convert(data)
                else:
                    self.add(data)

        elif kwargs.get("pickle_file", None):
            pickle_file = Path(kwargs["pickle_file"])
            self.graphs = None
            data = load_pickle_gz(pickle_file)

            if not self.graphs:
                self.graphs = self.__convert(data)
            else:
                self.add(data)

        elif kwargs.get("graphs", None):
            if not isinstance(kwargs["graphs"], list):
                raise NotImplementedError("""data should be of type list""")

            self.graphs = self.__convert(kwargs["graphs"])
        else:
            raise NotImplementedError(
                "Please provide either 'pickle_folder', 'pickle_file' or 'graphs' as parameter to GraphDataset"
            )

        # Only call super().__init__ if there's a parent class that needs it
        # For PyGGraphDataset, Sequence doesn't take kwargs
        # For SpektralGraphDataset, Dataset does take kwargs
        try:
            super().__init__(**kwargs)
        except TypeError:
            # If super().__init__() doesn't accept kwargs (like Sequence), call it without args
            super().__init__()

    def __convert(self, data):
        """
        Convert incoming data to correct format.
        Must be implemented by subclasses.
        """
        raise NotImplementedError("Subclasses must implement __convert()")

    def read(self):
        """
        Overriding the read function - to return a list of Graph objects.
        Must be implemented by subclasses.
        """
        raise NotImplementedError("Subclasses must implement read()")

    def add(self, other, verbose: bool = False):
        """Add more graphs to the dataset"""
        other = self.__convert(other)

        if verbose:
            logging.info(f"Adding {len(other)} graphs to GraphDataset...")

        self.graphs = self.graphs + other

    def dimensions(self) -> Tuple[int, int, int, int, int]:
        """
        N = Max number of nodes
        F = Dimensions of Node Features
        S = Dimensions of Edge Features
        n_out = Dimension of the target
        n = Number of samples in dataset
        """
        raise NotImplementedError("Subclasses must implement dimensions()")

    def split_test_train(
        self,
        split_train: float,
        split_test: float,
        by_graph_id: bool = False,
        random_seed: Union[bool, int] = False,
        train_label_ratio: Optional[float] = None,
        test_label_ratio: Optional[float] = None,
    ):
        return self.split_test_train_validation(
            split_train=split_train,
            split_test=split_test,
            split_validation=0.0,
            by_graph_id=by_graph_id,
            random_seed=random_seed,
            train_label_ratio=train_label_ratio,
            test_label_ratio=test_label_ratio,
        )

    def split_test_train_validation(
        self,
        split_train: float,
        split_test: float,
        split_validation: float,
        by_graph_id: bool = False,
        random_seed: int = None,
        train_label_ratio: Optional[float] = None,
        test_label_ratio: Optional[float] = None,
        val_label_ratio: Optional[float] = None,
    ):
        """
        Split dataset into train, test, and validation sets with optional label balancing.
        """
        total = split_train + split_test + split_validation

        train_pct = split_train / total
        test_pct = split_test / total
        validation_pct = split_validation / total

        if by_graph_id and (
            (validation_pct > train_pct)
            or (test_pct > train_pct)
            or (validation_pct > test_pct)
        ):
            raise NotImplementedError(
                "Make sure split_train > split_test >= split_validation, other behaviour is not supported when by_graph_id is True..."
            )

        dataset_length = len(self)
        num_train = int(train_pct * dataset_length)

        if validation_pct > 0:
            num_test = int(test_pct * dataset_length)
            num_validation = dataset_length - num_train - num_test
        else:
            num_test = dataset_length - num_train
            num_validation = 0

        unique_graph_ids = set(
            [
                g.get("id") if hasattr(g, "id") else getattr(g, "graph_id", None)
                for g in self
            ]
        )
        if unique_graph_ids == {None}:
            by_graph_id = False

            warnings.warn(
                f"""No graph_ids available, continuing with by_graph_id=False... If you want to use graph_ids please specify in GraphConverter class""",
                NoGraphIdsWarning,
            )

        if not by_graph_id:
            if random_seed:
                idxs = np.random.RandomState(seed=random_seed).permutation(
                    dataset_length
                )
            else:
                idxs = np.arange(dataset_length)

            if num_validation > 0:
                train_idxs = idxs[:num_train]
                test_idxs = idxs[num_train : num_train + num_test]
                validation_idxs = idxs[
                    num_train + num_test : num_train + num_test + num_validation
                ]

                train_set = self[train_idxs]
                test_set = self[test_idxs]
                validation_set = self[validation_idxs]

                if train_label_ratio is not None:
                    train_set = self._balance_labels(
                        train_set, train_label_ratio, random_seed
                    )
                if test_label_ratio is not None:
                    test_set = self._balance_labels(
                        test_set, test_label_ratio, random_seed
                    )
                if val_label_ratio is not None:
                    validation_set = self._balance_labels(
                        validation_set, val_label_ratio, random_seed
                    )

                return train_set, test_set, validation_set
            else:
                train_idxs = idxs[:num_train]
                test_idxs = idxs[num_train:]

                train_set = self[train_idxs]
                test_set = self[test_idxs]

                if train_label_ratio is not None:
                    train_set = self._balance_labels(
                        train_set, train_label_ratio, random_seed
                    )
                if test_label_ratio is not None:
                    test_set = self._balance_labels(
                        test_set, test_label_ratio, random_seed
                    )

                return train_set, test_set
        else:
            # Get graph IDs in a framework-agnostic way
            graph_ids = np.asarray(
                [
                    (
                        g.get("id")
                        if hasattr(g, "get") and g.get("id") is not None
                        else getattr(g, "graph_id", None)
                    )
                    for g in self
                ]
            )

            if random_seed:
                np.random.seed(random_seed)

            unique_graph_ids_list = sorted(list(unique_graph_ids))
            np.random.shuffle(unique_graph_ids_list)

            test_idxs, train_idxs, validation_idxs = list(), list(), list()

            def __handle_graph_id(i):
                graph_id = unique_graph_ids_list[i]
                unique_graph_ids.remove(graph_id)
                graph_idxs = np.where(graph_ids == graph_id)[0]
                return graph_idxs

            i = 0
            if num_validation > 0:
                while len(validation_idxs) < num_validation:
                    graph_idxs = __handle_graph_id(i)
                    validation_idxs.extend(graph_idxs)
                    i += 1

            while len(test_idxs) < num_test:
                graph_idxs = __handle_graph_id(i)
                test_idxs.extend(graph_idxs)
                i += 1

            train_idxs = np.isin(graph_ids, np.asarray(list(unique_graph_ids)))
            train_idxs = np.where(train_idxs)[0]

            if validation_idxs:
                train_set = self[train_idxs]
                test_set = self[test_idxs]
                validation_set = self[validation_idxs]

                if train_label_ratio is not None:
                    train_set = self._balance_labels(
                        train_set, train_label_ratio, random_seed
                    )
                if test_label_ratio is not None:
                    test_set = self._balance_labels(
                        test_set, test_label_ratio, random_seed
                    )
                if val_label_ratio is not None:
                    validation_set = self._balance_labels(
                        validation_set, val_label_ratio, random_seed
                    )

                return train_set, test_set, validation_set
            else:
                train_set = self[train_idxs]
                test_set = self[test_idxs]

                if train_label_ratio is not None:
                    train_set = self._balance_labels(
                        train_set, train_label_ratio, random_seed
                    )
                if test_label_ratio is not None:
                    test_set = self._balance_labels(
                        test_set, test_label_ratio, random_seed
                    )

                return train_set, test_set

    def _balance_labels(self, dataset, target_ratio, random_seed):
        """Balance a dataset to achieve a target ratio of labels."""
        if random_seed:
            np.random.seed(random_seed)

        if not 0 <= target_ratio <= 1:
            raise ValueError("target_ratio must be between 0 and 1")

        indices_by_label = {0: [], 1: []}

        for i, g in enumerate(dataset):
            # Handle different types of label storage
            if hasattr(g, "y"):
                y_value = g.y
            elif hasattr(g, "get") and g.get("y", None) is not None:
                y_value = g["y"]
            else:
                raise ValueError("Graph has no attribute 'y'...")

            if isinstance(y_value, (np.ndarray, list)):
                if len(y_value) != 1:
                    raise ValueError(
                        f"Expected y to be a single value, but got array of length {len(y_value)}"
                    )
                label = 1 if y_value[0] > 0.5 else 0
            else:
                label = 1 if y_value > 0.5 else 0

            indices_by_label[label].append(i)

        n_zeros = len(indices_by_label[0])
        n_ones = len(indices_by_label[1])
        total = n_zeros + n_ones

        current_ratio = n_ones / total if total > 0 else 0

        if abs(current_ratio - target_ratio) < 0.01:
            return dataset

        if current_ratio > target_ratio:
            target_ones = int(n_zeros * target_ratio / (1 - target_ratio))
            target_zeros = n_zeros
        else:
            target_zeros = int(n_ones * (1 - target_ratio) / target_ratio)
            target_ones = n_ones

        indices_to_keep = []

        if n_zeros > target_zeros:
            sampled_zeros = np.random.choice(
                indices_by_label[0], target_zeros, replace=False
            )
            indices_to_keep.extend(sampled_zeros)
        else:
            indices_to_keep.extend(indices_by_label[0])

        if n_ones > target_ones:
            sampled_ones = np.random.choice(
                indices_by_label[1], target_ones, replace=False
            )
            indices_to_keep.extend(sampled_ones)
        else:
            indices_to_keep.extend(indices_by_label[1])

        np.random.shuffle(indices_to_keep)

        return dataset[indices_to_keep]


# =============================================================================
# SPEKTRAL IMPLEMENTATION
# =============================================================================

try:
    from spektral.data import Dataset, Graph
    from spektral.data.utils import get_spec
    import tensorflow as tf

    _HAS_SPEKTRAL = True

    class SpektralGraphDataset(_GraphDatasetMixin, Dataset, Sequence):
        """
        Spektral-specific GraphDataset implementation.
        """

        def _SpektralGraphDataset__convert(self, data) -> List:
            """Convert incoming data to Spektral Graph format"""
            from spektral.data import Graph

            if isinstance(data[0], Graph):
                return [g for i, g in enumerate(data) if i % self.sample == 0]
            elif isinstance(data[0], dict):
                return [
                    Graph(
                        x=g["x"],
                        a=g["a"],
                        e=g["e"],
                        y=g["y"],
                        id=g["id"],
                        frame_id=g.get("frame_id", None),
                        object_ids=g.get("object_ids", None),
                        ball_owning_team_id=g.get("ball_owning_team_id", None),
                    )
                    for i, g in enumerate(data)
                    if i % self.sample == 0
                ]
            else:
                raise ValueError(
                    f"Cannot convert type {type(data[0])} to Spektral Graph. "
                    "Expected Spektral Graph or dict."
                )

        _GraphDatasetMixin__convert = _SpektralGraphDataset__convert

        def read(self) -> List:
            """Return a list of Spektral Graph objects"""
            graphs = self._SpektralGraphDataset__convert(self.graphs)
            logging.info(f"Loading {len(graphs)} graphs into SpektralGraphDataset...")
            return graphs

        def dimensions(self) -> Tuple[int, int, int, int, int]:
            """N, F, S, n_out, n"""
            N = max(g.n_nodes for g in self)
            F = self.n_node_features
            S = self.n_edge_features
            n_out = self.n_labels
            n = len(self)
            return (N, F, S, n_out, n)

        @property
        def signature(self):
            """Compute TensorFlow signature for the dataset"""
            from spektral.data.utils import get_spec
            import tensorflow as tf

            if len(self.graphs) == 0:
                return None
            signature = {}
            graph = self.graphs[0]

            if graph.x is not None:
                signature["x"] = dict()
                signature["x"]["spec"] = get_spec(graph.x)
                signature["x"]["shape"] = (None, self.n_node_features)
                signature["x"]["dtype"] = tf.as_dtype(graph.x.dtype)

            if graph.a is not None:
                signature["a"] = dict()
                signature["a"]["spec"] = get_spec(graph.a)
                signature["a"]["shape"] = (None, None)
                signature["a"]["dtype"] = tf.as_dtype(graph.a.dtype)

            if graph.e is not None:
                signature["e"] = dict()
                signature["e"]["spec"] = get_spec(graph.e)
                signature["e"]["shape"] = (None, self.n_edge_features)
                signature["e"]["dtype"] = tf.as_dtype(graph.e.dtype)

            if graph.y is not None:
                signature["y"] = dict()
                signature["y"]["spec"] = get_spec(graph.y)
                signature["y"]["shape"] = (self.n_labels,)
                signature["y"]["dtype"] = tf.as_dtype(np.array(graph.y).dtype)

            if hasattr(graph, "g") and graph.g is not None:
                signature["g"] = dict()
                signature["g"]["spec"] = get_spec(graph.g)
                signature["g"]["shape"] = graph.g.shape
                signature["g"]["dtype"] = tf.as_dtype(np.array(graph.g).dtype)

            return signature

except ImportError:
    _HAS_SPEKTRAL = False

    # Create a dummy class that raises an informative error
    class SpektralGraphDataset:
        def __init__(self, *args, **kwargs):
            raise SpektralDependencyError()


# =============================================================================
# PYTORCH GEOMETRIC IMPLEMENTATION
# =============================================================================

try:
    import torch
    from torch_geometric.data import Data

    _HAS_TORCH_GEOMETRIC = True
except ImportError:
    _HAS_TORCH_GEOMETRIC = False


class PyGGraphDataset(_GraphDatasetMixin, Sequence):
    """
    PyTorch Geometric GraphDataset implementation.
    """

    def _PyGGraphDataset__convert(self, data) -> List:
        """Convert incoming data to PyG Data format"""
        if not _HAS_TORCH_GEOMETRIC:
            raise ImportError(
                "PyTorch Geometric is required for PyGGraphDataset. "
                "Install it using: pip install torch torch-geometric"
            )

        from torch_geometric.data import Data

        if isinstance(data[0], Data):
            return [g for i, g in enumerate(data) if i % self.sample == 0]
        elif isinstance(data[0], dict):
            pyg_graphs = []
            for i, d in enumerate(data):
                if i % self.sample != 0:
                    continue

                # Node features
                x = torch.tensor(d["x"], dtype=torch.float)

                # Get adjacency matrix and convert to edge_index
                a = d["a"].toarray() if hasattr(d["a"], "toarray") else d["a"]
                edge_indices = np.nonzero(a)
                edge_index = torch.tensor(np.vstack(edge_indices), dtype=torch.long)

                # Edge features (already aligned with edges)
                edge_attr = torch.tensor(d["e"], dtype=torch.float)

                # Labels
                y = torch.tensor(d["y"], dtype=torch.long)

                # Create Data object
                graph_data = Data(
                    x=x,
                    edge_index=edge_index,
                    edge_attr=edge_attr,
                    y=y,
                )

                # Add custom attributes
                graph_data.id = d.get("id", None)
                graph_data.frame_id = d.get("frame_id", None)
                graph_data.ball_owning_team_id = d.get("ball_owning_team_id", None)
                graph_data.object_ids = d.get("object_ids", None)

                pyg_graphs.append(graph_data)

            return pyg_graphs
        else:
            raise ValueError(
                f"Cannot convert type {type(data[0])} to PyG Data. "
                "Expected PyG Data or dict."
            )

    _GraphDatasetMixin__convert = _PyGGraphDataset__convert

    def read(self) -> List:
        """Return a list of PyG Data objects"""
        if not _HAS_TORCH_GEOMETRIC:
            raise ImportError(
                "PyTorch Geometric is required. "
                "Install it using: pip install torch torch-geometric"
            )

        graphs = self._PyGGraphDataset__convert(self.graphs)
        logging.info(f"Loading {len(graphs)} graphs into PyGGraphDataset...")
        return graphs

    def dimensions(self) -> Tuple[int, int, int, int, int]:
        """N, F, S, n_out, n"""
        N = max(data.num_nodes for data in self)
        F = self[0].num_node_features if len(self) > 0 else 0
        S = self[0].num_edge_features if len(self) > 0 else 0
        n_out = self[0].y.shape[0] if len(self) > 0 else 0
        n = len(self)
        return (N, F, S, n_out, n)

    def __len__(self):
        return len(self.graphs)

    def __getitem__(self, idx):
        if isinstance(idx, (list, np.ndarray)):
            selected_graphs = [self.graphs[i] for i in idx]
            return PyGGraphDataset(graphs=selected_graphs, sample_rate=1.0)
        else:
            return self.graphs[idx]

    def __repr__(self):
        return f"PyGGraphDataset(n_graphs={len(self)})"

    @property
    def n_graphs(self):
        return len(self)


[docs] def GraphDataset( format: Optional[Literal["spektral", "pyg"]] = "spektral", **kwargs ) -> Union[SpektralGraphDataset, PyGGraphDataset]: """ Factory function that creates the appropriate dataset based on format. Args: format: Format specification ('spektral' or 'pyg'). Defaults to 'spektral'. **kwargs: Arguments passed to the dataset constructor Returns: SpektralGraphDataset or PyGGraphDataset depending on format Examples: # Spektral format (default) dataset = GraphDataset(graphs=spektral_graph_list, format='spektral') # PyG format dataset = GraphDataset(graphs=pyg_data_list, format='pyg') # From pickle files dataset = GraphDataset(pickle_file='graphs.pickle.gz', format='pyg') """ import warnings if format == "spektral": warnings.warn( """ unravelsports now supports PyTorch Geometric. The default "format" will change from 'spektral' to 'pyg' in a future version. \nNote: format='spektral' only really works on Python 3.11, due to very specific package requirements. PyTorch works on 3.11+. """, FutureWarning, ) def _create_dataset(fmt: str): """Helper function to create the appropriate dataset""" if fmt.lower() == "spektral": if not _HAS_SPEKTRAL: raise SpektralDependencyError() return SpektralGraphDataset(**kwargs) elif fmt.lower() == "pyg": if not _HAS_TORCH_GEOMETRIC: raise ImportError( "PyTorch Geometric is required. " "Install it using: pip install torch torch-geometric" ) return PyGGraphDataset(**kwargs) else: raise ValueError(f"format must be 'spektral' or 'pyg', got '{fmt}'") if ( kwargs.get("graphs") is None and kwargs.get("pickle_file") is None and kwargs.get("pickle_folder") is None ): raise ValueError( "Must provide either 'graphs', 'pickle_file', or 'pickle_folder'" ) if kwargs.get("graphs") is not None: graphs = kwargs["graphs"] if not isinstance(graphs, list) or len(graphs) == 0: raise ValueError("graphs must be a non-empty list") return _create_dataset(format)