Garfield.data.node_level_split_mask

Garfield.data.node_level_split_mask(data: Data, val_ratio: float = 0.1, test_ratio: float = 0.0, split_key: str = 'x') Data[source]

Split data on node-level into training, validation and test sets by adding node-level masks (train_mask, val_mask, test_mask) to the PyG Data object.

Parameters:
  • data – PyG Data object to be split.

  • val_ratio – Ratio of nodes to be included in the validation split.

  • test_ratio – Ratio of nodes to be included in the test split.

  • split_key – The attribute key of the PyG Data object that holds the ground truth labels. Only nodes in which the key is present will be split.

Returns:

PyG Data object with ´train_mask´, ´val_mask´ and ´test_mask´ attributes added.

Return type:

data