Garfield.modules.compute_edge_recon_loss
- Garfield.modules.compute_edge_recon_loss(edge_recon_logits: Tensor, edge_recon_labels: Tensor, edge_incl: Tensor | None = None) Tensor[source]
Compute edge reconstruction weighted binary cross entropy loss with logits using ground truth edge labels and predicted edge logits.
- Parameters:
edge_recon_logits – Predicted edge reconstruction logits for both positive and negative sampled edges (dim: 2 * ´edge_batch_size´).
edge_recon_labels – Edge ground truth labels for both positive and negative sampled edges (dim: 2 * ´edge_batch_size´).
edge_incl – Boolean mask which indicates edges to be included in the edge recon loss (dim: 2 * ´edge_batch_size´). If ´None´, includes all edges.
- Returns:
Weighted binary cross entropy loss between edge labels and predicted edge probabilities (calculated from logits for numerical stability in backpropagation).
- Return type:
edge_recon_loss