Commit 79ff3d4b authored by mova's avatar mova
Browse files

reduce the number of open files, switch sharing method

parent 114a30fe
......@@ -9,7 +9,7 @@ path:
test: "${path.dataset_processed}/test.pt"
training: "${path.dataset_processed}/training"
training_glob: "*.pt"
geo_lup: "data/hgcal/DetIdLUT_full.root"
geo_lup: "data/geo_hgcal/DetIdLUT_full.root"
run_path: "wd/${tag}/${hash}"
train_config: "${path.run_path}/resulting_train_config.yaml"
full_config: "${path.run_path}/full_config.yaml"
......
......@@ -18,16 +18,16 @@ dataset_path = Path(conf.path.training)
dataset_path.mkdir(parents=True, exist_ok=True)
# reading from the filesystem
def read_file(file: Path) -> GraphType:
def read_file(file: Path) -> List[GraphType]:
batch_list: List[GraphType] = torch.load(file)
return batch_list
# Collect the steps
def preprocessed_seq():
return (
qf.ProcessStep(read_file, 2, name="read_chunk"),
qf.ProcessStep(read_file, 1, name="read_chunk"),
Queue(1),
qf.pack.UnpackStep(),
Queue(conf.loader.prefetch_batches),
)
......@@ -23,8 +23,8 @@ pickling_support.install()
# Make it work ()
mp.set_sharing_strategy("file_descriptor")
# mp.set_sharing_strategy("file_system")
# mp.set_sharing_strategy("file_descriptor")
mp.set_sharing_strategy("file_system")
# Reworked according to the recommendations in
# https://pytorch.org/docs/stable/multiprocessing.html
......
......@@ -105,12 +105,9 @@ must queue an epoch via `queue_epoch()` and iterate over the instance of the cla
self.qfseq = qf.Sequence(*process_seq())
if conf.loader.preprocess_training:
self.preprocessed_files = [
str(e)
for e in sorted(
Path(conf.path.training).glob(conf.path.training_glob)
)
]
self.preprocessed_files = list(
sorted(Path(conf.path.training).glob(conf.path.training_glob))
)
if conf.command != "preprocess":
if (
......
"""Message pass model"""
import torch
from torch import nn
from torch_geometric.data import Data as GraphType
from torch_geometric.nn import GINConv, global_add_pool
from fgsim.config import conf, device
from fgsim.utils.cuda_clear import cuda_clear
n_dyn = conf.model.dyn_features
n_hlvs = len(conf.loader.hlvs)
n_node_features = len(conf.loader.cell_prop_keys)
n_all_features = n_dyn + n_hlvs + n_node_features
def get_hlv_dnn():
return nn.Sequential(
nn.Linear(n_hlvs + n_dyn, conf.model.deeplayer_nodes),
nn.ReLU(),
nn.Linear(conf.model.deeplayer_nodes, conf.model.deeplayer_nodes),
nn.ReLU(),
nn.Linear(conf.model.deeplayer_nodes, conf.model.deeplayer_nodes),
nn.ReLU(),
nn.Linear(conf.model.deeplayer_nodes, 1),
nn.ReLU(),
)
def get_node_dnn():
return nn.Sequential(
nn.Linear(n_all_features, conf.model.deeplayer_nodes),
nn.ReLU(),
nn.Linear(conf.model.deeplayer_nodes, conf.model.deeplayer_nodes),
nn.ReLU(),
nn.Linear(conf.model.deeplayer_nodes, conf.model.deeplayer_nodes),
nn.ReLU(),
nn.Linear(conf.model.deeplayer_nodes, conf.model.dyn_features),
nn.ReLU(),
)
def get_conv():
conv_dnn = get_node_dnn()
return GINConv(conv_dnn, train_eps=True)
class ModelClass(torch.nn.Module):
def __init__(self):
super(ModelClass, self).__init__()
self.conv = get_conv()
self.node_dnn = get_node_dnn()
self.hlv_dnn = get_hlv_dnn()
def forward(self, batch: GraphType):
def addstatic(
feature_mtx,
mask=torch.ones(len(batch.x), dtype=torch.bool, device=device),
):
return torch.hstack(
(
feature_mtx[mask],
batch.feature_mtx_static[mask],
batch.hlvs[batch.batch[mask]],
)
)
feature_mtx = torch.hstack(
(
batch.x,
torch.zeros(
(len(batch.x), conf.model.dyn_features - 1), device=device
),
)
)
for _ in range(conf.model.nprop):
feature_mtx = self.conv(addstatic(feature_mtx), batch.edge_index)
feature_mtx = self.node_dnn(addstatic(feature_mtx))
cuda_clear()
feature_mtx = global_add_pool(
feature_mtx, batch.batch, size=batch.num_graphs
)
feature_mtx = self.hlv_dnn(torch.hstack((batch.hlvs, feature_mtx)))
return feature_mtx
......@@ -2,6 +2,7 @@
import torch
from torch import nn
from torch_geometric.data import Data as GraphType
from torch_geometric.nn import GINConv, global_add_pool
from fgsim.config import conf, device
......@@ -52,7 +53,7 @@ class ModelClass(torch.nn.Module):
self.node_dnn = get_node_dnn()
self.hlv_dnn = get_hlv_dnn()
def forward(self, batch):
def forward(self, batch: GraphType):
def addstatic(
feature_mtx,
mask=torch.ones(len(batch.x), dtype=torch.bool, device=device),
......
......@@ -48,13 +48,13 @@
"version": "0.2.0",
"configurations": [
{
"name": "t hgcal",
"name": "t sparse",
"type": "python",
"request": "launch",
"module": "fgsim",
"args": [
"--tag",
"hgcal",
"sparse",
// "--debug",
"train",
]
......@@ -66,7 +66,7 @@
"module": "fgsim",
"args": [
"--tag",
"hgcal",
"sparse",
"preprocess",
]
},
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment