"""
This module contains the encoder used by the Garfield model.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.utils import dropout_adj
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv, GATConv, GATv2Conv
from .utils import DSBatchNorm, drop_feature
class Projection(nn.Module):
def __init__(self, in_dim: int, encoder_dim: int):
"""
Initializes the Projection layer.
Parameters
----------
in_dim : int
The dimension of the input features.
encoder_dim : int
The dimension of the encoded features.
"""
super().__init__()
self.layer = nn.Linear(in_dim, encoder_dim, bias=True)
self.relu = nn.ReLU()
def forward(self, x):
"""
Forward pass through the Projection layer.
Parameters
----------
x : torch.Tensor
Input tensor with shape (batch_size, in_dim).
Returns
-------
torch.Tensor
Output tensor with shape (batch_size, encoder_dim) after applying a linear transformation and ReLU activation.
"""
return self.relu(self.layer(x))
[docs]
class GATEncoder(nn.Module):
"""
The GATEncoder class implements a Graph Attention Network (GAT) encoder with multiple layers,
normalization, and optional fully connected (FC) encoder. It supports different types of GAT
convolutions and augmentations for omics and graph data.
Methods
----------
forward(data, decoder_type, augment_type)
Performs the forward pass through the GAT encoder, applying either omics or graph decoding,
with optional augmentation.
_forward_through_layers(x, edge_index, edge_weight, y)
Helper function to pass the input features through multiple GAT layers and apply normalization.
Parameters
----------
in_channels : int
Number of input feature dimensions (length of each node's feature vector).
hidden_dims : list[int]
List of output dimensions for each hidden layer in the GAT.
latent_dim : int
Dimension of the latent feature representation produced by the encoder.
conv_type : str
Type of GAT convolution to use ('GAT' or 'GATv2Conv').
use_FCencoder : bool
Whether to use an additional fully connected encoder before the GAT layers.
drop_feature_rate : float
Dropout rate for node features during augmentation.
drop_edge_rate : float
Dropout rate for edges during augmentation.
svd_q : int
Rank for the low-rank SVD approximation used in augmentations. Default is 5.
num_heads : int
Number of attention heads for each GAT layer.
dropout : float
Dropout rate for GAT layers.
concat : bool
Whether to concatenate the outputs of all attention heads.
num_domains : int, optional
Number of domains for domain-specific batch normalization (DSBN). If `1`, regular batch normalization is used. Default is 1.
used_edge_weight : bool, optional
Whether to use edge weights in the GAT layers. Default is False.
used_DSBN : bool, optional
Whether to use domain-specific batch normalization (DSBN). Default is False.
"""
[docs]
def __init__(
self,
in_channels,
hidden_dims,
latent_dim,
conv_type,
use_FCencoder,
drop_feature_rate,
drop_edge_rate,
svd_q,
num_heads,
dropout,
concat,
num_domains=1,
used_edge_weight=False,
used_DSBN=False,
):
"""
Initializes the GATEncoder with multiple Graph Attention Network (GAT) layers, normalization layers,
and optional fully connected (FC) encoder.
"""
super(GATEncoder, self).__init__()
self.use_FCencoder = use_FCencoder
self.drop_feature_rate = drop_feature_rate
self.drop_edge_rate = drop_edge_rate
self.svd_q = svd_q
self.used_DSBN = used_DSBN
self.used_edge_weight = used_edge_weight
self.layers = nn.ModuleList()
self.norm_layers = nn.ModuleList()
# Choose GAT layer type based on `conv_type`
GATLayer = GATConv if conv_type == "GAT" else GATv2Conv
# Initialize normalization layers based on `num_domains`
for hidden_dim in hidden_dims:
current_dim = hidden_dim * num_heads if concat else hidden_dim
if num_domains == 1:
norm_layer = nn.BatchNorm1d(current_dim)
else:
norm_layer = DSBatchNorm(current_dim, num_domains)
self.norm_layers.append(norm_layer)
# Initialize projection layer if `use_FCencoder` is True
if self.use_FCencoder:
encoder_dim = hidden_dims[0] * 2
self.proj = Projection(in_channels, encoder_dim)
current_dim = encoder_dim
else:
current_dim = in_channels
# Initialize GAT layers
for hidden_dim in hidden_dims:
layer = GATLayer(
in_channels=current_dim,
out_channels=hidden_dim,
heads=num_heads,
dropout=dropout,
concat=concat,
edge_dim=1 if self.used_edge_weight else None,
)
self.layers.append(layer)
current_dim = hidden_dim * num_heads if concat else hidden_dim
# Initialize the final mean and log standard deviation layers
self.conv_mean = GATLayer(
in_channels=current_dim,
out_channels=latent_dim,
heads=num_heads,
concat=False,
edge_dim=1 if self.used_edge_weight else None,
dropout=dropout,
)
self.conv_log_std = GATLayer(
in_channels=current_dim,
out_channels=latent_dim,
heads=num_heads,
concat=False,
edge_dim=1 if self.used_edge_weight else None,
dropout=dropout,
)
[docs]
def forward(self, data, decoder_type, augment_type):
"""
Performs the forward pass through the GAT encoder, applying either omics or graph decoding, with optional augmentation.
Parameters
----------
data : Data
PyTorch Geometric Data object containing node features, edge index, and other graph-related information.
decoder_type : str
Specifies the type of decoder to use, either 'omics' or 'graph'.
augment_type : str, optional
Specifies the type of augmentation to apply, either 'dropout' or 'svd'.
If None, no augmentation is applied.
Returns
-------
Tuple of torch.Tensor
For 'omics' decoder type: (z_mean1, z_log_std1, z_mean2, z_log_std2), where:
- z_mean1: Mean of the latent space from original input.
- z_log_std1: Log standard deviation of the latent space from original input.
- z_mean2: Mean of the latent space from augmented input.
- z_log_std2: Log standard deviation of the latent space from augmented input.
For 'graph' decoder type: (z_mean1, z_log_std1), where:
- z_mean1: Mean of the latent space.
- z_log_std1: Log standard deviation of the latent space.
"""
x, edge_index_all, y = data.x, data.edge_index, data.y
edge_index = edge_index_all[:, :2]
if decoder_type == "omics":
if augment_type is not None and augment_type == "dropout":
edge_weight = edge_index_all[:, 2] if self.used_edge_weight else None
x_aug = drop_feature(x=x, drop_prob=self.drop_feature_rate)
edge_index_aug = dropout_adj(edge_index, p=self.drop_edge_rate)[0]
edge_weight_aug = edge_weight
elif augment_type is not None and augment_type == "svd":
edge_weight = edge_index_all[:, 2].float()
num_nodes = int(edge_index.max().item()) + 1
sparse_adj = torch.sparse_coo_tensor(
indices=edge_index.t(),
values=edge_weight,
size=(num_nodes, num_nodes),
)
q = min(self.svd_q, sparse_adj.shape[1]) # 确保 q <= 矩阵的列数
u, s, v = torch.svd_lowrank(sparse_adj, q=q)
recon_adj = (u @ torch.diag(s)) @ v.T
sparse_adj = recon_adj.to_sparse()
edge_index_aug = sparse_adj.indices()
edge_weight_aug = sparse_adj.values().unsqueeze(1)
x_aug = drop_feature(x=x, drop_prob=self.drop_feature_rate)
# x_aug = data.x
else:
raise NotImplementedError(f"Unknown augment type: {augment_type}")
if self.use_FCencoder:
x = self.proj(x)
x_aug = self.proj(x_aug)
if not self.used_edge_weight and augment_type == "svd":
edge_weight = None
edge_weight_aug = None
z_mean1, z_log_std1 = self._forward_through_layers(
x, edge_index, edge_weight, y
)
z_mean2, z_log_std2 = self._forward_through_layers(
x_aug, edge_index_aug, edge_weight_aug, y
)
return z_mean1, z_log_std1, z_mean2, z_log_std2
elif decoder_type == "graph":
edge_weight = (
torch.ones(edge_index.shape[1]).unsqueeze(1).to(edge_index.device)
if self.used_edge_weight
else None
)
if self.use_FCencoder:
x = self.proj(x)
z_mean1, z_log_std1 = self._forward_through_layers(
x, edge_index, edge_weight, y
)
return z_mean1, z_log_std1
else:
raise NotImplementedError(f"Unknown decoder type: {decoder_type}")
[docs]
def _forward_through_layers(self, x, edge_index, edge_weight, y):
"""
Helper function to pass the input features through multiple GAT layers and apply normalization.
Parameters
----------
x : torch.Tensor
Node features (shape: [num_nodes, feature_dim]).
edge_index : torch.Tensor
Edge indices (shape: [2, num_edges]).
edge_weight : torch.Tensor or None
Edge weights for weighted GAT layers (optional, shape: [num_edges]).
y : torch.Tensor
Domain labels used for domain-specific batch normalization (DSBN) (shape: [num_nodes]).
Returns
-------
Tuple of torch.Tensor
- z_mean: Mean of the latent space after passing through the GAT layers (shape: [num_nodes, latent_dim]).
- z_log_std: Log standard deviation of the latent space (shape: [num_nodes, latent_dim]).
"""
for idx, layer in enumerate(self.layers):
x, _ = layer(
x.float(),
edge_index,
edge_attr=edge_weight.float() if self.used_edge_weight else None,
return_attention_weights=True,
)
if self.used_DSBN and len(x) > 1:
norm_layer = self.norm_layers[idx]
if isinstance(norm_layer, DSBatchNorm):
x = norm_layer(x, y)
else:
x = norm_layer(x)
x = F.relu(x)
z_mean, _ = self.conv_mean(
x.float(),
edge_index,
edge_attr=edge_weight.float() if self.used_edge_weight else None,
return_attention_weights=True,
)
z_log_std, _ = self.conv_log_std(
x.float(),
edge_index,
edge_attr=edge_weight.float() if self.used_edge_weight else None,
return_attention_weights=True,
)
return z_mean, z_log_std
### GCN encoder
[docs]
class GCNEncoder(nn.Module):
"""
The GCNEncoder class implements a Graph Convolutional Network (GCN) encoder with multiple layers,
normalization, and optional fully connected (FC) encoder. It supports different types of augmentations
for omics and graph data.
Methods
----------
forward(data, decoder_type, augment_type)
Performs the forward pass through the GCN encoder, applying either omics or graph decoding,
with optional augmentation.
_forward_through_layers(x, edge_index, edge_weight, y)
Helper function to pass the input features through multiple GCN layers and apply normalization.
_apply_normalization(x, y, idx)
Applies batch normalization or domain-specific batch normalization (DSBN) based on the model's configuration.
Parameters
----------
in_channels : int
Number of input feature dimensions (length of each node's feature vector).
hidden_dims : list[int]
List of output dimensions for each hidden layer in the GCN.
latent_dim : int
Dimension of the latent feature representation produced by the encoder.
use_FCencoder : bool
Whether to use a fully connected encoder (FC encoder) before the GCN layers.
drop_feature_rate : float
Dropout rate for node features during augmentation.
drop_edge_rate : float
Dropout rate for edges during augmentation.
svd_q : int
Rank for the low-rank SVD approximation used in augmentations.
dropout : float, optional
Dropout rate applied to GCN layers, default is 0.2.
num_domains : int, optional
Number of domains for domain-specific batch normalization (DSBN). If `1`, regular batch normalization is used. Default is 1.
used_edge_weight : bool, optional
Whether to use edge weights in the GCN layers. Default is False.
used_DSBN : bool, optional
Whether to use domain-specific batch normalization (DSBN). Default is False.
"""
[docs]
def __init__(
self,
in_channels,
hidden_dims,
latent_dim,
use_FCencoder,
drop_feature_rate,
drop_edge_rate,
svd_q,
dropout=0.2,
num_domains=1,
used_edge_weight=False,
used_DSBN=False,
):
"""
Initializes the GCNEncoder with configurable options for feature projection, dropout, and domain-specific batch normalization (DSBN).
"""
super(GCNEncoder, self).__init__()
self.use_FCencoder = use_FCencoder
self.drop_feature_rate = drop_feature_rate
self.drop_edge_rate = drop_edge_rate
self.svd_q = svd_q
self.used_DSBN = used_DSBN
self.used_edge_weight = used_edge_weight
self.norm = nn.ModuleList()
# Apply feature projection if specified
if self.use_FCencoder:
encoder_dim = hidden_dims[0] * 2
self.proj = Projection(in_channels, encoder_dim)
current_dim = encoder_dim
else:
current_dim = in_channels
# Initialize GCN layers
self.gcn_layers = nn.ModuleList()
total_layers = [current_dim] + hidden_dims
for i in range(len(total_layers) - 1):
self.gcn_layers.append(
GCNConv(total_layers[i], total_layers[i + 1], dropout=dropout)
)
# Initialize mean and log standard deviation layers
self.gcn_mu = GCNConv(hidden_dims[-1], latent_dim, dropout=dropout)
self.gcn_logvar = GCNConv(hidden_dims[-1], latent_dim, dropout=dropout)
# Initialize normalization layers
for current_dim in hidden_dims:
if num_domains == 1:
norm = nn.BatchNorm1d(current_dim)
else:
norm = DSBatchNorm(current_dim, num_domains)
self.norm.append(norm)
[docs]
def forward(self, data, decoder_type, augment_type=None):
"""
Forward pass through the GCN encoder, with optional augmentations such as dropout or SVD, and multiple decoding options.
Parameters
----------
data : Data
PyTorch Geometric Data object containing node features, edge index, and other graph-related information.
decoder_type : str
Specifies the type of decoder to use, either 'omics' for gene expression data or 'graph' for graph structure.
augment_type : str, optional
Specifies the type of augmentation to apply, either 'dropout' or 'svd'. If None, no augmentation is applied.
Returns
-------
Tuple of torch.Tensor
For 'omics' decoder type: (z_mean1, z_log_std1, z_mean2, z_log_std2), where:
- z_mean1: Mean of the latent space from original input.
- z_log_std1: Log standard deviation of the latent space from original input.
- z_mean2: Mean of the latent space from augmented input.
- z_log_std2: Log standard deviation of the latent space from augmented input.
For 'graph' decoder type: (z_mean1, z_log_std1), where:
- z_mean1: Mean of the latent space.
- z_log_std1: Log standard deviation of the latent space.
"""
x, edge_index_all, y = data.x, data.edge_index, data.y
edge_index = edge_index_all[:, :2]
if decoder_type == "omics":
if augment_type is not None and augment_type == "dropout":
edge_weight = edge_index_all[:, 2] if self.used_edge_weight else None
x_aug = drop_feature(x=x, drop_prob=self.drop_feature_rate)
edge_index_aug = dropout_adj(edge_index, p=self.drop_edge_rate)[0]
edge_weight_aug = edge_weight
elif augment_type is not None and augment_type == "svd":
edge_weight = edge_index_all[:, 2].float()
num_nodes = int(edge_index.max().item()) + 1
sparse_adj = torch.sparse_coo_tensor(
indices=edge_index.t(),
values=edge_weight,
size=(num_nodes, num_nodes),
)
q = min(self.svd_q, sparse_adj.shape[1]) # 确保 q <= 矩阵的列数
u, s, v = torch.svd_lowrank(sparse_adj, q=q)
recon_adj = (u @ torch.diag(s)) @ v.T
sparse_adj = recon_adj.to_sparse()
edge_index_aug = sparse_adj.indices()
edge_weight_aug = sparse_adj.values().unsqueeze(1)
# x_aug = drop_feature(x=x, drop_prob=self.drop_feature_rate)
x_aug = data.x
else:
raise NotImplementedError(f"Unknown augment type: {augment_type}")
if self.use_FCencoder:
x = self.proj(x)
x_aug = self.proj(x_aug)
if not self.used_edge_weight and augment_type == "svd":
edge_weight = None
edge_weight_aug = None
z_mean1, z_log_std1 = self._forward_through_layers(
x, edge_index, edge_weight, y
)
z_mean2, z_log_std2 = self._forward_through_layers(
x_aug, edge_index_aug, edge_weight_aug, y
)
return z_mean1, z_log_std1, z_mean2, z_log_std2
elif decoder_type == "graph":
edge_weight = (
torch.ones(edge_index.shape[1]).unsqueeze(1).to(edge_index.device)
if self.used_edge_weight
else None
)
if self.use_FCencoder:
x = self.proj(x)
z_mean1, z_log_std1 = self._forward_through_layers(
x, edge_index, edge_weight, y
)
return z_mean1, z_log_std1
else:
raise NotImplementedError(f"Unknown decoder type: {decoder_type}")
[docs]
def _forward_through_layers(self, x, edge_index, edge_weight, y):
"""
Helper function that passes the node features through GCN layers and applies normalization.
Parameters
----------
x : torch.Tensor
Node features (shape: [num_nodes, feature_dim]).
edge_index : torch.Tensor
Edge indices (shape: [2, num_edges]).
edge_weight : torch.Tensor or None
Edge weights for weighted GCN layers (optional, shape: [num_edges]).
y : torch.Tensor
Domain labels used for domain-specific batch normalization (DSBN) (shape: [num_nodes]).
Returns
-------
Tuple of torch.Tensor
- z_mean: Mean of the latent space after passing through the GCN layers (shape: [num_nodes, latent_dim]).
- z_log_std: Log standard deviation of the latent space (shape: [num_nodes, latent_dim]).
"""
for idx, layer in enumerate(self.gcn_layers):
x = layer(x, edge_index, edge_weight=edge_weight)
x = self._apply_normalization(x, y, idx)
x = F.relu(x)
z_mean = self.gcn_mu(x, edge_index, edge_weight=edge_weight)
z_log_std = self.gcn_logvar(x, edge_index, edge_weight=edge_weight)
return z_mean, z_log_std
[docs]
def _apply_normalization(self, x, y, idx):
"""
Applies batch normalization or domain-specific batch normalization (DSBN) based on the model's configuration.
Parameters
----------
x : torch.Tensor
Node features (shape: [num_nodes, feature_dim]).
y : torch.Tensor
Domain labels for DSBN (shape: [num_nodes]).
idx : int
Index of the current layer for normalization.
Returns
-------
torch.Tensor
Normalized node features after applying either batch normalization or DSBN (shape: [num_nodes, feature_dim]).
"""
if self.used_DSBN and len(x) > 1:
norm_layer = self.norm[idx]
if isinstance(norm_layer, DSBatchNorm):
x = norm_layer(x, y)
else:
x = norm_layer(x)
return x