import numpy as np
from typing import List, Literal, Optional, Tuple, Union
import torch
import torch.nn as nn
from torch import Tensor
from torch.nn.functional import normalize
from torch_geometric.data import Data
from torch_geometric.utils import to_dense_adj
from torch_geometric.utils import negative_sampling, remove_self_loops, add_self_loops
from torch_geometric.nn import GAE, InnerProductDecoder
from .utils import extract_subgraph, to_float_tensor
from ..nn.decoders import GATDecoder, GCNDecoder, CosineSimGraphDecoder
from .loss import (
compute_omics_recon_mse_loss,
compute_omics_recon_mmd_loss,
compute_edge_recon_loss,
compute_kl_reg_loss,
compute_contrastive_instanceloss,
compute_contrastive_clusterloss,
compute_adj_recon_loss,
)
# Based on VGAE class in PyTorch Geometric
[docs]
class GNNModelVAE(GAE):
"""
Garfield model class. This class contains the implementation of GNNModel Variational Auto-encoder.
Parameters
----------
encoder : nn.Module
The encoder module used in the variational graph autoencoder. 'GAT' or 'GCN'.
bottle_neck_neurons : int
Number of neurons in the bottleneck layer representing the latent dimension.
hidden_dims : int
Number of hidden dimensions for the encoder.
feature_dim : int
Number of feature dimensions in the input data.
num_heads : int
Number of attention heads used in the GAT encoder.
dropout : float
Dropout rate used in the encoder and decoder.
concat : bool
Whether to concatenate outputs of different attention heads.
n_domain : int
Number of domains for domain-specific batch normalization (DSBN).
used_edge_weight : bool
Whether to use edge weights in the graph convolution operation.
used_DSBN : bool
Whether to use domain-specific batch normalization (DSBN).
conv_type : str
Type of graph convolution to use, e.g., 'GAT', 'GATv2Conv', 'GCN'.
gnn_layer : int, optional
Number of layers in the GNN encoder. Default is 2.
cluster_num : int, optional
Number of clusters for the clustering layer. Default is 20.
include_edge_recon_loss : bool, optional
Whether to include edge reconstruction loss in the model. Default is True.
include_gene_expr_recon_loss : bool, optional
Whether to include gene expression reconstruction loss in the model. Default is True.
used_mmd : bool, optional
Whether to use MMD (Maximum Mean Discrepancy) loss for domain adaptation. Default is False.
"""
[docs]
def __init__(
self,
encoder,
bottle_neck_neurons,
hidden_dims,
feature_dim,
num_heads,
dropout,
concat,
n_domain,
used_edge_weight,
used_DSBN,
conv_type,
gnn_layer=2,
cluster_num=20,
include_edge_recon_loss=True,
include_gene_expr_recon_loss=True,
used_mmd=False,
):
super(GNNModelVAE, self).__init__(encoder)
# model configurations
self.encoder = encoder
self.latent = bottle_neck_neurons
self.hidden_dims = hidden_dims
self.feature_dim = feature_dim
self.num_heads = num_heads
self.dropout = dropout
self.concat = concat
self.n_domain = n_domain
self.used_edge_weight = used_edge_weight
self.used_DSBN = used_DSBN
self.include_edge_recon_loss = include_edge_recon_loss
self.include_gene_expr_recon_loss = include_gene_expr_recon_loss
self.used_mmd = used_mmd
self.conv_type = conv_type
self.cluster_num = cluster_num
assert self.conv_type in [
"GAT",
"GATv2Conv",
"GCN",
], 'Convolution must be "GCN", "GAT" or "GATv2Conv".'
self.gnn_layer = gnn_layer
# 使用 Xavier 初始化权重
self.eps_weight = nn.Parameter(
nn.init.xavier_uniform_(torch.empty(self.latent, self.latent))
)
self.eps_bias = nn.Parameter(torch.zeros(self.latent))
self.instance_projector = nn.Sequential(
nn.Linear(self.latent, self.latent),
nn.LayerNorm(self.latent, elementwise_affine=False),
nn.ReLU(),
nn.Linear(self.latent, self.feature_dim),
)
self.cluster_projector = nn.Sequential(
nn.Linear(self.latent, self.latent),
nn.LayerNorm(self.latent, elementwise_affine=False),
nn.ReLU(),
nn.Linear(self.latent, self.cluster_num),
nn.Softmax(dim=1),
)
# Initialize graph decoder module
self.graph_decoder = CosineSimGraphDecoder(dropout_rate=self.dropout)
# Initialize adj decoder module
self.adj_decoder = InnerProductDecoder()
## 重构表达谱
if self.conv_type in ["GAT", "GATv2Conv"]:
self.GAT_decoder = GATDecoder(
in_channels=self.latent,
hidden_dims=self.hidden_dims,
out_channels=self.feature_dim,
conv_type=self.conv_type,
num_heads=self.num_heads,
dropout=self.dropout,
concat=self.concat,
num_domains=self.n_domain, # DSBN
used_edge_weight=self.used_edge_weight,
used_DSBN=self.used_DSBN,
)
elif self.conv_type == "GCN":
self.GCN_decoder = GCNDecoder(
in_channels=self.latent,
hidden_dims=self.hidden_dims,
out_channels=self.feature_dim,
dropout=self.dropout,
num_domains=self.n_domain, # DSBN
used_edge_weight=self.used_edge_weight,
used_DSBN=self.used_DSBN,
)
else:
raise NotImplementedError("Unknown GNN-Operator.")
def reparameterize(self, mu: Tensor, logstd: Tensor, eps=None) -> Tensor:
"""
Applies the reparameterization trick to sample a latent vector from the latent distribution during training.
Parameters
----------
mu : torch.Tensor
Mean of the latent distribution (shape: [batch_size, latent_dim]).
logstd : torch.Tensor
Logarithm of the standard deviation of the latent distribution (shape: [batch_size, latent_dim]).
eps : torch.Tensor, optional
Noise tensor used for sampling. If not provided, a standard normal distribution will be used (shape: [batch_size, latent_dim]).
Returns
-------
torch.Tensor
A reparameterized latent vector sampled from the distribution (shape: [batch_size, latent_dim]).
"""
if self.training:
if eps is not None:
return mu + eps * torch.randn_like(logstd) * torch.exp(logstd)
else:
return mu + torch.randn_like(logstd) * torch.exp(logstd)
else:
return mu
def forward(self, data_batch, decoder_type, augment_type):
"""
Processes the input data through the encoder to obtain the latent representations and uses the decoder to reconstruct features or edges, depending on the task.
Parameters
----------
data_batch : Data
A PyTorch Geometric Data object containing node features, edge information, and any other relevant data.
decoder_type : str
Specifies which type of decoder to use, either 'omics' for gene expression data or 'graph' for edge reconstruction tasks.
augment_type : str
Specifies the type of data augmentation used during encoding, e.g., 'svd' for singular value decomposition or 'dropout' for regular dropout.
Returns
-------
dict
A dictionary containing the following keys:
- "recon_features" or "edge_recon_logits": Reconstructed features or edge logits, depending on the decoder type.
- "z": The latent representation of the input data.
- "mu": The mean of the latent distribution.
- "logstd": The log standard deviation of the latent distribution.
- "truth_x": Ground truth input features (for omics tasks).
- "truth_y": Ground truth labels (for MMD or classification tasks).
"""
# Get index of sampled nodes for current batch (neighbors of sampled
# nodes are also part of the batch for message passing layers but
# should be excluded in backpropagation)
if decoder_type == "omics":
# ´data_batch´ will be a node batch and first node_batch_size
# elements are the sampled nodes, leading to a dim of ´batch_idx´ of
# ´node_batch_size´
# batch_idx = slice(None, data_batch.batch_size)
all_mu1 = []
all_mu2 = []
for _ in range(self.gnn_layer):
data_batch.x = to_float_tensor(data_batch.x)
encoder_outputs = self.encoder(data_batch, decoder_type, augment_type)
mu1 = encoder_outputs[0] # [batch_idx, :]
mu2 = encoder_outputs[2] # [batch_idx, :]
all_mu1.append(mu1)
all_mu2.append(mu2)
mean1 = torch.stack(all_mu1).mean(dim=0) # sum
logstd1 = torch.matmul(mean1, self.eps_weight) + self.eps_bias
z1 = self.reparameterize(mean1, logstd1)
mean2 = torch.stack(all_mu2).mean(dim=0)
logstd2 = torch.matmul(mean2, self.eps_weight) + self.eps_bias
z2 = self.reparameterize(mean2, logstd2)
z_1 = normalize(self.instance_projector(z1), dim=1)
z_2 = normalize(self.instance_projector(z2), dim=1)
c_1 = self.cluster_projector(z1)
c_2 = self.cluster_projector(z2)
## 重构邻接矩阵
pos_adj = self.adj_decoder(z1, data_batch.edge_index, sigmoid=True)
# Do not include self-loops in negative samples
pos_edge_index = data_batch.edge_index
pos_edge_index, _ = remove_self_loops(pos_edge_index)
pos_edge_index, _ = add_self_loops(pos_edge_index)
# negative_sampling
neg_edge_index = negative_sampling(pos_edge_index, z1.size(0))
neg_adj = self.adj_decoder(z1, neg_edge_index.long(), sigmoid=True)
## 重构表达矩阵
output = {}
# with torch.no_grad(): ## TODO
# data_batch = extract_subgraph(data_batch, batch_idx)
if self.conv_type in ["GAT", "GATv2Conv"]:
recon_features = self.GAT_decoder(z1, data_batch)
else:
recon_features = self.GCN_decoder(z1, data_batch)
output["truth_x"] = to_float_tensor(data_batch.x)
output["truth_y"] = data_batch.y
output["recon_features"] = to_float_tensor(recon_features)
output["z"] = z1
output["z_1"] = z_1
output["z_2"] = z_2
output["c_1"] = c_1
output["c_2"] = c_2
output["mu"] = mean1
output["logstd"] = logstd1
output["pos_adj"] = pos_adj
output["neg_adj"] = neg_adj
return output
elif decoder_type == "graph":
# ´data_batch´ will be an edge batch with sampled positive and
# negative edges of size ´edge_batch_size´ respectively. Each edge
# has a source and target node, leading to a dim of ´batch_idx´ of
# 4 * ´edge_batch_size´
batch_idx = torch.cat(
(data_batch.edge_label_index[0], data_batch.edge_label_index[1]), 0
)
all_mu1 = []
for _ in range(self.gnn_layer):
data_batch.x = to_float_tensor(data_batch.x)
encoder_outputs = self.encoder(data_batch, decoder_type, augment_type)
mu1 = encoder_outputs[0][batch_idx, :]
all_mu1.append(mu1)
mean1 = torch.stack(all_mu1).mean(dim=0) # sum
logstd1 = torch.matmul(mean1, self.eps_weight) + self.eps_bias
z1 = self.reparameterize(mean1, logstd1)
## 重构表达矩阵
output = {}
# Store edge labels in output for loss computation
output["edge_recon_labels"] = data_batch.edge_label
# Use decoder to get the edge reconstruction logits
output["edge_recon_logits"] = self.graph_decoder(z1)
output["z"] = z1
output["mu"] = mean1
output["logstd"] = logstd1
return output
def loss(
self,
edge_model_output: dict,
node_model_output: dict,
lambda_edge_recon,
lambda_gene_expr_recon,
lambda_latent_adj_recon_loss,
lambda_latent_contrastive_instanceloss,
lambda_latent_contrastive_clusterloss,
lambda_omics_recon_mmd_loss,
) -> dict:
"""
Computes the total loss for the model by combining different loss components such as KL divergence, edge reconstruction loss, gene expression reconstruction loss, and contrastive losses.
Parameters
----------
edge_model_output : dict
A dictionary containing outputs from the forward pass for edge reconstruction, including edge logits and latent variables.
node_model_output : dict
A dictionary containing outputs from the forward pass for node reconstruction, including gene expression reconstruction and latent variables.
lambda_edge_recon : float
A scaling factor to adjust the contribution of edge reconstruction loss.
lambda_gene_expr_recon : float
A scaling factor to adjust the contribution of gene expression reconstruction loss.
lambda_latent_adj_recon_loss : float
A scaling factor to adjust the contribution of adjacency reconstruction loss in the latent space.
lambda_latent_contrastive_instanceloss : float
A scaling factor to adjust the contribution of instance-level contrastive loss between different latent representations.
lambda_latent_contrastive_clusterloss : float
A scaling factor to adjust the contribution of cluster-level contrastive loss between different clusters in the latent space.
lambda_omics_recon_mmd_loss : float
A scaling factor to adjust the contribution of Maximum Mean Discrepancy (MMD) loss in omics data reconstruction.
Returns
-------
dict
A dictionary containing individual loss terms and the total loss:
- "kl_reg_loss": KL divergence loss between the latent distributions.
- "edge_recon_loss": Binary cross-entropy loss for edge reconstruction (if applicable).
- "gene_expr_recon_loss": Mean squared error loss for gene expression reconstruction (if applicable).
- "lambda_latent_contrastive_instanceloss": Contrastive loss between instance-level latent vectors.
- "lambda_latent_contrastive_clusterloss": Contrastive loss between clusters in the latent space.
- "gene_expr_mmd_loss": MMD loss for omics data (if applicable).
- "global_loss": Sum of all the individual losses used for model optimization.
- "optim_loss": Sum of the losses used for backpropagation.
"""
loss_dict = {}
# 1. Compute Kullback-Leibler divergence loss for edge and node batch
loss_dict["kl_reg_loss"] = compute_kl_reg_loss(
mu=node_model_output["mu"], logstd=node_model_output["logstd"]
) # * 1 / node_model_output["mu"].size(0)
loss_dict["kl_reg_loss"] += compute_kl_reg_loss(
mu=edge_model_output["mu"], logstd=edge_model_output["logstd"]
) # * 1 / edge_model_output["mu"].size(0)
# 2. Compute edge reconstruction binary cross entropy loss for edge batch
loss_dict["edge_recon_loss"] = (
(
lambda_edge_recon
* compute_edge_recon_loss(
edge_recon_logits=edge_model_output["edge_recon_logits"],
edge_recon_labels=edge_model_output["edge_recon_labels"],
)
)
* edge_model_output["mu"].size(0)
/ 10
)
# 3. Compute gene expression reconstruction with MSE loss for node batch
loss_dict["gene_expr_recon_loss"] = (
lambda_gene_expr_recon
* compute_omics_recon_mse_loss(
recon_x=node_model_output["recon_features"],
x=node_model_output["truth_x"],
)
) * 20000 # node_model_output['truth_x'].size(-1)
# 4. compute reconstructed adj loss through node feedforward
loss_dict["lambda_latent_adj_recon_loss"] = (
compute_adj_recon_loss(
node_model_output["pos_adj"],
node_model_output["neg_adj"],
lambda_latent_adj_recon_loss,
)
* 100
) # * node_model_output['truth_x'].size(-1)
# 5. compute Contrastive instance losses
loss_dict[
"lambda_latent_contrastive_instanceloss"
] = compute_contrastive_instanceloss(
node_model_output["z_1"],
node_model_output["z_2"],
lambda_latent_contrastive_instanceloss,
)
# 6. compute Contrastive cluster losses
loss_dict[
"lambda_latent_contrastive_clusterloss"
] = compute_contrastive_clusterloss(
node_model_output["c_1"],
node_model_output["c_2"],
self.cluster_num,
lambda_latent_contrastive_clusterloss,
)
# 7. compute MMD loss
if self.used_mmd:
cell_batch = node_model_output["truth_y"]
device = cell_batch.device
cell_batch = cell_batch.detach().cpu()
unique_groups, group_indices = np.unique(cell_batch, return_inverse=True)
grouped_z_cell = {
group: node_model_output["z"][group_indices == i]
for i, group in enumerate(unique_groups)
}
group_labels = list(unique_groups)
num_groups = len(group_labels)
loss_dict["gene_expr_mmd_loss"] = torch.tensor(0, dtype=torch.float).to(
device
)
for i in range(num_groups):
for j in range(i + 1, num_groups):
z_i = grouped_z_cell[group_labels[i]]
z_j = grouped_z_cell[group_labels[j]]
mmd_loss_tmp = compute_omics_recon_mmd_loss(
z_i, z_j
) * node_model_output["z"].size(0)
loss_dict["gene_expr_mmd_loss"] += (
mmd_loss_tmp * lambda_omics_recon_mmd_loss
)
# Compute optimization loss used for backpropagation as well as global
# loss used for early stopping of model training and best model saving
loss_dict["global_loss"] = 0
loss_dict["optim_loss"] = 0
loss_dict["global_loss"] += loss_dict["kl_reg_loss"]
loss_dict["optim_loss"] += loss_dict["kl_reg_loss"]
if self.include_edge_recon_loss:
loss_dict["global_loss"] += loss_dict["edge_recon_loss"]
loss_dict["optim_loss"] += loss_dict["edge_recon_loss"]
if self.include_gene_expr_recon_loss:
loss_dict["global_loss"] += loss_dict["gene_expr_recon_loss"]
loss_dict["optim_loss"] += loss_dict["gene_expr_recon_loss"]
loss_dict["global_loss"] += loss_dict["lambda_latent_adj_recon_loss"]
loss_dict["optim_loss"] += loss_dict["lambda_latent_adj_recon_loss"]
loss_dict["global_loss"] += loss_dict[
"lambda_latent_contrastive_instanceloss"
]
loss_dict["optim_loss"] += loss_dict[
"lambda_latent_contrastive_instanceloss"
]
loss_dict["global_loss"] += loss_dict[
"lambda_latent_contrastive_clusterloss"
]
loss_dict["optim_loss"] += loss_dict[
"lambda_latent_contrastive_clusterloss"
]
if self.used_mmd:
loss_dict["global_loss"] += loss_dict["gene_expr_mmd_loss"]
loss_dict["optim_loss"] += loss_dict["gene_expr_mmd_loss"]
return loss_dict
def get_latent_representation(
self,
node_batch: Data,
augment_type: Literal["svd", "dropout"] = "svd",
return_mu_std: bool = False,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
Encodes the input data into latent space, either returning the latent features (z) or
the distribution parameters (mu and std) based on the input option.
Parameters
----------
node_batch : Data
A PyTorch Geometric Data object containing features and graph structure for the node-level batch.
augment_type : str, optional
Specifies the type of augmentation used in the encoder, e.g., 'svd' (default) or 'dropout'.
return_mu_std : bool, optional
If True, the function returns the mean (mu) and standard deviation (std) of the latent distribution.
Otherwise, it returns the reparameterized latent features (z). Default is False.
Returns
-------
Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
- If `return_mu_std` is False, it returns the reparameterized latent features (z) (shape: [batch_size, latent_dim]).
- If `return_mu_std` is True, it returns a tuple of the mean (mu) and standard deviation (std) of the latent distribution (each of shape: [batch_size, latent_dim]).
"""
# Get latent distribution parameters
encoder_outputs = self.encoder(
node_batch, augment_type=augment_type, decoder_type="omics"
) # z_mean1, z_log_std1, z_mean2, z_log_std2
mu = encoder_outputs[0][: node_batch.batch_size, :]
logstd = encoder_outputs[1][: node_batch.batch_size, :]
if return_mu_std:
std = torch.exp(logstd)
return mu, std
else:
z = self.reparameterize(mu, logstd)
return z