Garfield.modules.compute_edge_recon_loss

Garfield.modules.compute_edge_recon_loss(edge_recon_logits: Tensor, edge_recon_labels: Tensor, edge_incl: Tensor | None = None) Tensor[source]

Compute edge reconstruction weighted binary cross entropy loss with logits using ground truth edge labels and predicted edge logits.

Parameters:
  • edge_recon_logits – Predicted edge reconstruction logits for both positive and negative sampled edges (dim: 2 * ´edge_batch_size´).

  • edge_recon_labels – Edge ground truth labels for both positive and negative sampled edges (dim: 2 * ´edge_batch_size´).

  • edge_incl – Boolean mask which indicates edges to be included in the edge recon loss (dim: 2 * ´edge_batch_size´). If ´None´, includes all edges.

Returns:

Weighted binary cross entropy loss between edge labels and predicted edge probabilities (calculated from logits for numerical stability in backpropagation).

Return type:

edge_recon_loss