Source code for Garfield.nn.decoders

"""
This module contains the decoder used by the Garfield model.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv, GATConv, GATv2Conv
from .utils import DSBatchNorm, compute_cosine_similarity


[docs] class GATDecoder(nn.Module): """ Graph Attention Network (GAT) Decoder class. This class implements a GAT-based decoder for reconstructing node features from latent representations. It supports domain-specific batch normalization (DSBN) and edge weights. Parameters ---------- in_channels : int Number of input feature dimensions. hidden_dims : list[int] List of output dimensions for each hidden GAT layer, in ascending order. out_channels : int Number of output feature dimensions. conv_type : str Type of GAT convolution layer to use ('GAT' or 'GATv2Conv'). num_heads : int Number of attention heads for each GAT layer. dropout : float Dropout rate applied to GAT layers. concat : bool Whether to concatenate the output of all attention heads or not. num_domains : int or str Number of domains for domain-specific batch normalization (DSBN). If `1`, regular batch normalization is used. used_edge_weight : bool, optional Whether to use edge weights in the GAT layers. Default is False. used_DSBN : bool, optional Whether to use domain-specific batch normalization (DSBN). Default is False. """
[docs] def __init__( self, in_channels, hidden_dims, out_channels, conv_type, num_heads, dropout, concat, num_domains="", used_edge_weight=False, used_DSBN=False, ): """ Initializes the GATDecoder, which consists of multiple Graph Attention Network (GAT) layers followed by domain-specific normalization (if applicable). """ super(GATDecoder, self).__init__() self.used_DSBN = used_DSBN self.used_edge_weight = used_edge_weight self.layers = nn.ModuleList() self.norm = nn.ModuleList() self.dropout = dropout if conv_type == "GAT": GATLayer = GATConv elif conv_type == "GATv2Conv": GATLayer = GATv2Conv num_hidden_layers = len(hidden_dims) num_heads_list = [num_heads] * num_hidden_layers concat_list = [concat] * num_hidden_layers for i in range(num_hidden_layers): if concat_list[i]: current_dim = hidden_dims[::-1][i] * num_heads_list[i] else: current_dim = hidden_dims[::-1][i] # [::-1] 代表反转 norm = None if type(num_domains) == int: if num_domains == 1: # TO DO norm = nn.BatchNorm1d(current_dim) elif num_domains > 1 and self.used_DSBN: # num_domains >1 represent domain-specific batch normalization of n domain norm = DSBatchNorm(current_dim, num_domains) self.norm.append(norm) current_dim = in_channels # in_channels dropout_list = [dropout] * num_hidden_layers for i in range(num_hidden_layers): layer = GATLayer( in_channels=current_dim, out_channels=hidden_dims[::-1][i], heads=num_heads_list[i], dropout=dropout_list[i], concat=concat_list[i], edge_dim=1 if self.used_edge_weight else None, ) self.layers.append(layer) if concat_list[i]: current_dim = hidden_dims[::-1][i] * num_heads_list[i] else: current_dim = hidden_dims[::-1][i] ### 数据集总重构的layers self.conv_recon = GATLayer( in_channels=current_dim, out_channels=out_channels, heads=num_heads, concat=False, edge_dim=1 if self.used_edge_weight else None, dropout=dropout, )
def forward(self, x, data): """ Performs a forward pass through the GAT decoder layers and reconstructs the node features. Parameters ---------- x : torch.Tensor Node features (shape: [num_nodes, feature_dim]). data : Data PyTorch Geometric Data object containing edge index, edge attributes, and domain labels. Returns ------- torch.Tensor Reconstructed node features after passing through the GAT layers and applying normalization (if applicable). """ edge_index, y, edge_index_all = data.edge_index, data.y, data.edge_attr edge_weight = edge_index_all[:, 2] # edge_weight = torch.ones(edge_index.shape[1]).cpu().numpy() # 先将张量转移到CPU,再转为numpy for idx, layer in enumerate(self.layers): x, _ = layer( x, edge_index, edge_attr=edge_weight if self.used_edge_weight else None, return_attention_weights=True, ) if self.used_DSBN: if self.norm: if len(x) == 1: pass elif self.norm[0].__class__.__name__ == "DSBatchNorm": x = self.norm[idx](x, y) else: x = self.norm[idx](x) x = F.relu(x) recon_x, _ = self.conv_recon( x, edge_index, edge_attr=edge_weight if self.used_edge_weight else None, return_attention_weights=True, ) return recon_x
### GCN decoder
[docs] class GCNDecoder(nn.Module): """ Graph Convolutional Network (GCN) Decoder class. This class implements a GCN-based decoder for reconstructing node features from latent representations. It supports domain-specific batch normalization (DSBN) and edge weights. Parameters ---------- in_channels : int Number of input feature dimensions. hidden_dims : list[int] List of output dimensions for each hidden GCN layer, in ascending order. out_channels : int Number of output feature dimensions. dropout : float, optional Dropout rate applied to GCN layers. Default is 0.2. num_domains : int or str Number of domains for domain-specific batch normalization (DSBN). If `1`, regular batch normalization is used. used_edge_weight : bool, optional Whether to use edge weights in the GCN layers. Default is False. used_DSBN : bool, optional Whether to use domain-specific batch normalization (DSBN). Default is False. """
[docs] def __init__( self, in_channels, hidden_dims, out_channels, dropout=0.2, num_domains="", used_edge_weight=False, used_DSBN=False, ): """ Initializes the GCNDecoder, consisting of multiple Graph Convolutional Network (GCN) layers followed by domain-specific normalization (if applicable). """ super(GCNDecoder, self).__init__() # 如果 hidden_channels 是单一数字,将其转换成单元素列表 if isinstance(hidden_dims, int): hidden_dims = [hidden_dims] self.used_DSBN = used_DSBN self.used_edge_weight = used_edge_weight self.layers = nn.ModuleList() self.norm = nn.ModuleList() # 创建一个包含所有GCN层的列表 gcn_layers = [] hidden_dims = hidden_dims[::-1] # 反转 total_layers = [in_channels] + hidden_dims for i in range(len(total_layers) - 1): gcn_layers.append( GCNConv(total_layers[i], total_layers[i + 1], dropout=dropout) ) # 使用 nn.ModuleList 以确保所有层都被正确注册 self.gcn_layers = nn.ModuleList(gcn_layers) # 输出层 self.gcn_recon = GCNConv(hidden_dims[-1], out_channels, dropout=dropout) # self.norm 层 num_hidden_layers = len(hidden_dims) for i in range(num_hidden_layers): current_dim = hidden_dims[i] norm = None if type(num_domains) == int: if num_domains == 1: # TO DO norm = nn.BatchNorm1d(current_dim) elif num_domains > 1 and self.used_DSBN: norm = DSBatchNorm( current_dim, num_domains ) # num_domains >1 represent domain-specific batch normalization of n domain self.norm.append(norm)
def forward(self, x, data): """ Performs a forward pass through the GCN decoder layers and reconstructs the node features. Parameters ---------- x : torch.Tensor Node features (shape: [num_nodes, feature_dim]). data : Data PyTorch Geometric Data object containing edge index, edge attributes, and domain labels. Returns ------- torch.Tensor Reconstructed node features after passing through the GCN layers and applying normalization (if applicable). """ edge_index, y, edge_index_all = data.edge_index, data.y, data.edge_attr edge_weight = edge_index_all[:, 2] # edge_weight = torch.ones(edge_index.shape[1]).cpu().numpy() # 先将张量转移到CPU,再转为numpy ### latent for idx, layer in enumerate(self.gcn_layers): x = layer( x, edge_index, edge_weight=edge_weight if self.used_edge_weight else None, ) if self.used_DSBN: if self.norm: if len(x) == 1: pass elif self.norm[0].__class__.__name__ == "DSBatchNorm": print("Perform DSBN normalization...") x = self.norm[idx](x, y) else: print("Perform batch normalization...") x = self.norm[idx](x) x = F.relu(x) x_recon = self.gcn_recon( x, edge_index, edge_weight=edge_weight if self.used_edge_weight else None ) return x_recon
class CosineSimGraphDecoder(nn.Module): """ Cosine similarity graph decoder class. Takes the concatenated latent feature vectors z of the source and target nodes as input, and calculates the element-wise cosine similarity between source and target nodes to return the reconstructed edge logits. The sigmoid activation function to compute reconstructed edge probabilities is integrated into the binary cross entropy loss for computational efficiency. Parameters ---------- dropout_rate: Probability of nodes to be dropped during training. """ def __init__(self, dropout_rate: float = 0.0): super().__init__() print("COSINE SIM GRAPH DECODER -> " f"dropout_rate: {dropout_rate}") self.dropout = nn.Dropout(dropout_rate) def forward(self, z: torch.Tensor) -> torch.Tensor: """ Forward pass of the cosine similarity graph decoder. Parameters ---------- z: Concatenated latent feature vector of the source and target nodes (dim: 4 * edge_batch_size x n_gps due to negative edges). Returns ---------- edge_recon_logits: Reconstructed edge logits (dim: 2 * edge_batch_size due to negative edges). """ z = self.dropout(z) # Compute element-wise cosine similarity edge_recon_logits = compute_cosine_similarity( z[: int(z.shape[0] / 2)], z[int(z.shape[0] / 2) :] # ´edge_label_index[0]´ ) # ´edge_label_index[1]´ return edge_recon_logits