Commit d8426374 authored by mova's avatar mova
Browse files

add dnn training for hgcal as comparison

parent b37c97ea
......@@ -9,7 +9,7 @@ path:
test: "${path.dataset_processed}/test.pt"
training: "${path.dataset_processed}/training"
training_glob: "*.pt"
geo_lup: "data/geo_hgcal/DetIdLUT_full.root"
geo_lup: "data/geo_hgcal/DetIdLUT.root"
run_path: "wd/${tag}/${hash}"
train_config: "${path.run_path}/resulting_train_config.yaml"
full_config: "${path.run_path}/full_config.yaml"
......
import torch
from torch import nn
from fgsim.config import conf
n_dyn = conf.model.dyn_features
n_hlvs = len(conf.loader.hlvs)
n_node_features = len(conf.loader.cell_prop_keys)
n_all_features = n_dyn + n_hlvs + n_node_features
def hlv():
return nn.Sequential(
nn.Linear(n_hlvs, conf.model.deeplayer_nodes),
nn.ReLU(),
nn.Linear(conf.model.deeplayer_nodes, conf.model.deeplayer_nodes),
nn.ReLU(),
nn.Linear(conf.model.deeplayer_nodes, conf.model.dyn_features),
nn.ReLU(),
)
class ModelClass(torch.nn.Module):
def __init__(self):
super(ModelClass, self).__init__()
self.hlv = hlv()
self.end_lin = nn.Linear(conf.model.dyn_features, 1)
def forward(self, batch):
X = batch.hlvs
X = self.hlv(X)
X = self.end_lin(X)
return X
......@@ -48,14 +48,13 @@
"version": "0.2.0",
"configurations": [
{
"name": "t sparse",
"name": "t dnn",
"type": "python",
"request": "launch",
"module": "fgsim",
"args": [
"--tag",
"sparse",
// "--debug",
"dnn",
"train",
]
},
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment