Garfield.trainer.GarfieldTrainer
- class Garfield.trainer.GarfieldTrainer(adata, model: Module, label_name, used_pca_feat, adj_key, edge_val_ratio, edge_test_ratio, node_val_ratio, node_test_ratio, augment_type, num_neighbors, loaders_n_hops, edge_batch_size, node_batch_size, reload_best_model, use_early_stopping, early_stopping_kwargs, monitor, verbose, device_id, seed, **kwargs)[source]
Initializes the GarfieldTrainer class, which handles data preparation, model initialization, and training of the Garfield model.
- Parameters:
adata (AnnData) – Annotated data matrix.
model (nn.Module) – The Garfield model to be trained.
label_name (str) – Column name for labels in the annotated data.
used_pca_feat (Bool) – Whether used pca features or not for node feature.
adj_key (str) – Key for the adjacency matrix (e.g., spatial connectivity) in the data.
edge_val_ratio (float) – Proportion of edges to use for validation in edge-level tasks.
edge_test_ratio (float) – Proportion of edges to use for testing in edge-level tasks.
node_val_ratio (float) – Proportion of nodes to use for validation in node-level tasks.
node_test_ratio (float) – Proportion of nodes to use for testing in node-level tasks.
augment_type (str) – Type of data augmentation to apply (e.g., ‘dropout’, ‘svd’).
num_neighbors (int) – Number of neighbors to sample for each node in graph-based tasks.
loaders_n_hops (int) – Number of hops to consider for neighbors in graph-based tasks.
edge_batch_size (int) – Batch size for edge-level tasks.
node_batch_size (int or None) – Batch size for node-level tasks. If None, it will be determined automatically.
reload_best_model (bool) – Whether to reload the best model after training.
use_early_stopping (bool) – Whether to apply early stopping during training.
early_stopping_kwargs (dict) – Additional arguments for early stopping (e.g., patience, delta).
monitor (bool) – Whether to print monitoring logs during training.
verbose (bool) – Whether to print detailed logs during training.
seed (int) – Seed for random number generation to ensure reproducibility.
kwargs (dict) – Additional arguments for training configuration.
- __init__(adata, model: Module, label_name, used_pca_feat, adj_key, edge_val_ratio, edge_test_ratio, node_val_ratio, node_test_ratio, augment_type, num_neighbors, loaders_n_hops, edge_batch_size, node_batch_size, reload_best_model, use_early_stopping, early_stopping_kwargs, monitor, verbose, device_id, seed, **kwargs)[source]
Methods
__init__(adata, model, label_name, ...)is_early_stopping()Check whether to apply early stopping, update learning rate and save best model state.
plot_loss_curves([title])Plot loss curves.
save(dir_path[, overwrite, save_anndata])Save the state of the model. Neither the trainer optimizer state nor the trainer history are saved. :param dir_path: Path to a directory. :param overwrite: Overwrite existing data or not. If False and directory already exists at dir_path, error will be raised. :param save_anndata: If True, also saves the anndata :param anndata_write_kwargs: Kwargs for anndata write function.
test_metrics_epoch(lambda_edge_recon, ...)Evaluates the Garfield model at the end of each epoch on validation data.
train(n_epochs, n_epochs_no_edge_recon, ...)Trains the Garfield model for a specified number of epochs with joint edge-level and node-level tasks.
validate_end()Evaluates the Garfield model after training on the validation dataset.