batchtype.py 183 Bytes
Newer Older
1
2
3
4
5
6
7
8
from typing import Dict, List, Union

import torch
import torch_geometric

BatchType = Union[torch_geometric.data.batch.Batch, Dict[str, torch.Tensor]]

DataSetType = List[BatchType]