Garfield.trainer.GarfieldTrainer

class Garfield.trainer.GarfieldTrainer(adata, model: Module, label_name, used_pca_feat, adj_key, edge_val_ratio, edge_test_ratio, node_val_ratio, node_test_ratio, augment_type, num_neighbors, loaders_n_hops, edge_batch_size, node_batch_size, reload_best_model, use_early_stopping, early_stopping_kwargs, monitor, verbose, device_id, seed, **kwargs)[source]

Initializes the GarfieldTrainer class, which handles data preparation, model initialization, and training of the Garfield model.

Parameters:
  • adata (AnnData) – Annotated data matrix.

  • model (nn.Module) – The Garfield model to be trained.

  • label_name (str) – Column name for labels in the annotated data.

  • used_pca_feat (Bool) – Whether used pca features or not for node feature.

  • adj_key (str) – Key for the adjacency matrix (e.g., spatial connectivity) in the data.

  • edge_val_ratio (float) – Proportion of edges to use for validation in edge-level tasks.

  • edge_test_ratio (float) – Proportion of edges to use for testing in edge-level tasks.

  • node_val_ratio (float) – Proportion of nodes to use for validation in node-level tasks.

  • node_test_ratio (float) – Proportion of nodes to use for testing in node-level tasks.

  • augment_type (str) – Type of data augmentation to apply (e.g., ‘dropout’, ‘svd’).

  • num_neighbors (int) – Number of neighbors to sample for each node in graph-based tasks.

  • loaders_n_hops (int) – Number of hops to consider for neighbors in graph-based tasks.

  • edge_batch_size (int) – Batch size for edge-level tasks.

  • node_batch_size (int or None) – Batch size for node-level tasks. If None, it will be determined automatically.

  • reload_best_model (bool) – Whether to reload the best model after training.

  • use_early_stopping (bool) – Whether to apply early stopping during training.

  • early_stopping_kwargs (dict) – Additional arguments for early stopping (e.g., patience, delta).

  • monitor (bool) – Whether to print monitoring logs during training.

  • verbose (bool) – Whether to print detailed logs during training.

  • seed (int) – Seed for random number generation to ensure reproducibility.

  • kwargs (dict) – Additional arguments for training configuration.

__init__(adata, model: Module, label_name, used_pca_feat, adj_key, edge_val_ratio, edge_test_ratio, node_val_ratio, node_test_ratio, augment_type, num_neighbors, loaders_n_hops, edge_batch_size, node_batch_size, reload_best_model, use_early_stopping, early_stopping_kwargs, monitor, verbose, device_id, seed, **kwargs)[source]

Methods

__init__(adata, model, label_name, ...)

is_early_stopping()

Check whether to apply early stopping, update learning rate and save best model state.

plot_loss_curves([title])

Plot loss curves.

save(dir_path[, overwrite, save_anndata])

Save the state of the model. Neither the trainer optimizer state nor the trainer history are saved. :param dir_path: Path to a directory. :param overwrite: Overwrite existing data or not. If False and directory already exists at dir_path, error will be raised. :param save_anndata: If True, also saves the anndata :param anndata_write_kwargs: Kwargs for anndata write function.

test_metrics_epoch(lambda_edge_recon, ...)

Evaluates the Garfield model at the end of each epoch on validation data.

train(n_epochs, n_epochs_no_edge_recon, ...)

Trains the Garfield model for a specified number of epochs with joint edge-level and node-level tasks.

validate_end()

Evaluates the Garfield model after training on the validation dataset.