From f782e29a8404bd7451d04288d621d99be68b4d2e Mon Sep 17 00:00:00 2001 From: Marmaduke Woodman <mmwoodman@gmail.com> Date: Fri, 12 May 2017 16:37:16 +0200 Subject: [PATCH] more updates towards loopy generated kernel --- examples/hackathon.py | 40 ++++++++++++++++++++++++++++++++++------ tvb_hpc/model.py | 12 ++++++++++-- tvb_hpc/network.py | 4 ++-- 3 files changed, 46 insertions(+), 10 deletions(-) diff --git a/examples/hackathon.py b/examples/hackathon.py index 9339ee6..4f9594d 100644 --- a/examples/hackathon.py +++ b/examples/hackathon.py @@ -28,17 +28,45 @@ scm = scheme.EulerStep(osc.dt) scm_knl = scm.kernel(target) scm_knl = lp.fix_parameters(scm_knl, nsvar=len(osc.state_sym)) +from tvb_hpc.base import BaseKernel + +class Tavg(BaseKernel): + + def __init__(self, model): + self.model = model + self.novar = min(len(model.input_sym), len(model.obsrv_sym)) + + def kernel_domains(self): + return "{ [i_node, i_ovar]: 0 <= i_node < nnode and 0 <= i_ovar < novar }" + + def kernel_isns(self): + return ["tavg[i_node, i_ovar] = tavg[i_node, i_ovar] + obsrv[i_time % ntime, i_node, i_ovar]"] + + def kernel_data(self): + return 'tavg obsrv'.split() + + def kernel_dtypes(self): + return {'tavg,obsrv': np.float32} + +tavg = Tavg(osc) +tavg_knl = tavg.kernel(target=target) +tavg_knl = lp.fix_parameters(tavg_knl, novar=tavg.novar) + # fuse kernels -knls = osc_knl, net_knl, scm_knl -data_flow = [ ('input', 1, 0), ('diffs', 0, 2), ('drift', 0, 2), ] +knls = osc_knl, net_knl, scm_knl, tavg_knl +data_flow = [ ('input', 1, 0), ('diffs', 0, 2), ('drift', 0, 2), ('obsrv', 2, 3) ] knl = lp.fuse_kernels(knls, data_flow=data_flow) # and time step knl = lp.to_batched(knl, 'nstep', [], 'i_step', sequential=True) -knl = lp.fix_parameters(knl, i_time=pm.parse('(i_step + i_step_0) % ntime')) +knl = lp.fix_parameters(knl, i_time=pm.parse('i_step + i_step_0')) knl.args.append(lp.ValueArg('i_step_0', np.uintc)) knl = lp.add_dtypes(knl, {'i_step_0': np.uintc}) +code, _ = lp.generate_code(knl) +with open('numba_kernel.py', 'w') as fd: + fd.write(code) + # TODO add outer time loop & prange over subjects? # load connectivity TODO util function / class @@ -55,7 +83,7 @@ col = sw.indices.astype(np.uintc) row = sw.indptr.astype(np.uintc) # choose param space -ng = 32 +ng = 8 couplings = np.logspace(1.6, 3.0, ng) speeds = np.logspace(0.0, 2.0, ng) @@ -64,12 +92,12 @@ LOG.info('trace.nbytes %.3f MB', trace.nbytes / 2**20) for j, (speed, coupling) in enumerate(itertools.product(speeds, couplings)): lnz = (lengths[nz] / speed / osc.dt).astype(np.uintc) state, input, param, drift, diffs, _ = osc.prep_arrays(nnode) - obsrv = np.zeros((lnz.max() + 3, nnode, 1), np.float32) + obsrv = np.zeros((lnz.max() + 3, nnode, 2), np.float32) for i in range(trace.shape[1]): knl(10, nnode, obsrv.shape[0], state, input, param, drift, diffs, obsrv, nnz, lnz, row, col, wnz, a=coupling, i_step_0=i*10) - trace[j, i] = state[:, 0] + trace[j, i] = obsrv[i*10 % obsrv.shape[0], :, 1] print(j) # check correctness diff --git a/tvb_hpc/model.py b/tvb_hpc/model.py index 2e63c44..ff93f1d 100644 --- a/tvb_hpc/model.py +++ b/tvb_hpc/model.py @@ -131,7 +131,7 @@ class BaseModel(BaseKernel): fmt = { 'drift': '{kind}[i_node, {i}] = {expr}', 'diffs': '{kind}[i_node, {i}] = {expr}', - 'obsrv': '{kind}[i_time, i_node, {i}] = {expr}', + 'obsrv': '{kind}[i_time % ntime, i_node, {i}] = {expr}', } for kind in 'drift diffs obsrv'.split(): exprs = getattr(self, kind + '_sym') @@ -161,7 +161,15 @@ class Kuramoto(BaseModel): drift = 'omega + I', diffs = 0, obsrv = 'theta', 'sin(theta)' - const = {'omega': 1.0} + const = {'omega': 1.0, 'pi': np.pi} + + def kernel_isns(self): + isns = list(super().kernel_isns()) + isns.append( + 'state[i_node, 0] = ' + ' if(state[i_node, 0]>pi, state[i_node, 0]-2*pi, ' + ' if(state[i_node, 0]<-pi, state[i_node, 0]+2*pi, state[i_node, 0]))') + return isns class HMJE(BaseModel): diff --git a/tvb_hpc/network.py b/tvb_hpc/network.py index 2e242de..0389815 100644 --- a/tvb_hpc/network.py +++ b/tvb_hpc/network.py @@ -62,8 +62,8 @@ class Network(BaseKernel): # substitute pre_syn and post_syn for obsrv data pre_expr = subst_vars( expr=pre, - pre_syn=pm.parse('obsrv[i_time - delays[j_node], col[j_node], k]'), # k is var idx - post_syn=pm.parse('obsrv[i_time, i_node, k]'), + pre_syn=pm.parse('obsrv[(i_time - delays[j_node]) % ntime, col[j_node], k]'), # k is var idx + post_syn=pm.parse('obsrv[i_time % ntime, i_node, k]'), ) # build weighted sum over nodes sum = subst_vars( -- GitLab