Commit b095755b authored by mova's avatar mova
Browse files

training on sparse matrices

parent 1d50d1a5
......@@ -5,6 +5,7 @@ Conversion from list of hit to graph
# import sys
from os.path import isfile, splitext
from typing import Dict
import awkward as ak
import numpy as np
......@@ -31,21 +32,12 @@ geo_lup = pd.read_pickle(pickle_lup_path)
def event_to_graph(event: ak.highlevel.Record) -> GraphType:
# yappi.start()
# plt.close()
# plt.cla()
# plt.clf()
# event_hitenergies = ak.to_numpy(event.rechit_energy)
# event_hitenergies = event_hitenergies[event_hitenergies < 0.04]
# plt.hist(event_hitenergies, bins=100)
# plt.savefig("hist.png")
# print("foo")
key_id = conf.loader.braches.id
key_hit_energy = conf.loader.braches.hit_energy
# Sum up the sim hits
id_to_energy_dict = {}
id_to_energy_dict: Dict[int, float] = {}
for hit_energy, detid in zip(event[key_hit_energy], event[key_id]):
# TODO fix the detids
if detid not in geo_lup.index:
......@@ -94,9 +86,9 @@ def event_to_graph(event: ak.highlevel.Record) -> GraphType:
static_feature_matrix = geo_lup_filtered[conf.loader.cell_prop_keys].values
# check for NaNs of the detid is not present in the geolut
df = geo_lup_filtered[conf.loader.cell_prop_keys]
props_df = geo_lup_filtered[conf.loader.cell_prop_keys]
# problemid= df[df.isnull().any(axis =1)].index[0] # 2160231891
invalids = df[df.isnull().any(axis=1)].index
invalids = props_df[props_df.isnull().any(axis=1)].index
if len(invalids) != 0:
logger.error(f"No match in geo lut for detids {invalids}.")
raise ValueError
......@@ -136,7 +128,7 @@ def event_to_graph(event: ak.highlevel.Record) -> GraphType:
[hit_energies[detid_lut[detid]] for detid in detid_isolated]
)
hlvs["isolated_E_fraction"] = hlvs["isolated_energy"] / hlvs["sum_energy"]
for ivar, var in enumerate(["x", "y", "z"]):
for var in ["x", "y", "z"]:
var_weighted = static_feature_matrix[:, 1] * hit_energies
mean = np.mean(var_weighted)
hlvs[var + "_mean"] = mean
......
......@@ -16,8 +16,7 @@ from torch_geometric.data import Data as GraphType
from fgsim.config import conf
from fgsim.geo.detid_to_graph import event_to_graph
from . import qf
from fgsim.io import qf
# Load files
ds_path = Path(conf.path.dataset)
......@@ -68,8 +67,18 @@ def geo_batch(list_of_graphs: List[GraphType]) -> GraphType:
return batch
def magic_do_nothing(elem: GraphType) -> GraphType:
return elem
ToSparseTranformer = torch_geometric.transforms.ToSparseTensor(
remove_edge_index=True, fill_cache=True
)
def add_sparse_adj_mtx(batch: GraphType) -> GraphType:
batch = ToSparseTranformer(batch)
return batch
def magic_do_nothing(batch: GraphType) -> GraphType:
return batch
# Collect the steps
......@@ -93,6 +102,7 @@ def process_seq():
# conf.loader.num_workers_stack,
# name="split_layer_subgraphs",
# ),
qf.ProcessStep(add_sparse_adj_mtx, 1, name="add_sparse_adj_mtx"),
# Needed for outputs to stay in order.
qf.ProcessStep(
magic_do_nothing,
......
"""Message pass model"""
import torch
from torch import nn
from torch_geometric.nn import GINConv, global_add_pool
......@@ -52,17 +54,18 @@ class ModelClass(torch.nn.Module):
def forward(self, batch):
def addstatic(
x, mask=torch.ones(len(batch.x), dtype=torch.bool, device=device)
feature_mtx,
mask=torch.ones(len(batch.x), dtype=torch.bool, device=device),
):
return torch.hstack(
(
x[mask],
feature_mtx[mask],
batch.feature_mtx_static[mask],
batch.hlvs[batch.batch[mask]],
)
)
x = torch.hstack(
feature_mtx = torch.hstack(
(
batch.x,
torch.zeros(
......@@ -72,11 +75,13 @@ class ModelClass(torch.nn.Module):
)
for _ in range(conf.model.nprop):
x = self.conv(addstatic(x), batch.edge_index)
x = self.node_dnn(addstatic(x))
feature_mtx = self.conv(addstatic(feature_mtx), batch.adj_t)
feature_mtx = self.node_dnn(addstatic(feature_mtx))
cuda_clear()
x = global_add_pool(x, batch.batch, size=batch.num_graphs)
feature_mtx = global_add_pool(
feature_mtx, batch.batch, size=batch.num_graphs
)
x = self.hlv_dnn(torch.hstack((batch.hlvs, x)))
return x
feature_mtx = self.hlv_dnn(torch.hstack((batch.hlvs, feature_mtx)))
return feature_mtx
......@@ -3,6 +3,7 @@ from typing import Dict
import numpy as np
import torch
import torch_geometric
import torch_sparse
from fgsim.config import device
from fgsim.utils.typecheck import istype
......@@ -15,6 +16,8 @@ def move_batch_to_device(batch, device):
def move(element):
if torch.is_tensor(element):
return element.to(device)
if isinstance(element, torch_sparse.SparseTensor):
return element.to(device)
elif isinstance(element, (list, set, tuple)):
return type(element)((move(ee) for ee in element))
elif isinstance(element, dict):
......@@ -50,18 +53,20 @@ def move_batch_to_device(batch, device):
return batch_new
def clone_or_copy(e):
if torch.is_tensor(e):
return e.clone()
elif isinstance(e, (list, set, tuple)):
return type(e)((clone_or_copy(ee) for ee in e))
elif isinstance(e, dict):
return {k: clone_or_copy(ee) for k, ee in e.items()}
elif isinstance(e, (int, str, float)):
return e
elif type(e).__module__ == np.__name__:
return e
elif e is None:
def clone_or_copy(element):
if torch.is_tensor(element):
return element.clone()
if isinstance(element, torch_sparse.SparseTensor):
return element.clone()
elif isinstance(element, (list, set, tuple)):
return type(element)((clone_or_copy(ee) for ee in element))
elif isinstance(element, dict):
return {k: clone_or_copy(ee) for k, ee in element.items()}
elif isinstance(element, (int, str, float)):
return element
elif type(element).__module__ == np.__name__:
return element
elif element is None:
return None
else:
raise ValueError
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment