Commit 24434b27 authored by mova's avatar mova
Browse files

split the dataset processing sequences depending on configuration add typing checks

parent 9d2efe72
......@@ -12,6 +12,10 @@ from .utils.logger import logger
# Add the project to the path, -> `import fgsim.x`
sys.path.append(os.path.dirname(os.path.realpath(".")))
from typeguard.importhook import install_import_hook
install_import_hook("fgsim")
def main():
# always reload the local modules
......
......@@ -60,19 +60,19 @@ def split_layer_subgraphs(batch):
)
# if not 0 in edge_index_inner.shape:
# logger.warn(
# logger.warning(
# (torch.min(edge_index_inner),
# torch.max(edge_index_inner),
# torch.sum(mask_input_inner),)
# )
# if not 0 in edge_index_forward.shape:
# logger.warn(
# logger.warning(
# (torch.min(edge_index_forward),
# torch.max(edge_index_forward),
# torch.sum(mask_input_forward),)
# )
# if not 0 in edge_index_backward.shape:
# logger.warn(
# logger.warning(
# (torch.min(edge_index_backward),
# torch.max(edge_index_backward),
# torch.sum(mask_input_backward),)
......
from typing import Dict, List, Union
import torch
import torch_geometric
BatchType = Union[torch_geometric.data.batch.Batch, Dict[str, torch.Tensor]]
DataSetType = List[BatchType]
"""
Conversion from list of hit to graph
"""
# import sys
import awkward as ak
import numpy as np
import pandas as pd
import torch
import torch_geometric
# import yappi
from torch_geometric.data import Data as GraphType
from fgsim.config import conf
# with uproot.open(conf.path.geo_lup) as rf:
# geo_lup = rf['analyzer/tree;1'].arrays(library='ak')
# geo_lup = ak.to_pandas(geo_lup)
# geo_lup.set_index('globalid', inplace=True)
# geo_lup.to_pickle('data/hgcal/DetIdLUT_full.pd')
geo_lup = pd.read_pickle("data/hgcal/DetIdLUT_full.pd")
def event_to_graph(event: ak.highlevel.Record) -> GraphType:
# yappi.start()
# plt.close()
# plt.cla()
# plt.clf()
# event_hitenergies = ak.to_numpy(event.rechit_energy)
# event_hitenergies = event_hitenergies[event_hitenergies < 0.04]
# plt.hist(event_hitenergies, bins=100)
# plt.savefig("hist.png")
# print("foo")
# cut the hits 81% energy 56% of the hits @ 57 GeV
energyfilter = event["rechit_energy"] > 0.0048
num_nodes = sum(energyfilter)
event_hitenergies = ak.to_numpy(event["rechit_energy"][energyfilter])
event_detids = ak.to_numpy(event[conf.loader.id_key][energyfilter])
# print(
# (
# sum(event_hitenergies) / sum(event["rechit_energy"]),
# len(event_hitenergies) / len(event["rechit_energy"]),
# )
# )
# Filter out the rows/detids that are not in the event
geo_lup_filtered = geo_lup.loc[event_detids]
neighbor_keys = [
"next",
"previous",
"n0",
"n1",
"n2",
"n3",
"n4",
"n5",
"n6",
"n7",
"n8",
"n9",
"n10",
"n11",
]
neighbor_df = geo_lup_filtered[neighbor_keys]
static_feature_keys = [
"x",
"y",
"z",
"celltype",
"issilicon",
]
# compute static features
static_feature_matrix = geo_lup_filtered[static_feature_keys].values
# number detids in events
detid_lut = {det_id: i for i, det_id in enumerate(event_detids)}
# ## construct adj mtx from detids
# Select the detids from the current event
node_neighbors_mtx_list = [
np.vstack((np.repeat(detid, len(neighbor_keys)), neighbors_list.values))
for detid, neighbors_list in neighbor_df.iterrows()
]
edge_index_detid = np.hstack(node_neighbors_mtx_list)
# Filter the out the zero entries:
edge_index_detid = edge_index_detid[:, edge_index_detid[1] != 0]
# Filter neigbhors not within the array
eventid_set = set(event_detids)
target_detid_in_event_mask = np.array(
[(lambda x: x in eventid_set)(x) for x in edge_index_detid[1]]
)
edge_index_detid = edge_index_detid[:, target_detid_in_event_mask]
# shift from detids to nodenumbers
edge_index = np.vectorize(detid_lut.get)(edge_index_detid)
# yappi.stop()
# current_module = sys.modules[__name__]
# yappi.get_func_stats(
# filter_callback=lambda x: yappi.module_matches(x, [current_module])
# ).sort("name", "desc").print_all()
graph = torch_geometric.data.Data(
x=torch.tensor(event_hitenergies, dtype=torch.float32).reshape(
(num_nodes, 1)
),
y=event["gen_energy"],
edge_index=torch.tensor(edge_index, dtype=torch.long),
)
graph.static_features = torch.tensor(
np.asarray(static_feature_matrix, dtype=np.float32), dtype=torch.float32
).reshape((num_nodes, -1))
return graph
......@@ -3,9 +3,13 @@ Here steps for reading the h5 files and processing the calorimeter \
images to graphs are definded. `process_seq` is the function that \
should be passed the qfseq.
"""
import os
from pathlib import Path
import h5py as h5
import numpy as np
import torch_geometric
import yaml
from torch.multiprocessing import Queue
from fgsim.config import conf
......@@ -14,6 +18,26 @@ from fgsim.geo.transform import transform
from . import qf
# Load files
ds_path = Path(conf.path.dataset)
assert ds_path.is_dir()
files = [str(e) for e in sorted(ds_path.glob("**/*.h5"))]
if len(files) < 1:
raise RuntimeError("No hdf5 datasets found")
# load lengths
if not os.path.isfile(conf.path.ds_lenghts):
len_dict = {}
for fn in files:
with h5.File(fn) as h5_file:
len_dict[fn] = len(h5_file[conf.yvar])
with open(conf.path.ds_lenghts, "w") as f:
yaml.dump(len_dict, f, Dumper=yaml.SafeDumper)
else:
with open(conf.path.ds_lenghts, "r") as f:
len_dict = yaml.load(f, Loader=yaml.SafeLoader)
# reading from the filesystem
def read_chunk(chunks):
......
......@@ -2,16 +2,38 @@
`process_seq` is the function that should be passed the qfseq."""
import math
import os
from pathlib import Path
import h5py as h5
import numpy as np
import torch
import yaml
from torch.multiprocessing import Queue
from fgsim.config import conf
from . import qf
# Load files
ds_path = Path(conf.path.dataset)
assert ds_path.is_dir()
files = [str(e) for e in sorted(ds_path.glob("**/*.h5"))]
if len(files) < 1:
raise RuntimeError("No hdf5 datasets found")
# load lengths
if not os.path.isfile(conf.path.ds_lenghts):
len_dict = {}
for fn in files:
with h5.File(fn) as h5_file:
len_dict[fn] = len(h5_file[conf.yvar])
with open(conf.path.ds_lenghts, "w") as f:
yaml.dump(len_dict, f, Dumper=yaml.SafeDumper)
else:
with open(conf.path.ds_lenghts, "r") as f:
len_dict = yaml.load(f, Loader=yaml.SafeLoader)
# reading from the filesystem
def read_chunk(chunks):
......
"""
Here steps for reading the root files and processing the hit list \
to graphs are definded. `process_seq` is the function that \
should be passed the qfseq.
"""
import os
from pathlib import Path
from typing import List, Tuple
import awkward as ak
import torch_geometric
import uproot
import yaml
from torch.multiprocessing import Queue
from torch_geometric.data import Data as GraphType
from fgsim.config import conf
from fgsim.geo.detid_to_graph import event_to_graph
from . import qf
# Load files
ds_path = Path(conf.path.dataset)
assert ds_path.is_dir()
files = [str(e) for e in sorted(ds_path.glob(conf.path.dataset_glob))]
if len(files) < 1:
raise RuntimeError("No hdf5 datasets found")
# load lengths
if not os.path.isfile(conf.path.ds_lenghts):
len_dict = {}
for fn in files:
with uproot.open(fn) as rfile:
len_dict[fn] = rfile[conf.loader.rootprefix][conf.yvar].num_entries
with open(conf.path.ds_lenghts, "w") as f:
yaml.dump(len_dict, f, Dumper=yaml.SafeDumper)
else:
with open(conf.path.ds_lenghts, "r") as f:
len_dict = yaml.load(f, Loader=yaml.SafeLoader)
# reading from the filesystem
def read_chunk(chunks: List[Tuple[str, int, int]]) -> ak.highlevel.Array:
chunks_list = []
for chunk in chunks:
file_path, start, end = chunk
with uproot.open(file_path) as rfile:
roottree = rfile[conf.loader.rootprefix]
chunks_list.append(
roottree.arrays(
conf.loader.keylist,
entry_start=start,
entry_stop=end,
library="ak",
)
)
# split up the events and pass them as a dict
output = ak.concatenate(chunks_list)
# remove the double gen energy
output["gen_energy"] = output["gen_energy"][:, 0]
return output
def geo_batch(list_of_graphs: List[GraphType]) -> GraphType:
batch = torch_geometric.data.Batch().from_data_list(list_of_graphs)
return batch
def magic_do_nothing(elem: GraphType) -> GraphType:
return elem
# Collect the steps
def process_seq():
return (
qf.ProcessStep(read_chunk, 2, name="read_chunk"),
# Queue(1),
# # In the input is now [(x,y), ... (x [300 * 51 * 51 * 25], y [300,1] ), (x,y)]
# # For these elements to be processed by each of the workers in the following
# # transformthey need to be (x [51 * 51 * 25], y [1] ):
qf.PoolStep(
event_to_graph,
nworkers=conf.loader.num_workers_transform,
name="transform",
),
# Queue(1),
# qf.RepackStep(conf.loader.batch_size),
qf.ProcessStep(geo_batch, 1, name="geo_batch"),
# qf.ProcessStep(
# split_layer_subgraphs,
# conf.loader.num_workers_stack,
# name="split_layer_subgraphs",
# ),
# Needed for outputs to stay in order.
qf.ProcessStep(
magic_do_nothing,
1,
name="magic_do_nothing",
),
Queue(conf.loader.prefetch_batches),
)
......@@ -23,7 +23,8 @@ pickling_support.install()
# Make it work ()
mp.set_sharing_strategy("file_descriptor")
# mp.set_sharing_strategy("file_descriptor")
mp.set_sharing_strategy("file_system")
# Reworked according to the recommendations in
# https://pytorch.org/docs/stable/multiprocessing.html
......
......@@ -139,7 +139,7 @@ class RepackStep(StepBase):
{self.workername} terminal element into output queue {id(self.outq)}."""
)
self.safe_put(self.outq, TerminateQueue())
logger.warn(
logger.warning(
f"""\
{self.workername} finished with iterable (in {self.count_in}/out {self.count_out})"""
)
......
......@@ -70,7 +70,7 @@ class PoolStep(StepBase):
if isinstance(wkin, TerminateQueue):
logger.info(f"{self.workername} terminating")
self.safe_put(self.outq, TerminateQueue())
logger.warn(
logger.warning(
f"""\
{self.workername} finished with iterable (in {self.count_in}/out {self.count_out})"""
)
......@@ -102,12 +102,12 @@ class PoolStep(StepBase):
break
except Exception as error:
logger.warn(f"""{self.workername} got error""")
logger.warning(f"""{self.workername} got error""")
self.handle_error(error, wkin)
break
# if wkin_iter.count > 200:
# logger.warn(
# logger.warning(
# f"""\
# Giving large iterables ({wkin_iter.count})\
# to a worker can lead to crashes.
......
......@@ -144,7 +144,7 @@ class Sequence:
def stop(self):
logger.info("Before Sequence Stop\n" + str(self.flowstatus()))
logger.warn("Setting shutdown event!")
logger.warning("Setting shutdown event!")
self.shutdown_event.set()
......
......@@ -5,42 +5,22 @@ loaded depending on `conf.loader.name`.
import importlib
import os
from pathlib import Path
import h5py as h5
import numpy as np
import torch
import yaml
from fgsim.config import conf
from fgsim.geo.batchtype import DataSetType
from fgsim.io import qf
from fgsim.io.qf.sequence import Sequence as qfseq
from fgsim.utils.logger import logger
from . import qf
# Import the specified processing sequence
process_seq = importlib.import_module(
f"fgsim.io.{conf.loader.name}", "fgsim.models"
).process_seq
ds_path = Path(conf.path.dataset)
assert ds_path.is_dir()
files = [str(e) for e in sorted(ds_path.glob("**/*.h5"))]
if len(files) < 1:
raise RuntimeError("No hdf5 datasets found")
sel_seq = importlib.import_module(f"fgsim.io.{conf.loader.name}", "fgsim.models")
# load lengths
if not os.path.isfile(conf.path.ds_lenghts):
len_dict = {}
for fn in files:
with h5.File(fn) as h5_file:
len_dict[fn] = len(h5_file[conf.yvar])
with open(conf.path.ds_lenghts, "w") as f:
yaml.dump(len_dict, f, Dumper=yaml.SafeDumper)
else:
with open(conf.path.ds_lenghts, "r") as f:
len_dict = yaml.load(f, Loader=yaml.SafeLoader)
process_seq = sel_seq.process_seq
files = sel_seq.files
len_dict = sel_seq.len_dict
class QueuedDataLoader:
......@@ -63,13 +43,13 @@ must queue an epoch via `queue_epoch()` and iterate over the instance of the cla
elem_to_add = chunksize - current_chunck_elements
if elem_left_in_cur_file > elem_to_add:
chunk_coords[-1].append(
[files[ifile], ielement, ielement + elem_to_add]
(files[ifile], ielement, ielement + elem_to_add)
)
ielement += elem_to_add
current_chunck_elements += elem_to_add
else:
chunk_coords[-1].append(
[files[ifile], ielement, ielement + elem_left_in_cur_file]
(files[ifile], ielement, ielement + elem_left_in_cur_file)
)
ielement = 0
current_chunck_elements += elem_left_in_cur_file
......@@ -113,40 +93,49 @@ must queue an epoch via `queue_epoch()` and iterate over the instance of the cla
n_validation_chunks + n_testing_chunks :
]
# Check that there is a reasonable amount of data
assert (
len(self.validation_chunks) + len(self.testing_chunks)
> len(self.training_chunks) / 2
)
# Assign the sequence with the specifice steps needed to process the dataset.
self.qfseq = qf.Sequence(*process_seq())
if not os.path.isfile(conf.path.validation):
logger.warn(
logger.warning(
f"""\
Processing validation batches, queuing {len(self.validation_chunks)} chunks."""
)
self.qfseq.queue_iterable(self.validation_chunks)
self._validation_batches = [batch for batch in self.qfseq]
torch.save(self._validation_batches, conf.path.validation)
logger.warn("Validation batches pickled.")
logger.warning("Validation batches pickled.")
if not os.path.isfile(conf.path.test):
logger.warn(
logger.warning(
f"""\
Processing testing batches, queuing {len(self.validation_chunks)} chunks."""
)
self.qfseq.queue_iterable(self.testing_chunks)
self._testing_batches = [batch for batch in self.qfseq]
torch.save(self._testing_batches, conf.path.test)
logger.warn("Testing batches pickled.")
logger.warning("Testing batches pickled.")
@property
def validation_batches(self):
def validation_batches(self) -> DataSetType:
if not hasattr(self, "_validation_batches"):
logger.warning("Validation batches not loaded, loading from disk.")
self._validation_batches = torch.load(
conf.path.validation, map_location=torch.device("cpu")
)
logger.warning("Finished loading.")
logger.warning(
f"Finished loading. Type is{type(self._validation_batches)}"
)
return self._validation_batches
@property
def testing_batches(self):
def testing_batches(self) -> DataSetType:
if not hasattr(self, "_testing_batches"):
logger.warning("Testing batches not loaded, loading from disk.")
self._testing_batches = torch.load(
......@@ -155,7 +144,7 @@ Processing testing batches, queuing {len(self.validation_chunks)} chunks."""
logger.warning("Finished loading.")
return self._testing_batches
def queue_epoch(self, n_skip_events=0):
def queue_epoch(self, n_skip_events=0) -> None:
n_skip_chunks = n_skip_events // conf.loader.chunksize
# Cycle Epochs
n_skip_chunks = n_skip_chunks % len(self.training_chunks)
......@@ -177,5 +166,5 @@ Skipping {n_skip_events} events => {n_skip_chunks} chunks and {n_skip_batches} b
_ = next(self.qfseq)
logger.info(f"Skipped {n_skip_batches} batches.")
def __iter__(self):
def __iter__(self) -> qfseq:
return iter(self.qfseq)
......@@ -28,7 +28,7 @@ validation steps: {relative_improvement*100}%"""
train_state.holder.save_checkpoint()
train_state.writer.flush()
train_state.writer.close()
logger.warn("Early Stopping criteria fulfilled")
logger.warning("Early Stopping criteria fulfilled")
OmegaConf.save(train_state.state, conf.path.complete_state)
if not conf.debug:
train_state.experiment.log_other("ended", True)
......
......@@ -57,7 +57,7 @@ class ModelHolder:
def __load_checkpoint(self):
if not os.path.isfile(conf.path.checkpoint):
logger.warn("Proceeding without loading checkpoint.")
logger.warning("Proceeding without loading checkpoint.")
return
checkpoint = torch.load(conf.path.checkpoint, map_location=device)
......@@ -70,7 +70,7 @@ class ModelHolder:
self.optim.load_state_dict(checkpoint["optim"])
self.best_model_state = checkpoint["best_model"]
logger.warn(
logger.warning(
"Loading model from checkpoint at"
+ f" epoch {self.state['epoch']}"
+ f" batch {self.state['ibatch']}"
......
......@@ -11,7 +11,7 @@ from .holder import model_holder
def profile_procedure() -> None:
logger.warn(
logger.warning(
"Starting profiling with state\n" + OmegaConf.to_yaml(model_holder.state)
)
model_holder.writer = setup_writer()
......
import time
from typing import Dict, Union
import torch
import torch_geometric
......@@ -18,7 +19,8 @@ from .validate import validate
def training_step(
batch: torch_geometric.data.Batch, train_state: TrainState
batch: Union[torch_geometric.data.Batch, Dict[str, torch.Tensor]],
train_state: TrainState,
) -> None:
train_state.holder.optim.zero_grad()
output = train_state.holder.model(batch)
......@@ -36,7 +38,7 @@ def training_step(
def training_procedure() -> None:
logger.warn(
logger.warning(
"Starting training with state\n" + OmegaConf.to_yaml(model_holder.state)
)
train_state = TrainState(
......
import torch
from torch import nn
from fgsim.config import conf