Source code for Garfield.model.Garfield

import warnings
from types import SimpleNamespace
from typing import Literal, List, Optional, Tuple, Union

import torch
import numpy as np
from anndata import AnnData

from .._settings import settings
from .basemodelmixin import BaseModelMixin
from ..data.dataprocessors import prepare_data
from ..preprocessing.preprocess import DataProcess
from ..data.dataloaders import initialize_dataloaders
from .utils import weighted_knn_trainer, weighted_knn_transfer
from ..nn.encoders import GATEncoder, GCNEncoder
from ..modules.GNNModelVAE import GNNModelVAE
from ..trainer.trainer import GarfieldTrainer


[docs] class Garfield(torch.nn.Module, BaseModelMixin): """ Garfield: Graph-based Contrastive Learning enable Fast Single-Cell Embedding Parameters ---------- adata_list : list List of AnnData objects containing data from multiple batches or samples. profile : str Specifies the data profile type (e.g., 'RNA', 'ATAC', 'ADT', 'multi-modal', 'spatial'). data_type : str Type of the multi-omics dataset (e.g., Paired, UnPaired) for preprocessing. sub_data_type : list[str] List of data types for multi-modal datasets (e.g., ['rna', 'atac'] or ['rna', 'adt']). sample_col : str Column in the dataset that indicates batch or sample identifiers (default: 'batch'). weight : float or None Weighting factor that determines the contribution of different modalities or types of graphs in multi-omics or spatial data. - For non-spatial single-cell multi-omics data (e.g., RNA + ATAC), `weight` specifies the contribution of the graph constructed from RNA data. The remaining (1 - weight) represents the contribution from the other modality. - For spatial single-modality data, `weight` refers to the contribution of the graph constructed from the physical spatial information, while (1 - weight) reflects the contribution from the molecular graph. graph_const_method : str Method for constructing the graph (e.g., 'mu_std', 'Radius', 'KNN', 'Squidpy'). genome : str Reference genome to use during preprocessing (e.g., 'mm10', 'mm9', 'hg38', 'hg19'). use_gene_weight : bool Whether to apply gene weights in the preprocessing step. use_top_pcs : bool Whether to use the top principal components during gene score preprocessing step. used_hvg : bool Whether to use highly variable genes (HVGs) for analysis. min_features : int Minimum number of features required for a cell to be included in the dataset. min_cells : int Minimum number of cells required for a feature to be retained in the dataset. keep_mt : bool Whether to retain mitochondrial genes in the analysis. target_sum : float Target sum used for normalization (e.g., 1e4 for counts per cell). rna_n_top_features : int Number of top features to retain for RNA datasets (e.g., 3000). atac_n_top_features : int Number of top features to retain for ATAC datasets (e.g., 10000). n_components : int Number of components to use for dimensionality reduction (e.g., PCA). n_neighbors : int Number of neighbors to use in graph-based algorithms (e.g., KNN). metric : str Distance metric used during graph construction (e.g., 'correlation', 'euclidean'). svd_solver : str Solver for singular value decomposition (SVD), such as 'arpack' or 'randomized'. used_pca_feat: bool Whether to use PCA or LSI features for the encoder. adj_key : str Key in the AnnData object that holds the adjacency matrix. edge_val_ratio : float Ratio of edges to use for validation in edge-level tasks. edge_test_ratio : float Ratio of edges to use for testing in edge-level tasks. node_val_ratio : float Ratio of nodes to use for validation in node-level tasks. node_test_ratio : float Ratio of nodes to use for testing in node-level tasks. augment_type : str Type of augmentation to use (e.g., 'dropout', 'svd'). svd_q : int Rank for the low-rank SVD approximation. use_FCencoder : bool Whether to use a fully connected encoder before the graph layers. hidden_dims : list[int] List of hidden layer dimensions for the encoder. bottle_neck_neurons : int Number of neurons in the bottleneck (latent) layer. num_heads : int Number of attention heads for each graph attention layer. dropout : float Dropout rate applied during training. concat : bool Whether to concatenate attention heads or not. drop_feature_rate : float Dropout rate applied to node features. drop_edge_rate : float Dropout rate applied to edges during augmentation. used_edge_weight : bool Whether to use edge weights in the graph layers. used_DSBN : bool Whether to use domain-specific batch normalization. conv_type : str Type of graph convolution to use ('GATv2Conv', 'GAT', 'GCN'). gnn_layer : int Number of times the encoder is repeated in the forward pass, not the number of GNN layers. cluster_num : int Number of clusters for latent feature clustering. num_neighbors : int Number of neighbors to sample for graph-based data loaders. loaders_n_hops : int Number of hops for neighbors during graph construction. edge_batch_size : int Batch size for edge-level tasks. node_batch_size : int Batch size for node-level tasks. include_edge_recon_loss : bool Whether to include edge reconstruction loss in the training objective. include_gene_expr_recon_loss : bool Whether to include gene expression reconstruction loss in the training objective. used_mmd : bool Whether to use maximum mean discrepancy (MMD) for domain adaptation. lambda_latent_contrastive_instanceloss : float Weight for the instance-level contrastive loss. lambda_latent_contrastive_clusterloss : float Weight for the cluster-level contrastive loss. lambda_gene_expr_recon : float Weight for the gene expression reconstruction loss. lambda_latent_adj_recon_loss : float Weight for the adjacency reconstruction loss. lambda_edge_recon : float Weight for the edge reconstruction loss. lambda_omics_recon_mmd_loss : float Weight for the MMD loss in omics reconstruction tasks. n_epochs : int Number of training epochs. n_epochs_no_edge_recon : int Number of epochs without edge reconstruction loss. learning_rate : float Learning rate for the optimizer. weight_decay : float Weight decay (L2 regularization) for the optimizer. gradient_clipping : float Maximum norm for gradient clipping. latent_key : str Key for storing latent features in the AnnData object. reload_best_model : bool Whether to reload the best model after training. use_early_stopping : bool Whether to use early stopping during training. early_stopping_kwargs : dict Arguments for configuring early stopping (e.g., patience, delta). monitor : bool Whether to print training progress. device_id: int Device ID for GPU training. seed : int Random seed for reproducibility. verbose : bool Whether to display detailed logs during training. """
[docs] def __init__(self, gf_params): super(Garfield, self).__init__() if gf_params is None: gf_params = settings.gf_params.copy() else: assert isinstance(gf_params, dict), "`gf_params` must be dict" self.args = SimpleNamespace(**gf_params) # data preprocessing parameters self.adata_list_ = self.args.adata_list self.profile_ = self.args.profile self.data_type_ = self.args.data_type self.sub_data_type_ = self.args.sub_data_type self.sample_col_ = self.args.sample_col self.weight_ = self.args.weight self.graph_const_method_ = self.args.graph_const_method self.genome_ = self.args.genome self.use_gene_weight_ = self.args.use_gene_weight self.user_cache_path_ = self.args.user_cache_path self.use_top_pcs_ = self.args.use_top_pcs self.used_hvg_ = self.args.used_hvg self.min_features_ = self.args.min_features self.min_cells_ = self.args.min_cells self.keep_mt_ = self.args.keep_mt self.target_sum_ = self.args.target_sum self.rna_n_top_features_ = self.args.rna_n_top_features self.atac_n_top_features_ = self.args.atac_n_top_features self.n_components_ = self.args.n_components self.n_neighbors_ = self.args.n_neighbors self.metric_ = self.args.metric self.svd_solver_ = self.args.svd_solver # datasets self.used_pca_feat_ = self.args.used_pca_feat self.adj_key_ = self.args.adj_key # data split parameters self.edge_val_ratio_ = self.args.edge_val_ratio self.edge_test_ratio_ = self.args.edge_test_ratio self.node_val_ratio_ = self.args.node_val_ratio self.node_test_ratio_ = self.args.node_test_ratio # model parameters self.augment_type_ = self.args.augment_type self.svd_q_ = self.args.svd_q # if augment_type == 'svd' self.use_FCencoder_ = self.args.use_FCencoder self.hidden_dims_ = self.args.hidden_dims self.bottle_neck_neurons_ = self.args.bottle_neck_neurons self.num_heads_ = self.args.num_heads self.dropout_ = self.args.dropout self.concat_ = self.args.concat self.drop_feature_rate_ = self.args.drop_feature_rate self.drop_edge_rate_ = self.args.drop_edge_rate self.used_edge_weight_ = self.args.used_edge_weight self.used_DSBN_ = self.args.used_DSBN # self.n_domain_ = self.args.num_classes self.conv_type_ = self.args.conv_type self.gnn_layer_ = self.args.gnn_layer self.cluster_num_ = self.args.cluster_num # data loader parameters self.num_neighbors_ = self.args.num_neighbors self.loaders_n_hops_ = self.args.loaders_n_hops self.edge_batch_size_ = self.args.edge_batch_size self.node_batch_size_ = self.args.node_batch_size # loss parameters self.include_edge_recon_loss_ = self.args.include_edge_recon_loss self.include_gene_expr_recon_loss_ = self.args.include_gene_expr_recon_loss self.used_mmd_ = self.args.used_mmd self.lambda_latent_contrastive_instanceloss_ = ( self.args.lambda_latent_contrastive_instanceloss ) self.lambda_latent_contrastive_clusterloss_ = ( self.args.lambda_latent_contrastive_clusterloss ) self.lambda_gene_expr_recon_ = self.args.lambda_gene_expr_recon self.lambda_latent_adj_recon_loss_ = self.args.lambda_latent_adj_recon_loss self.lambda_edge_recon_ = self.args.lambda_edge_recon self.lambda_omics_recon_mmd_loss_ = self.args.lambda_omics_recon_mmd_loss # train parameters self.n_epochs_ = self.args.n_epochs self.n_epochs_no_edge_recon_ = self.args.n_epochs_no_edge_recon self.learning_rate_ = self.args.learning_rate self.weight_decay_ = self.args.weight_decay self.gradient_clipping_ = self.args.gradient_clipping # other parameters self.latent_key_ = self.args.latent_key self.reload_best_model_ = self.args.reload_best_model self.use_early_stopping_ = self.args.use_early_stopping self.early_stopping_kwargs_ = self.args.early_stopping_kwargs self.monitor_ = self.args.monitor self.seed_ = self.args.seed self.device_id_ = self.args.device_id self.verbose_ = self.args.verbose # Set seed for reproducibility np.random.seed(self.seed_) if torch.cuda.is_available(): torch.cuda.manual_seed(self.seed_) torch.manual_seed(self.seed_) else: torch.manual_seed(self.seed_) # Data load and preprocessing print("--- DATA LOADING AND PREPROCESSING ---") self.adata = DataProcess( adata_list=self.adata_list_, profile=self.profile_, data_type=self.data_type_, sub_data_type=self.sub_data_type_, sample_col=self.sample_col_, genome=self.genome_, weight=self.weight_, graph_const_method=self.graph_const_method_, use_gene_weight=self.use_gene_weight_, user_cache_path=self.user_cache_path_, use_top_pcs=self.use_top_pcs_, used_hvg=self.used_hvg_, min_features=self.min_features_, min_cells=self.min_cells_, keep_mt=self.keep_mt_, target_sum=self.target_sum_, rna_n_top_features=self.rna_n_top_features_, atac_n_top_features=self.atac_n_top_features_, n_components=self.n_components_, n_neighbors=self.n_neighbors_, metric=self.metric_, svd_solver=self.svd_solver_, ) # set up model if not self.used_pca_feat_: self.num_features_ = self.adata.n_vars else: self.num_features_ = self.adata.obsm["feat"].shape[1] if self.sample_col_ is not None: try: self.n_domain_ = len(self.adata.obs[self.sample_col_].unique()) except KeyError: self.n_domain_ = len(self.adata.obs["rna:" + self.sample_col_].unique()) else: self.n_domain_ = None self.setup_layers()
def setup_layers(self): """ Creating the layers. """ ## 选择 encoder ## 设定参数 if self.conv_type_ in ["GAT", "GATv2Conv"]: encoder = GATEncoder( in_channels=self.num_features_, hidden_dims=self.hidden_dims_, latent_dim=self.bottle_neck_neurons_, conv_type=self.conv_type_, use_FCencoder=self.use_FCencoder_, num_heads=self.num_heads_, dropout=self.dropout_, concat=self.concat_, num_domains=1, # batch normalization drop_feature_rate=self.drop_feature_rate_, drop_edge_rate=self.drop_edge_rate_, used_edge_weight=self.used_edge_weight_, svd_q=self.svd_q_, used_DSBN=self.used_DSBN_, ) else: encoder = GCNEncoder( in_channels=self.num_features_, hidden_dims=self.hidden_dims_, latent_dim=self.bottle_neck_neurons_, use_FCencoder=self.use_FCencoder_, dropout=self.dropout_, num_domains=1, # batch normalization drop_feature_rate=self.drop_feature_rate_, drop_edge_rate=self.drop_edge_rate_, used_edge_weight=self.used_edge_weight_, svd_q=self.svd_q_, used_DSBN=self.used_DSBN_, ) ## GCNModelVAE self.model = GNNModelVAE( encoder=encoder, bottle_neck_neurons=self.bottle_neck_neurons_, hidden_dims=self.hidden_dims_, feature_dim=self.num_features_, num_heads=self.num_heads_, dropout=self.dropout_, concat=self.concat_, n_domain=self.n_domain_, used_edge_weight=self.used_edge_weight_, used_DSBN=self.used_DSBN_, conv_type=self.conv_type_, gnn_layer=self.gnn_layer_, cluster_num=self.cluster_num_, include_edge_recon_loss=self.include_edge_recon_loss_, include_gene_expr_recon_loss=self.include_gene_expr_recon_loss_, used_mmd=self.used_mmd_, ) self.is_trained_ = False def train(self, **trainer_kwargs): self.trainer = GarfieldTrainer( adata=self.adata, model=self.model, label_name=self.sample_col_, used_pca_feat=self.used_pca_feat_, adj_key=self.adj_key_, # data split edge_val_ratio=self.edge_val_ratio_, edge_test_ratio=self.edge_test_ratio_, node_val_ratio=self.node_val_ratio_, node_test_ratio=self.node_test_ratio_, # data process augment_type=self.augment_type_, # data loader num_neighbors=self.num_neighbors_, loaders_n_hops=self.loaders_n_hops_, edge_batch_size=self.edge_batch_size_, node_batch_size=self.node_batch_size_, # other parameters reload_best_model=self.reload_best_model_, use_early_stopping=self.use_early_stopping_, early_stopping_kwargs=self.early_stopping_kwargs_, monitor=self.monitor_, device_id=self.device_id_, verbose=self.verbose_, seed=self.seed_, **trainer_kwargs, ) self.trainer.train( n_epochs=self.n_epochs_, n_epochs_no_edge_recon=self.lambda_edge_recon_, # : int=0 learning_rate=self.learning_rate_, weight_decay=self.weight_decay_, gradient_clipping=self.gradient_clipping_, lambda_edge_recon=self.lambda_edge_recon_, lambda_gene_expr_recon=self.lambda_gene_expr_recon_, lambda_latent_adj_recon_loss=self.lambda_latent_adj_recon_loss_, lambda_latent_contrastive_instanceloss=self.lambda_latent_contrastive_instanceloss_, lambda_latent_contrastive_clusterloss=self.lambda_latent_contrastive_clusterloss_, lambda_omics_recon_mmd_loss=self.lambda_omics_recon_mmd_loss_, ) self.node_batch_size_ = self.trainer.node_batch_size_ self.is_trained_ = True self.model.eval() self.adata.obsm[self.latent_key_], _ = self.get_latent_representation( adata=self.adata, adj_key=self.adj_key_, return_mu_std=True, node_batch_size=self.node_batch_size_, ) # embedding def get_latent_representation( self, adata: Optional[AnnData] = None, adj_key: str = "connectivities", return_mu_std: bool = False, node_batch_size: int = 64, dtype: type = np.float64, ) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: """ Get the latent representation / gene program scores from a trained model. Parameters ---------- adata: AnnData object to get the latent representation for. If ´None´, uses the adata object stored in the model instance. adj_key: Key under which the sparse adjacency matrix is stored in ´adata.obsp´. return_mu_std: If `True`, return ´mu´ and ´std´ instead of latent features ´z´. node_batch_size: Batch size used during data loading. dtype: Precision to store the latent representations. Returns ---------- z: Latent space features (dim: n_obs x n_active_gps or n_obs x n_gps). mu: Expected values of the latent posterior (dim: n_obs x n_active_gps or n_obs x n_gps). std: Standard deviations of the latent posterior (dim: n_obs x n_active_gps or n_obs x n_gps). """ self._check_if_trained(warn=False) device = next(self.model.parameters()).device if adata is None: adata = self.adata # Create single dataloader containing entire dataset data_dict = prepare_data( adata=adata, adj_key=adj_key, used_pca_feat=self.used_pca_feat_, edge_val_ratio=0.0, edge_test_ratio=0.0, node_val_ratio=0.0, node_test_ratio=0.0, ) node_masked_data = data_dict["node_masked_data"] loader_dict = initialize_dataloaders( node_masked_data=node_masked_data, edge_train_data=None, edge_val_data=None, edge_batch_size=None, node_batch_size=node_batch_size, shuffle=False, ) node_loader = loader_dict["node_train_loader"] # Initialize latent vectors if return_mu_std: mu = np.empty( shape=(adata.shape[0], self.bottle_neck_neurons_), dtype=dtype ) std = np.empty( shape=(adata.shape[0], self.bottle_neck_neurons_), dtype=dtype ) else: z = np.empty(shape=(adata.shape[0], self.bottle_neck_neurons_), dtype=dtype) # Get latent representation for each batch of the dataloader and put it # into latent vectors for i, node_batch in enumerate(node_loader): n_obs_before_batch = i * node_batch_size n_obs_after_batch = n_obs_before_batch + node_batch.batch_size node_batch = node_batch.to(device) if return_mu_std: mu_batch, std_batch = self.model.get_latent_representation( node_batch=node_batch, augment_type=self.augment_type_, return_mu_std=True, ) mu[n_obs_before_batch:n_obs_after_batch, :] = ( mu_batch.detach().cpu().numpy() ) std[n_obs_before_batch:n_obs_after_batch, :] = ( std_batch.detach().cpu().numpy() ) else: z_batch = self.model.get_latent_representation( node_batch=node_batch, augment_type=self.augment_type_, return_mu_std=False, ) z[n_obs_before_batch:n_obs_after_batch, :] = ( z_batch.detach().cpu().numpy() ) if return_mu_std: return mu, std else: return z # Loss curve def plot_loss_curves(self, title="Losses Curve"): return self.trainer.plot_loss_curves(title=title) @classmethod def _get_init_params_from_dict(cls, dct): init_params = { # Preprocessing options "adata_list": dct["adata_list_"], "profile": dct["profile_"], "data_type": dct["data_type_"], "sub_data_type": dct["sub_data_type_"], "sample_col": dct["sample_col_"], "weight": dct["weight_"], "graph_const_method": dct["graph_const_method_"], "genome": dct["genome_"], "use_gene_weight": dct["use_gene_weight_"], "user_cache_path": dct["user_cache_path_"], "use_top_pcs": dct["use_top_pcs_"], "used_hvg": dct["used_hvg_"], "min_features": dct["min_features_"], "min_cells": dct["min_cells_"], "keep_mt": dct["keep_mt_"], "target_sum": dct["target_sum_"], "rna_n_top_features": dct["rna_n_top_features_"], "atac_n_top_features": dct["atac_n_top_features_"], "n_components": dct["n_components_"], "n_neighbors": dct["n_neighbors_"], "metric": dct["metric_"], "svd_solver": dct["svd_solver_"], "used_pca_feat": dct["used_pca_feat_"], "adj_key": dct["adj_key_"], # data split parameters "edge_val_ratio": dct["edge_val_ratio_"], "edge_test_ratio": dct["edge_test_ratio_"], "node_val_ratio": dct["node_val_ratio_"], "node_test_ratio": dct["node_test_ratio_"], # model parameters "augment_type": dct["augment_type_"], "svd_q": dct["svd_q_"], "use_FCencoder": dct["use_FCencoder_"], "gnn_layer": dct["gnn_layer_"], "conv_type": dct["conv_type_"], "hidden_dims": dct["hidden_dims_"], "bottle_neck_neurons": dct["bottle_neck_neurons_"], "cluster_num": dct["cluster_num_"], "num_heads": dct["num_heads_"], "dropout": dct["dropout_"], "concat": dct["concat_"], "drop_feature_rate": dct["drop_feature_rate_"], "drop_edge_rate": dct["drop_edge_rate_"], "used_edge_weight": dct["used_edge_weight_"], "used_DSBN": dct["used_DSBN_"], "used_mmd": dct["used_mmd_"], # data loader parameters "num_neighbors": dct["num_neighbors_"], "loaders_n_hops": dct["loaders_n_hops_"], "edge_batch_size": dct["edge_batch_size_"], "node_batch_size": dct["node_batch_size_"], # loss parameters "include_edge_recon_loss": dct["include_edge_recon_loss_"], "include_gene_expr_recon_loss": dct["include_gene_expr_recon_loss_"], "lambda_latent_contrastive_instanceloss": dct[ "lambda_latent_contrastive_instanceloss_" ], "lambda_latent_contrastive_clusterloss": dct[ "lambda_latent_contrastive_clusterloss_" ], "lambda_gene_expr_recon": dct["lambda_gene_expr_recon_"], "lambda_latent_adj_recon_loss": dct["lambda_latent_adj_recon_loss_"], "lambda_edge_recon": dct["lambda_edge_recon_"], "lambda_omics_recon_mmd_loss": dct["lambda_omics_recon_mmd_loss_"], # train parameters "n_epochs": dct["n_epochs_"], "n_epochs_no_edge_recon": dct["n_epochs_no_edge_recon_"], "learning_rate": dct["learning_rate_"], "weight_decay": dct["weight_decay_"], "gradient_clipping": dct["gradient_clipping_"], # other parameters "latent_key": dct["latent_key_"], "reload_best_model": dct["reload_best_model_"], "use_early_stopping": dct["use_early_stopping_"], "early_stopping_kwargs": dct["early_stopping_kwargs_"], "monitor": dct["monitor_"], "device_id": dct["device_id_"], "seed": dct["seed_"], "verbose": dct["verbose_"], } return init_params def label_transfer( self, ref_adata, ref_adata_emb, query_adata, query_adata_emb, ref_adata_obs, label_keys, n_neighbors=50, threshold=1, pred_unknown=False, mode="package", ): knn_transformer = weighted_knn_trainer( train_adata=ref_adata, train_adata_emb=ref_adata_emb, # location of our joint embedding n_neighbors=n_neighbors, ) labels, uncert = weighted_knn_transfer( query_adata=query_adata, query_adata_emb=query_adata_emb, # location of our embedding, query_adata.X in this case label_keys=label_keys, # (start of) obs column name(s) for which to transfer labels knn_model=knn_transformer, ref_adata_obs=ref_adata_obs, threshold=threshold, pred_unknown=pred_unknown, mode=mode, ) # 定义列名的映射 cols = ref_adata_obs.columns[ref_adata_obs.columns.str.startswith(label_keys)] if pred_unknown: rename_mapping_labels = {col: f"transferred_{col}_filtered" for col in cols} else: rename_mapping_labels = { col: f"transferred_{col}_unfiltered" for col in cols } # 定义 uncertainty 映射 rename_mapping_uncert = {col: f"transferred_{col}_uncert" for col in cols} # 重命名列并加入到 'query_adata.obs' query_adata.obs = query_adata.obs.join( labels.rename(columns=rename_mapping_labels) ) # 重命名列并加入到 'query_adata.obs' query_adata.obs = query_adata.obs.join( uncert.rename(columns=rename_mapping_uncert) ) ## 去除 query_adata obs 中 NA 的列 query_adata.obs = query_adata.obs.dropna(axis=1, how="all") return query_adata