Commit b90e81c6 authored by mova's avatar mova
Browse files

implement hgcal dataloader and training

parent 24434b27
......@@ -11,11 +11,11 @@ from .utils.cli import args
# Add a custum resolver to OmegaConf allowing for divisions
# Give int back if you can:
def divide(a, b):
if a // b == a / b:
return a // b
def divide(numerator, denominator):
if numerator // denominator == numerator / denominator:
return numerator // denominator
else:
return a / b
return numerator / denominator
OmegaConf.register_new_resolver("div", divide, replace=True)
......@@ -61,9 +61,8 @@ hyperparameters = OmegaConf.masked_copy(
OmegaConf.resolve(hyperparameters)
# Compute the hash
conf_hash = str(hashlib.sha1(str(hyperparameters).encode()).hexdigest()[:7])
conf["hash"] = conf_hash
hyperparameters["hash"] = conf_hash
conf["hash"] = str(hashlib.sha1(str(hyperparameters).encode()).hexdigest()[:7])
hyperparameters["hash"] = conf["hash"]
os.makedirs(conf.path.run_path, exist_ok=True)
......
......@@ -4,24 +4,30 @@ Conversion from list of hit to graph
# import sys
from os.path import isfile, splitext
import awkward as ak
import numpy as np
import pandas as pd
import torch
import torch_geometric
import uproot
# import yappi
from torch_geometric.data import Data as GraphType
from fgsim.config import conf
from fgsim.utils.logger import logger
# with uproot.open(conf.path.geo_lup) as rf:
# geo_lup = rf['analyzer/tree;1'].arrays(library='ak')
# geo_lup = ak.to_pandas(geo_lup)
# geo_lup.set_index('globalid', inplace=True)
# geo_lup.to_pickle('data/hgcal/DetIdLUT_full.pd')
pickle_lup_path = splitext("conf.path.geo_lup")[0] + ".pd"
if not isfile(pickle_lup_path):
with uproot.open(conf.path.geo_lup) as rf:
geo_lup = rf["analyzer/tree;1"].arrays(library="ak")
geo_lup = ak.to_pandas(geo_lup)
geo_lup.set_index("globalid", inplace=True)
geo_lup.to_pickle(pickle_lup_path)
geo_lup = pd.read_pickle("data/hgcal/DetIdLUT_full.pd")
geo_lup = pd.read_pickle(pickle_lup_path)
def event_to_graph(event: ak.highlevel.Record) -> GraphType:
......@@ -35,21 +41,36 @@ def event_to_graph(event: ak.highlevel.Record) -> GraphType:
# plt.hist(event_hitenergies, bins=100)
# plt.savefig("hist.png")
# print("foo")
key_id = conf.loader.braches.id
key_hit_energy = conf.loader.braches.hit_energy
# Sum up the sim hits
id_to_energy_dict = {}
for hit_energy, detid in zip(event[key_hit_energy], event[key_id]):
# TODO fix the detids
if detid not in geo_lup.index:
continue
if detid in id_to_energy_dict:
id_to_energy_dict[detid] += hit_energy
else:
id_to_energy_dict[detid] = hit_energy
detids = np.array(list(id_to_energy_dict.keys()), dtype=np.uint)
hit_energies = np.array(list(id_to_energy_dict.values()), dtype=np.float32)
# assert np.all(hit_energies==[id_to_energy_dict[x] for x in detids])
# cut the hits 81% energy 56% of the hits @ 57 GeV
energyfilter = event["rechit_energy"] > 0.0048
energyfilter = hit_energies > 0 # 0.0048
num_nodes = sum(energyfilter)
event_hitenergies = ak.to_numpy(event["rechit_energy"][energyfilter])
event_detids = ak.to_numpy(event[conf.loader.id_key][energyfilter])
# print(
# (
# sum(event_hitenergies) / sum(event["rechit_energy"]),
# len(event_hitenergies) / len(event["rechit_energy"]),
# )
# )
hit_energies = hit_energies[energyfilter]
detids = detids[energyfilter]
# _, counts = np.unique(ak.to_numpy(detids), return_counts=True)
# assert all(counts==1)
# Filter out the rows/detids that are not in the event
geo_lup_filtered = geo_lup.loc[event_detids]
geo_lup_filtered = geo_lup.reindex(index=detids)
neighbor_keys = [
"next",
"previous",
......@@ -68,18 +89,19 @@ def event_to_graph(event: ak.highlevel.Record) -> GraphType:
]
neighbor_df = geo_lup_filtered[neighbor_keys]
static_feature_keys = [
"x",
"y",
"z",
"celltype",
"issilicon",
]
# compute static features
static_feature_matrix = geo_lup_filtered[static_feature_keys].values
static_feature_matrix = geo_lup_filtered[conf.loader.cell_prop_keys].values
# check for NaNs of the detid is not present in the geolut
df = geo_lup_filtered[conf.loader.cell_prop_keys]
# problemid= df[df.isnull().any(axis =1)].index[0] # 2160231891
invalids = df[df.isnull().any(axis=1)].index
if len(invalids) != 0:
logger.error(f"No match in geo lut for detids {invalids}.")
raise ValueError
# number detids in events
detid_lut = {det_id: i for i, det_id in enumerate(event_detids)}
detid_lut = {det_id: i for i, det_id in enumerate(detids)}
# ## construct adj mtx from detids
# Select the detids from the current event
......@@ -93,14 +115,35 @@ def event_to_graph(event: ak.highlevel.Record) -> GraphType:
# Filter the out the zero entries:
edge_index_detid = edge_index_detid[:, edge_index_detid[1] != 0]
# Filter neigbhors not within the array
eventid_set = set(event_detids)
eventid_set = set(detids)
target_detid_in_event_mask = np.array(
[(lambda x: x in eventid_set)(x) for x in edge_index_detid[1]]
)
edge_index_detid = edge_index_detid[:, target_detid_in_event_mask]
# shift from detids to nodenumbers
edge_index = np.vectorize(detid_lut.get)(edge_index_detid)
if edge_index_detid.size != 0:
edge_index = np.vectorize(detid_lut.get)(edge_index_detid)
else:
edge_index = edge_index_detid
# Collects the hlvs now:
hlvs = {}
hlvs["sum_energy"] = sum(hit_energies)
detid_isolated = set(detids) - set(np.unique(edge_index))
hlvs["num_isolated"] = len(detid_isolated)
hlvs["isolated_energy"] = sum(
[hit_energies[detid_lut[detid]] for detid in detid_isolated]
)
hlvs["isolated_E_fraction"] = hlvs["isolated_energy"] / hlvs["sum_energy"]
for ivar, var in enumerate(["x", "y", "z"]):
var_weighted = static_feature_matrix[:, 1] * hit_energies
mean = np.mean(var_weighted)
hlvs[var + "_mean"] = mean
hlvs[var + "_std"] = np.std(var_weighted)
hlvs[var + "_mom3"] = np.power(
np.sum(np.power(var_weighted - mean, 3)), 1 / 3
)
# yappi.stop()
# current_module = sys.modules[__name__]
......@@ -108,15 +151,19 @@ def event_to_graph(event: ak.highlevel.Record) -> GraphType:
# filter_callback=lambda x: yappi.module_matches(x, [current_module])
# ).sort("name", "desc").print_all()
# Build the graph
graph = torch_geometric.data.Data(
x=torch.tensor(event_hitenergies, dtype=torch.float32).reshape(
(num_nodes, 1)
),
y=event["gen_energy"],
x=torch.tensor(hit_energies, dtype=torch.float).reshape((num_nodes, 1)),
y=torch.tensor(event[conf.loader.braches.energy][0], dtype=torch.float),
edge_index=torch.tensor(edge_index, dtype=torch.long),
)
graph.static_features = torch.tensor(
np.asarray(static_feature_matrix, dtype=np.float32), dtype=torch.float32
graph.feature_mtx_static = torch.tensor(
np.asarray(static_feature_matrix, dtype=np.float32), dtype=torch.float
).reshape((num_nodes, -1))
graph.hlvs = torch.tensor(
[hlvs[k] for k in conf.loader.hlvs], dtype=torch.float
).reshape((1, -1))
return graph
......@@ -32,7 +32,7 @@ if not os.path.isfile(conf.path.ds_lenghts):
len_dict = {}
for fn in files:
with uproot.open(fn) as rfile:
len_dict[fn] = rfile[conf.loader.rootprefix][conf.yvar].num_entries
len_dict[fn] = rfile[conf.loader.rootprefix].num_entries
with open(conf.path.ds_lenghts, "w") as f:
yaml.dump(len_dict, f, Dumper=yaml.SafeDumper)
else:
......@@ -60,7 +60,6 @@ def read_chunk(chunks: List[Tuple[str, int, int]]) -> ak.highlevel.Array:
output = ak.concatenate(chunks_list)
# remove the double gen energy
output["gen_energy"] = output["gen_energy"][:, 0]
return output
......
......@@ -96,8 +96,8 @@ must queue an epoch via `queue_epoch()` and iterate over the instance of the cla
# Check that there is a reasonable amount of data
assert (
len(self.validation_chunks) + len(self.testing_chunks)
> len(self.training_chunks) / 2
)
< len(self.training_chunks) / 2
), "Dataset to small"
# Assign the sequence with the specifice steps needed to process the dataset.
self.qfseq = qf.Sequence(*process_seq())
......
......@@ -43,7 +43,7 @@ def prediction_procedure():
with torch.no_grad():
prediction = torch.squeeze(train_state.holder.model(batch).T)
yhat = prediction.to("cpu").numpy()
y = batch[conf.yvar].to("cpu").numpy()
y = batch.y.to("cpu").numpy()
yhats.append(yhat)
ys.append(y)
......
......@@ -6,6 +6,7 @@ from comet_ml.experiment import BaseExperiment
from torch.utils.tensorboard import SummaryWriter
from fgsim.config import conf
from fgsim.geo.batchtype import BatchType
from fgsim.io.queued_dataset import QueuedDataLoader
from .holder import ModelHolder
......@@ -22,7 +23,7 @@ class TrainState:
writer: Optional[SummaryWriter]
experiment: Optional[BaseExperiment]
def write_trainstep_logs(self) -> None:
def write_trainstep_logs(self, batch: BatchType) -> None:
if conf.debug:
return
traintime = self.state.time_training_done - self.state.batch_start_time
......@@ -73,3 +74,7 @@ class TrainState:
step=self.state["grad_step"],
epoch=self.state["epoch"],
)
# self.experiment.log_histogram(
# experiment, gradmap, epoch * steps_per_epoch, prefix="gradient"
# )
......@@ -28,9 +28,9 @@ def training_step(
prediction = torch.squeeze(output.T)
# Check for the global_add_pool bug in pytorch_geometric
# https://github.com/rusty1s/pytorch_geometric/issues/2895
if len(prediction) != len(batch[conf.yvar]):
if len(prediction) != len(batch.y):
return
loss = train_state.holder.lossf(y=batch[conf.yvar], yhat=prediction)
loss = train_state.holder.lossf(y=batch.y, yhat=prediction)
loss.backward()
train_state.holder.optim.step()
......@@ -85,7 +85,7 @@ def training_procedure() -> None:
train_state.state.time_io_done = time.time()
training_step(batch, train_state)
train_state.state.time_training_done = time.time()
train_state.write_trainstep_logs()
train_state.write_trainstep_logs(batch)
train_state.state.ibatch += 1
train_state.state.processed_events += conf.loader.batch_size
train_state.state["grad_step"] += 1
......
......@@ -21,7 +21,7 @@ def validate(train_state: TrainState) -> None:
batch_gpu = move_batch_to_device(batch, device)
with torch.no_grad():
prediction = torch.squeeze(train_state.holder.model(batch_gpu).T)
loss = train_state.holder.lossf(y=batch_gpu[conf.yvar], yhat=prediction)
loss = train_state.holder.lossf(y=batch_gpu.y, yhat=prediction)
losses.append(loss)
del batch_gpu
......
import torch
from torch import nn
from torch_geometric.nn import GINConv, global_add_pool
from fgsim.config import conf
from fgsim.config import conf, device
from fgsim.utils.cuda_clear import cuda_clear
nfeatures = conf.model.dyn_features + conf.model.static_features
n_dyn = conf.model.dyn_features
n_hlvs = len(conf.loader.hlvs)
n_node_features = len(conf.loader.cell_prop_keys)
n_all_features = n_dyn + n_hlvs + n_node_features
def get_hlv_dnn():
return nn.Sequential(
nn.Linear(n_hlvs + n_dyn, conf.model.deeplayer_nodes),
nn.ReLU(),
nn.Linear(conf.model.deeplayer_nodes, conf.model.deeplayer_nodes),
nn.ReLU(),
nn.Linear(conf.model.deeplayer_nodes, conf.model.deeplayer_nodes),
nn.ReLU(),
nn.Linear(conf.model.deeplayer_nodes, 1),
nn.ReLU(),
)
def get_node_dnn():
return nn.Sequential(
nn.Linear(n_all_features, conf.model.deeplayer_nodes),
nn.ReLU(),
nn.Linear(conf.model.deeplayer_nodes, conf.model.deeplayer_nodes),
nn.ReLU(),
nn.Linear(conf.model.deeplayer_nodes, conf.model.deeplayer_nodes),
nn.ReLU(),
nn.Linear(conf.model.deeplayer_nodes, conf.model.dyn_features),
nn.ReLU(),
)
def get_conv():
conv_dnn = get_node_dnn()
return GINConv(conv_dnn, train_eps=True)
class ModelClass(torch.nn.Module):
def __init__(self):
super(ModelClass, self).__init__()
self.end_lin = nn.Linear(2, 1)
self.conv = get_conv()
self.node_dnn = get_node_dnn()
self.hlv_dnn = get_hlv_dnn()
def forward(self, batch):
X = torch.vstack([batch["ECAL_E"], batch["HCAL_E"]]).float().T
X = self.end_lin(X)
return X
def addstatic(
x, mask=torch.ones(len(batch.x), dtype=torch.bool, device=device)
):
return torch.hstack(
(
x[mask],
batch.feature_mtx_static[mask],
batch.hlvs[batch.batch[mask]],
)
)
x = torch.hstack(
(
batch.x,
torch.zeros(
(len(batch.x), conf.model.dyn_features - 1), device=device
),
)
)
for _ in range(conf.model.nprop):
x = self.conv(addstatic(x), batch.edge_index)
x = self.node_dnn(addstatic(x))
cuda_clear()
x = global_add_pool(x, batch.batch, size=batch.num_graphs)
x = self.hlv_dnn(torch.hstack((batch.hlvs, x)))
return x
......@@ -35,11 +35,17 @@ def get_experiment():
-- if unsuccessfull -- generates a new one."""
comet_conf = OmegaConf.load("fgsim/comet.yaml")
api = comet_ml.API(comet_conf.api_key)
project_name = (
conf.comet_project_name
if "comet_project_name" in conf
else comet_conf.project_name
)
experiments = [
exp
for exp in api.get(
workspace=comet_conf.workspace, project_name=comet_conf.project_name
workspace=comet_conf.workspace,
project_name=project_name,
)
if exp.get_parameters_summary("hash") != []
]
......@@ -54,7 +60,7 @@ def get_experiment():
raise ValueError("Experiment does not exist in comet.ml!")
logger.warning("Creating new experiment.")
new_api_exp = api._create_experiment(
workspace=comet_conf.workspace, project_name=comet_conf.project_name
workspace=comet_conf.workspace, project_name=project_name
)
exp_key = new_api_exp.id
elif len(qres) == 1:
......
from typing import Dict
import numpy as np
import torch
import torch_geometric
......@@ -8,17 +9,22 @@ from fgsim.utils.typecheck import istype
def move_batch_to_device(batch, device):
"""This function moves batches (eg. from torch_geometric) to a specified device
and also takes into account manually assinged properties."""
def move(element):
if torch.is_tensor(element):
return element.to(device)
elif isinstance(element, list):
return [move(ee) for ee in element]
elif isinstance(element, (list, set, tuple)):
return type(element)((move(ee) for ee in element))
elif isinstance(element, dict):
return {k: move(ee) for k, ee in element.items()}
elif element is None:
return None
elif isinstance(element, (int, str, float)):
return element
elif type(element).__module__ == np.__name__:
return element
else:
raise ValueError
......@@ -38,19 +44,23 @@ def move_batch_to_device(batch, device):
elif istype(batch, Dict[str, torch.Tensor]):
batch_new = {k: move(v) for k, v in batch.items()}
else:
raise RuntimeError("Cannot move this object to the torch device.")
raise RuntimeError(
"Cannot move this object to the torch device, invalid type."
)
return batch_new
def clone_or_copy(e):
if torch.is_tensor(e):
return e.clone()
elif isinstance(e, list):
return [clone_or_copy(ee) for ee in e]
elif isinstance(e, (list, set, tuple)):
return type(e)((clone_or_copy(ee) for ee in e))
elif isinstance(e, dict):
return {k: clone_or_copy(ee) for k, ee in e.items()}
elif isinstance(e, (int, str, float)):
return e
elif type(e).__module__ == np.__name__:
return e
elif e is None:
return None
else:
......@@ -58,6 +68,9 @@ def clone_or_copy(e):
def clone_batch(batch):
"""This function clones batches (eg. from torch_geometric) and
also takes into account manually assinged properties. This is needed
when using torch_geometric with torch.multiprocessing"""
batch_cloned = torch_geometric.data.Batch().from_dict(
{k: clone_or_copy(v) for k, v in batch.to_dict().items()}
)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment