From f782e29a8404bd7451d04288d621d99be68b4d2e Mon Sep 17 00:00:00 2001
From: Marmaduke Woodman <mmwoodman@gmail.com>
Date: Fri, 12 May 2017 16:37:16 +0200
Subject: [PATCH] more updates towards loopy generated kernel

---
 examples/hackathon.py | 40 ++++++++++++++++++++++++++++++++++------
 tvb_hpc/model.py      | 12 ++++++++++--
 tvb_hpc/network.py    |  4 ++--
 3 files changed, 46 insertions(+), 10 deletions(-)

diff --git a/examples/hackathon.py b/examples/hackathon.py
index 9339ee6..4f9594d 100644
--- a/examples/hackathon.py
+++ b/examples/hackathon.py
@@ -28,17 +28,45 @@ scm = scheme.EulerStep(osc.dt)
 scm_knl = scm.kernel(target)
 scm_knl = lp.fix_parameters(scm_knl, nsvar=len(osc.state_sym))
 
+from tvb_hpc.base import BaseKernel
+
+class Tavg(BaseKernel):
+
+    def __init__(self, model):
+        self.model = model
+        self.novar = min(len(model.input_sym), len(model.obsrv_sym))
+
+    def kernel_domains(self):
+        return "{ [i_node, i_ovar]: 0 <= i_node < nnode and 0 <= i_ovar < novar }"
+
+    def kernel_isns(self):
+        return ["tavg[i_node, i_ovar] = tavg[i_node, i_ovar] + obsrv[i_time % ntime, i_node, i_ovar]"]
+
+    def kernel_data(self):
+        return 'tavg obsrv'.split()
+
+    def kernel_dtypes(self):
+        return {'tavg,obsrv': np.float32}
+
+tavg = Tavg(osc)
+tavg_knl = tavg.kernel(target=target)
+tavg_knl = lp.fix_parameters(tavg_knl, novar=tavg.novar)
+
 # fuse kernels
-knls = osc_knl, net_knl, scm_knl
-data_flow = [ ('input', 1, 0), ('diffs', 0, 2), ('drift', 0, 2), ]
+knls = osc_knl, net_knl, scm_knl, tavg_knl
+data_flow = [ ('input', 1, 0), ('diffs', 0, 2), ('drift', 0, 2), ('obsrv', 2, 3) ]
 knl = lp.fuse_kernels(knls, data_flow=data_flow)
 
 # and time step
 knl = lp.to_batched(knl, 'nstep', [], 'i_step', sequential=True)
-knl = lp.fix_parameters(knl, i_time=pm.parse('(i_step + i_step_0) % ntime'))
+knl = lp.fix_parameters(knl, i_time=pm.parse('i_step + i_step_0'))
 knl.args.append(lp.ValueArg('i_step_0', np.uintc))
 knl = lp.add_dtypes(knl, {'i_step_0': np.uintc})
 
+code, _ = lp.generate_code(knl)
+with open('numba_kernel.py', 'w') as fd:
+    fd.write(code)
+
 # TODO add outer time loop & prange over subjects?
 
 # load connectivity TODO util function / class
@@ -55,7 +83,7 @@ col = sw.indices.astype(np.uintc)
 row = sw.indptr.astype(np.uintc)
 
 # choose param space
-ng = 32
+ng = 8
 couplings = np.logspace(1.6, 3.0, ng)
 speeds = np.logspace(0.0, 2.0, ng)
 
@@ -64,12 +92,12 @@ LOG.info('trace.nbytes %.3f MB', trace.nbytes / 2**20)
 for j, (speed, coupling) in enumerate(itertools.product(speeds, couplings)):
     lnz = (lengths[nz] / speed / osc.dt).astype(np.uintc)
     state, input, param, drift, diffs, _ = osc.prep_arrays(nnode)
-    obsrv = np.zeros((lnz.max() + 3, nnode, 1), np.float32)
+    obsrv = np.zeros((lnz.max() + 3, nnode, 2), np.float32)
     for i in range(trace.shape[1]):
         knl(10, nnode, obsrv.shape[0], state, input, param,
                 drift, diffs, obsrv, nnz, lnz, row, col, wnz,
                 a=coupling, i_step_0=i*10)
-        trace[j, i] = state[:, 0]
+        trace[j, i] = obsrv[i*10 % obsrv.shape[0], :, 1]
     print(j)
 
 # check correctness
diff --git a/tvb_hpc/model.py b/tvb_hpc/model.py
index 2e63c44..ff93f1d 100644
--- a/tvb_hpc/model.py
+++ b/tvb_hpc/model.py
@@ -131,7 +131,7 @@ class BaseModel(BaseKernel):
         fmt = {
             'drift': '{kind}[i_node, {i}] = {expr}',
             'diffs': '{kind}[i_node, {i}] = {expr}',
-            'obsrv': '{kind}[i_time, i_node, {i}] = {expr}',
+            'obsrv': '{kind}[i_time % ntime, i_node, {i}] = {expr}',
         }
         for kind in 'drift diffs obsrv'.split():
             exprs = getattr(self, kind + '_sym')
@@ -161,7 +161,15 @@ class Kuramoto(BaseModel):
     drift = 'omega + I',
     diffs = 0,
     obsrv = 'theta', 'sin(theta)'
-    const = {'omega': 1.0}
+    const = {'omega': 1.0, 'pi': np.pi}
+
+    def kernel_isns(self):
+        isns = list(super().kernel_isns())
+        isns.append(
+            'state[i_node, 0] = '
+            ' if(state[i_node, 0]>pi, state[i_node, 0]-2*pi, '
+            '  if(state[i_node, 0]<-pi, state[i_node, 0]+2*pi, state[i_node, 0]))')
+        return isns
 
 
 class HMJE(BaseModel):
diff --git a/tvb_hpc/network.py b/tvb_hpc/network.py
index 2e242de..0389815 100644
--- a/tvb_hpc/network.py
+++ b/tvb_hpc/network.py
@@ -62,8 +62,8 @@ class Network(BaseKernel):
         # substitute pre_syn and post_syn for obsrv data
         pre_expr = subst_vars(
             expr=pre,
-            pre_syn=pm.parse('obsrv[i_time - delays[j_node], col[j_node], k]'),  # k is var idx
-            post_syn=pm.parse('obsrv[i_time, i_node, k]'),
+            pre_syn=pm.parse('obsrv[(i_time - delays[j_node]) % ntime, col[j_node], k]'),  # k is var idx
+            post_syn=pm.parse('obsrv[i_time % ntime, i_node, k]'),
         )
         # build weighted sum over nodes
         sum = subst_vars(
-- 
GitLab