Commit 41b4f9ab authored by mova's avatar mova
Browse files

keep the edge_index when generating the adj matrix

parent 15faa8e9
......@@ -21,17 +21,18 @@ from fgsim.io import qf
# Load files
ds_path = Path(conf.path.dataset)
assert ds_path.is_dir()
files = [str(e) for e in sorted(ds_path.glob(conf.path.dataset_glob))]
files = sorted(ds_path.glob(conf.path.dataset_glob))
if len(files) < 1:
raise RuntimeError("No hdf5 datasets found")
ChunkType = List[Tuple[Path, int, int]]
# load lengths
if not os.path.isfile(conf.path.ds_lenghts):
len_dict = {}
for fn in files:
with uproot.open(fn) as rfile:
len_dict[fn] = rfile[conf.loader.rootprefix].num_entries
len_dict[str(fn)] = rfile[conf.loader.rootprefix].num_entries
with open(conf.path.ds_lenghts, "w") as f:
yaml.dump(len_dict, f, Dumper=yaml.SafeDumper)
else:
......@@ -40,7 +41,7 @@ else:
# reading from the filesystem
def read_chunk(chunks: List[Tuple[str, int, int]]) -> ak.highlevel.Array:
def read_chunk(chunks: ChunkType) -> ak.highlevel.Array:
chunks_list = []
for chunk in chunks:
file_path, start, end = chunk
......@@ -68,7 +69,7 @@ def geo_batch(list_of_graphs: List[GraphType]) -> GraphType:
ToSparseTranformer = torch_geometric.transforms.ToSparseTensor(
remove_edge_index=True, fill_cache=True
remove_edge_index=False, fill_cache=True
)
......
......@@ -27,7 +27,6 @@ def read_file(file: Path) -> List[GraphType]:
def preprocessed_seq():
return (
qf.ProcessStep(read_file, 1, name="read_chunk"),
Queue(1),
qf.pack.UnpackStep(),
Queue(conf.loader.prefetch_batches),
)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment