Garfield: Graph-based Contrastive Learning enable Fast Single-Cell Embedding

API

Import Garfield as:

import Garfield

Configuration for Garfield

settings.set_figure_params([context, style, ...])

Set global parameters for figures.

settings.set_gf_params([config])

Set Garfield parameters

settings.set_workdir([workdir])

Set working directory.

Reading

data.initialize_dataloaders(node_masked_data)

Initialize edge-level and node-level training and validation dataloaders.

data.edge_level_split(data, edge_label_adj)

Split a PyG Data object into training, validation and test PyG Data objects using an edge-level split.

data.node_level_split_mask(data[, ...])

Split data on node-level into training, validation and test sets by adding node-level masks (train_mask, val_mask, test_mask) to the PyG Data object.

data.prepare_data(adata[, label_name, ...])

Prepares the dataset for training and evaluation by performing node-level and edge-level splits and returns a dictionary containing the processed data.

data.GraphAnnTorchDataset(adata[, ...])

Spatially annotated torch dataset class to extract node features, node labels, adjacency matrix and edge indices in a standardized format from an AnnData object.

See more at anndata

Preprocessing

preprocessing.gene_scores(adata, genome[, ...])

Calculate gene scores of scATACseq data

preprocessing.get_nearest_neighbors(...[, ...])

For each row in query_arr, compute its nearest neighbor in target_arr.

preprocessing.preprocessing_rna(adata[, ...])

Preprocessing single-cell RNA-seq data

preprocessing.preprocessing_atac(adata[, ...])

Preprocess scATAC data matrix.

preprocessing.preprocessing_adt(adata[, ...])

Preprocessing single-cell RNA-seq data

preprocessing.preprocessing(adata[, ...])

Preprocessing function for single-cell and multi-modal data.

preprocessing.DataProcess(adata_list, profile)

Processes single or multi-modal data (e.g., RNA, ATAC, ADT, spatial) with optional preprocessing steps such as normalization, feature selection, and dimensionality reduction.

Model

model.utils.weighted_knn_trainer(...[, ...])

Trains a weighted KNN classifier on train_adata.

model.utils.weighted_knn_transfer(...[, ...])

Annotates query_adata cells with an input trained weighted KNN classifier.

model.Garfield.Garfield(gf_params)

Garfield: Graph-based Contrastive Learning enable Fast Single-Cell Embedding

Loss

modules.compute_omics_recon_mse_loss(recon_x, x)

Computes MSE loss between reconstructed data and ground truth data.

modules.compute_edge_recon_loss(...[, edge_incl])

Compute edge reconstruction weighted binary cross entropy loss with logits using ground truth edge labels and predicted edge logits.

modules.compute_kl_reg_loss(mu, logstd)

Compute Kullback-Leibler divergence as per Kingma, D.

modules.compute_contrastive_instanceloss(...)

Compute the contrastive loss given two batches of feature vectors z_i and z_j.

modules.compute_contrastive_clusterloss(c_i, ...)

Cluster loss function.

modules.compute_omics_recon_mmd_loss(...)

Initializes Maximum Mean Discrepancy(MMD) between source_features and target_features.

Modules

modules.GNNModelVAE(encoder, ...[, ...])

Garfield model class.

NN

nn.GATEncoder(in_channels, hidden_dims, ...)

The GATEncoder class implements a Graph Attention Network (GAT) encoder with multiple layers, normalization, and optional fully connected (FC) encoder.

nn.GCNEncoder(in_channels, hidden_dims, ...)

The GCNEncoder class implements a Graph Convolutional Network (GCN) encoder with multiple layers, normalization, and optional fully connected (FC) encoder.

nn.GATDecoder(in_channels, hidden_dims, ...)

Graph Attention Network (GAT) Decoder class.

nn.GCNDecoder(in_channels, hidden_dims, ...)

Graph Convolutional Network (GCN) Decoder class.

nn.DSBatchNorm(num_features, n_domain[, ...])

Domain-specific Batch Normalization

Trainer

trainer.GarfieldTrainer(adata, model, ...)

Initializes the GarfieldTrainer class, which handles data preparation, model initialization, and training of the Garfield model.

trainer.eval_metrics(edge_recon_probs, ...)

Get the evaluation metrics for a (balanced) sample of positive and negative edges and a sample of nodes.

trainer.plot_eval_metrics(eval_dict)

Plot evaluation metrics.

Tools

trainer.EarlyStopping([...])

EarlyStopping class for early stopping of Garfield training.

trainer.print_progress(epoch, logs, n_epochs)

Create message for '_print_progress_bar()' and print it out with a progress bar.

Analysis

analysis.calc_marker_stats(ad, groupby[, ...])

Calculate marker statistics for grouped data.

analysis.filter_marker_stats(data[, ...])

Filter marker statistics based on thresholds.

analysis.aggregate_top_markers(ad, mks, groupby)

Aggregate top marker genes.

analysis.get_enrichr_geneset([organism])

analysis.get_niche_enrichr(mks, geneset[, ...])

Perform Enrichr analysis on top genes for each niche derived from aggregated marker statistics (non-parallel version).

analysis.get_fast_niche_enrichr(mks, geneset)

Perform Enrichr analysis on top genes for each niche derived from aggregated marker statistics.

analysis.get_niche_gsea(mks, geneset[, ...])

Perform GSEA analysis for each niche based on the full ranked gene list from marker statistics.

analysis.calc_neighbor_prop(adata[, ...])

Normalize the cell type abundance based on the nearest neighbors for each batch in the AnnData object.

Plot

plot.plot_multi_patterns_spatial(adata, ...)

Plot taken from cell2location at https://github.com/BayraktarLab/cell2location.

plot.plot_markers(adata, groupby, mks[, ...])

Plot markers for specific groups.

plot.niches_enrichment_barplot(enrichments, ...)

Create a barplot for the enrichment results (either from Enrichr or GSEA).

plot.niches_enrichment_dotplot(enrichments, ...)

Create a dotplot for the enrichment results (either from Enrichr or GSEA).