Commit 82733e7c authored by mova's avatar mova
Browse files

fix deepconv

parent b92df5e6
import torch
from torch import nn
from torch_geometric.nn import GCNConv, GINConv, global_add_pool
from torch_geometric.nn import GINConv, global_add_pool
from ..config import conf, device
from ..utils.cuda_clear import cuda_clear
......@@ -54,11 +54,9 @@ def batch_to_hlvs(batch):
class ModelClass(torch.nn.Module):
def __init__(self):
super(ModelClass, self).__init__()
self.upscale_conv = GCNConv(1, conf.model.dyn_features)
self.inlayer_conv = get_conv()
self.forward_conv = get_conv()
self.backward_conv = get_conv()
self.node_dnn = get_node_dnn()
self.hlv_dnn = get_hlv_dnn()
def forward(self, batch):
......@@ -79,20 +77,22 @@ class ModelClass(torch.nn.Module):
),
)
)
for _ in range(conf.model.nprop):
# forwards
# the last time is just inlayer MPL
for ilayer in range(conf.nlayers):
# forwards
# the last time is just inlayer MPL
inner_inp_mask = batch.mask_inp_innerL[ilayer]
inner_outp_mask = batch.mask_outp_innerL[ilayer]
sourcelayermask = batch.layers == ilayer
targetlayermask = batch.layers == ilayer + 1
partial_inner = self.inlayer_conv(
addstatic(x, inner_inp_mask),
batch.inner_edges_per_layer[ilayer],
)
x[batch.layers == ilayer] = partial_inner[inner_outp_mask]
x[sourcelayermask] = partial_inner[inner_outp_mask]
del partial_inner, inner_inp_mask, inner_outp_mask
if ilayer == conf.nlayers - 1:
......@@ -103,12 +103,10 @@ class ModelClass(torch.nn.Module):
addstatic(x, forward_inp_mask),
batch.forward_edges_per_layer[ilayer],
)
x[batch.layers == ilayer + 1] = partial_forward[forward_outp_mask]
x[targetlayermask] = partial_forward[forward_outp_mask]
del partial_forward, forward_inp_mask, forward_outp_mask
cuda_clear()
x = nn.functional.relu(x)
# backward
# ilayer goes from nlayers - 1 to nlayers - 2 to ... 1
for ilayer in range(conf.nlayers - 1, 0, -1):
......@@ -116,11 +114,15 @@ class ModelClass(torch.nn.Module):
# the last time is just inlayer MPL
backward_inp_mask = batch.mask_inp_backwardL[ilayer]
backward_outp_mask = batch.mask_outp_backwardL[ilayer]
sourcelayermask = batch.layers == ilayer - 2
targetlayermask = batch.layers == ilayer - 1
partial_backward = self.backward_conv(
addstatic(x, backward_inp_mask),
batch.backward_edges_per_layer[ilayer],
)
x[batch.layers == ilayer - 1] = partial_backward[backward_outp_mask]
x[targetlayermask] = partial_backward[backward_outp_mask]
del partial_backward, backward_inp_mask, backward_outp_mask
inner_inp_mask = batch.mask_inp_innerL[ilayer - 1]
......@@ -129,14 +131,10 @@ class ModelClass(torch.nn.Module):
addstatic(x, inner_inp_mask),
batch.inner_edges_per_layer[ilayer - 1],
)
x[batch.layers == ilayer - 1] = partial_inner[inner_outp_mask]
x[targetlayermask] = partial_inner[inner_outp_mask]
del partial_inner, inner_inp_mask, inner_outp_mask
cuda_clear()
x = nn.functional.relu(x)
# DNN on the feature matrix
x = self.node_dnn(addstatic(x))
cuda_clear()
x = global_add_pool(x, batch.batch, size=batch.num_graphs)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment