Garfield.modules.GNNModelVAE
- class Garfield.modules.GNNModelVAE(encoder, bottle_neck_neurons, hidden_dims, feature_dim, num_heads, dropout, concat, n_domain, used_edge_weight, used_DSBN, conv_type, gnn_layer=2, cluster_num=20, include_edge_recon_loss=True, include_gene_expr_recon_loss=True, used_mmd=False)[source]
Garfield model class. This class contains the implementation of GNNModel Variational Auto-encoder.
- Parameters:
encoder (nn.Module) – The encoder module used in the variational graph autoencoder. ‘GAT’ or ‘GCN’.
bottle_neck_neurons (int) – Number of neurons in the bottleneck layer representing the latent dimension.
hidden_dims (int) – Number of hidden dimensions for the encoder.
feature_dim (int) – Number of feature dimensions in the input data.
num_heads (int) – Number of attention heads used in the GAT encoder.
dropout (float) – Dropout rate used in the encoder and decoder.
concat (bool) – Whether to concatenate outputs of different attention heads.
n_domain (int) – Number of domains for domain-specific batch normalization (DSBN).
used_edge_weight (bool) – Whether to use edge weights in the graph convolution operation.
used_DSBN (bool) – Whether to use domain-specific batch normalization (DSBN).
conv_type (str) – Type of graph convolution to use, e.g., ‘GAT’, ‘GATv2Conv’, ‘GCN’.
gnn_layer (int, optional) – Number of layers in the GNN encoder. Default is 2.
cluster_num (int, optional) – Number of clusters for the clustering layer. Default is 20.
include_edge_recon_loss (bool, optional) – Whether to include edge reconstruction loss in the model. Default is True.
include_gene_expr_recon_loss (bool, optional) – Whether to include gene expression reconstruction loss in the model. Default is True.
used_mmd (bool, optional) – Whether to use MMD (Maximum Mean Discrepancy) loss for domain adaptation. Default is False.
- __init__(encoder, bottle_neck_neurons, hidden_dims, feature_dim, num_heads, dropout, concat, n_domain, used_edge_weight, used_DSBN, conv_type, gnn_layer=2, cluster_num=20, include_edge_recon_loss=True, include_gene_expr_recon_loss=True, used_mmd=False)[source]
Initializes internal Module state, shared by both nn.Module and ScriptModule.
Methods
__init__(encoder, bottle_neck_neurons, ...)Initializes internal Module state, shared by both nn.Module and ScriptModule.
add_module(name, module)Adds a child module to the current module.
apply(fn)Applies
fnrecursively to every submodule (as returned by.children()) as well as self.bfloat16()Casts all floating point parameters and buffers to
bfloat16datatype.buffers([recurse])Returns an iterator over module buffers.
children()Returns an iterator over immediate children modules.
cpu()Moves all model parameters and buffers to the CPU.
cuda([device])Moves all model parameters and buffers to the GPU.
decode(*args, **kwargs)Runs the decoder and computes edge probabilities.
double()Casts all floating point parameters and buffers to
doubledatatype.encode(*args, **kwargs)Runs the encoder and computes node-wise latent variables.
eval()Sets the module in evaluation mode.
extra_repr()Set the extra representation of the module
float()Casts all floating point parameters and buffers to
floatdatatype.forward(data_batch, decoder_type, augment_type)Processes the input data through the encoder to obtain the latent representations and uses the decoder to reconstruct features or edges, depending on the task.
get_buffer(target)Returns the buffer given by
targetif it exists, otherwise throws an error.get_extra_state()Returns any extra state to include in the module's state_dict.
get_latent_representation(node_batch[, ...])Encodes the input data into latent space, either returning the latent features (z) or the distribution parameters (mu and std) based on the input option.
get_parameter(target)Returns the parameter given by
targetif it exists, otherwise throws an error.get_submodule(target)Returns the submodule given by
targetif it exists, otherwise throws an error.half()Casts all floating point parameters and buffers to
halfdatatype.ipu([device])Moves all model parameters and buffers to the IPU.
load_state_dict(state_dict[, strict])Copies parameters and buffers from
state_dictinto this module and its descendants.loss(edge_model_output, node_model_output, ...)Computes the total loss for the model by combining different loss components such as KL divergence, edge reconstruction loss, gene expression reconstruction loss, and contrastive losses.
modules()Returns an iterator over all modules in the network.
named_buffers([prefix, recurse])Returns an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.
named_children()Returns an iterator over immediate children modules, yielding both the name of the module as well as the module itself.
named_modules([memo, prefix, remove_duplicate])Returns an iterator over all modules in the network, yielding both the name of the module as well as the module itself.
named_parameters([prefix, recurse])Returns an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.
parameters([recurse])Returns an iterator over module parameters.
recon_loss(z, pos_edge_index[, neg_edge_index])Given latent variables
z, computes the binary cross entropy loss for positive edgespos_edge_indexand negative sampled edges.register_backward_hook(hook)Registers a backward hook on the module.
register_buffer(name, tensor[, persistent])Adds a buffer to the module.
register_forward_hook(hook)Registers a forward hook on the module.
register_forward_pre_hook(hook)Registers a forward pre-hook on the module.
register_full_backward_hook(hook)Registers a backward hook on the module.
register_load_state_dict_post_hook(hook)Registers a post hook to be run after module's
load_state_dictis called.register_module(name, module)Alias for
add_module().register_parameter(name, param)Adds a parameter to the module.
reparameterize(mu, logstd[, eps])Applies the reparameterization trick to sample a latent vector from the latent distribution during training.
requires_grad_([requires_grad])Change if autograd should record operations on parameters in this module.
reset_parameters()Resets all learnable parameters of the module.
set_extra_state(state)This function is called from
load_state_dict()to handle any extra state found within the state_dict.share_memory()See
torch.Tensor.share_memory_()state_dict(*args[, destination, prefix, ...])Returns a dictionary containing references to the whole state of the module.
test(z, pos_edge_index, neg_edge_index)Given latent variables
z, positive edgespos_edge_indexand negative edgesneg_edge_index, computes area under the ROC curve (AUC) and average precision (AP) scores.to(*args, **kwargs)Moves and/or casts the parameters and buffers.
to_empty(*, device)Moves the parameters and buffers to the specified device without copying storage.
train([mode])Sets the module in training mode.
type(dst_type)Casts all parameters and buffers to
dst_type.xpu([device])Moves all model parameters and buffers to the XPU.
zero_grad([set_to_none])Sets gradients of all model parameters to zero.
Attributes
T_destinationdump_patchestraining