Source code for unravel.classifiers.crystal_graph

from spektral.layers import GlobalAvgPool, CrystalConv
from tensorflow.keras.layers import Dense, Dropout
from tensorflow.keras.models import Model


[docs] class CrystalGraphClassifier(Model): """ Default Graph Classifier with CrystalConvolution layers as presented in Sahasrabudhe & Bekkers (2023) """
[docs] def __init__( self, n_layers: int = 3, channels: int = 128, drop_out: float = 0.5, n_out: int = 1, **kwargs ): super().__init__(**kwargs) self.n_layers = n_layers self.channels = channels self.drop_out = drop_out self.n_out = n_out self.conv1 = CrystalConv() self.convs = [CrystalConv() for _ in range(1, self.n_layers)] self.pool = GlobalAvgPool() self.dense1 = Dense(self.channels, activation="relu") self.dropout = Dropout(self.drop_out) self.dense2 = Dense(self.channels, activation="relu") self.dense3 = Dense(self.n_out, activation="sigmoid")
[docs] def call(self, inputs): x, a, e, i = inputs x = self.conv1([x, a, e]) for conv in self.convs: x = conv([x, a, e]) x = self.pool([x, i]) x = self.dense1(x) x = self.dropout(x) x = self.dense2(x) x = self.dropout(x) return self.dense3(x)