Skip to content

Commit c266963

Browse files
authored
Fix accessing of third factor in modulatory spikes handler in synapse models (#1317)
1 parent 35177ce commit c266963

File tree

4 files changed

+276
-13
lines changed

4 files changed

+276
-13
lines changed

pynestml/codegeneration/resources_nest/point_neuron/common/SynapseHeader.h.jinja2

Lines changed: 38 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -475,15 +475,15 @@ inline void set_{{ variable.get_name() }}(const {{ declarations.print_variable_t
475475
* Update internal state (``S_``) of the synapse according to the dynamical equations defined in the model and the statements in the ``update`` block.
476476
**/
477477
inline void
478-
update_internal_state_(double t_start, double timestep, const {{synapseName}}CommonSynapseProperties& cp);
478+
update_internal_state_(double t_start, double timestep, const {{ synapseName }}CommonSynapseProperties& cp);
479479

480480
void recompute_internal_variables();
481481

482482
std::string get_name() const;
483483

484484
public:
485485
// this line determines which common properties to use
486-
typedef {{synapseName}}CommonSynapseProperties CommonPropertiesType;
486+
typedef {{ synapseName }}CommonSynapseProperties CommonPropertiesType;
487487

488488
typedef Connection< targetidentifierT > ConnectionBase;
489489

@@ -846,7 +846,7 @@ void get_entry_from_continuous_variable_history(double t,
846846
{#- post spike based: grab the entry from the post spiking history buffer #}
847847
{%- if continuous_state_buffering_method == "post_spike_based" %}
848848
{%- for var_name in paired_neuron.state_vars_that_need_continuous_buffering %}
849-
{%- set var_name_post = utils.get_var_name_tuples_of_neuron_synapse_pair(continuous_post_ports, var_name) %}
849+
{%- set var_name_post = utils.get_var_name_tuples_of_neuron_synapse_pair(continuous_post_ports, var_name) %}
850850
const double __{{ var_name }} = start->{{ var_name }}_;
851851
{%- endfor %}
852852
{%- endif %}
@@ -930,7 +930,7 @@ void get_entry_from_continuous_variable_history(double t,
930930
* update synapse internal state from `t_lastspike_` to `__t_spike`
931931
**/
932932
{%- if vt_ports is defined and vt_ports|length > 0 %}
933-
{%- set vt_port = vt_ports[0] %}
933+
{%- set vt_port = vt_ports[0] %}
934934
process_{{vt_port}}_spikes_( vt_spikes, t_lastspike_, __t_spike, cp );
935935
{%- endif %}
936936

@@ -1058,16 +1058,32 @@ constexpr ConnectionModelProperties {{synapseName}}< targetidentifierT >::proper
10581058
{%- endif %}
10591059
{%- if vt_ports is defined and vt_ports|length > 0 %}
10601060
{%- set vt_port = vt_ports[0] %}
1061+
1062+
/**
1063+
* Handler for volume transmitter spikes
1064+
**/
10611065
template < typename targetidentifierT >
10621066
void
1063-
{{synapseName}}< targetidentifierT >::process_{{vt_port}}_spikes_( const std::vector< spikecounter >& vt_spikes,
1067+
{{synapseName}}< targetidentifierT >::process_{{ vt_port }}_spikes_( const std::vector< spikecounter >& vt_spikes,
10641068
double t0,
10651069
double t1,
10661070
const {{synapseName}}CommonSynapseProperties& cp )
10671071
{
10681072
#ifdef DEBUG
10691073
std::cout << "[synapse " << this << "] {{ synapseName }}::process_{{ vt_port }}_spikes_(): t0 = " << t0 << ", t1 = " << t1 << "\n";
10701074
#endif
1075+
1076+
const size_t tid = kernel().vp_manager.get_thread_id();
1077+
{%- if paired_neuron_name is not none and paired_neuron_name|length > 0 %}
1078+
typedef {{ paired_neuron_name }} post_neuron_t;
1079+
{{ paired_neuron_name }}* __target = static_cast< {{ paired_neuron_name }}* >(get_target(tid));
1080+
assert(__target);
1081+
{% for var_name in paired_neuron.state_vars_that_need_continuous_buffering %}
1082+
{%- set var_name_post = utils.get_var_name_tuples_of_neuron_synapse_pair(continuous_post_ports, var_name) %}
1083+
const double __{{ var_name }} = ((post_neuron_t*)(__target))->get_{{ var_name_post }}();
1084+
{%- endfor %}
1085+
{%- endif %}
1086+
10711087
// process dopa spikes in (t0, t1]
10721088
// propagate weight from t0 to t1
10731089
if ( ( vt_spikes.size() > vt_spikes_idx_ + 1 )
@@ -1529,8 +1545,9 @@ inline void
15291545
}
15301546

15311547
{%- if vt_ports is defined and vt_ports|length > 0 %}
1548+
15321549
/**
1533-
* Update to end of timestep ``t_trig``, while processing vt spikes and post spikes
1550+
* Update synapse internal state to end of timestep ``t_trig``, while processing vt spikes and post spikes
15341551
**/
15351552
template < typename targetidentifierT >
15361553
{%- if not (nest_version.startswith("v2") or nest_version.startswith("v3.0") or nest_version.startswith("v3.1") or nest_version.startswith("v3.2") or nest_version.startswith("v3.3") or nest_version.startswith("v3.4")) %}
@@ -1554,8 +1571,15 @@ inline void
15541571

15551572
const size_t tid = kernel().vp_manager.get_thread_id();
15561573
{%- if paired_neuron_name is not none and paired_neuron_name|length > 0 %}
1574+
typedef {{ paired_neuron_name }} post_neuron_t;
15571575
{{ paired_neuron_name }}* __target = static_cast< {{ paired_neuron_name }}* >(get_target(tid));
15581576
assert(__target);
1577+
1578+
{% for var_name in paired_neuron.state_vars_that_need_continuous_buffering %}
1579+
{%- set var_name_post = utils.get_var_name_tuples_of_neuron_synapse_pair(continuous_post_ports, var_name) %}
1580+
const double __{{ var_name }} = ((post_neuron_t*)(__target))->get_{{ var_name_post }}();
1581+
{%- endfor %}
1582+
15591583
{%- else %}
15601584
Node* __target = get_target( tid );
15611585
{%- endif %}
@@ -1602,7 +1626,7 @@ inline void
16021626
{#- post spike based: grab the entry from the post spiking history buffer #}
16031627
{%- if continuous_state_buffering_method == "post_spike_based" %}
16041628
{%- for var_name in paired_neuron.state_vars_that_need_continuous_buffering %}
1605-
{%- set var_name_post = utils.get_var_name_tuples_of_neuron_synapse_pair(continuous_post_ports, var_name) %}
1629+
{%- set var_name_post = utils.get_var_name_tuples_of_neuron_synapse_pair(continuous_post_ports, var_name) %}
16061630
const double __{{ var_name }} = start->{{ var_name }}_;
16071631
{%- endfor %}
16081632
{%- endif %}
@@ -1616,14 +1640,16 @@ inline void
16161640
{%- filter indent(6, True) %}
16171641
{%- if post_ports is defined %}
16181642
{%- for post_port in spiking_post_ports %}
1643+
{%- set dynamics = synapse.get_on_receive_block(post_port) %}
1644+
{%- if dynamics is not none %}
16191645
/**
1620-
* NESTML generated onReceive code block for postsynaptic port "{{post_port}}" begins here!
1646+
* NESTML generated onReceive code block for postsynaptic port "{{ post_port }}" begins here!
16211647
**/
16221648

1623-
{%- set dynamics = synapse.get_on_receive_block(post_port) %}
1624-
{%- with ast = dynamics.get_stmts_body() %}
1625-
{%- include "directives_cpp/StmtsBody.jinja2" %}
1626-
{%- endwith %}
1649+
{%- with ast = dynamics.get_stmts_body() %}
1650+
{%- include "directives_cpp/StmtsBody.jinja2" %}
1651+
{%- endwith %}
1652+
{%- endif %}
16271653
{%- endfor %}
16281654
{%- endif %}
16291655
{%- endfilter %}
@@ -1659,7 +1685,6 @@ inline void
16591685
t_lastspike_ = t_trig;
16601686
}
16611687

1662-
16631688
{%- endif %}
16641689

16651690
} // namespace {{ synapseName }};
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# test_continuous_third_factor_in_modulatory_spikes_handler_neuron
2+
# ################################################################
3+
#
4+
# Description
5+
# +++++++++++
6+
#
7+
# Like ``iaf_psc_exp``, but with a dummy postsynaptic third-factor variable.
8+
#
9+
# Copyright statement
10+
# +++++++++++++++++++
11+
#
12+
# This file is part of NEST.
13+
#
14+
# Copyright (C) 2004 The NEST Initiative
15+
#
16+
# NEST is free software: you can redistribute it and/or modify
17+
# it under the terms of the GNU General Public License as published by
18+
# the Free Software Foundation, either version 2 of the License, or
19+
# (at your option) any later version.
20+
#
21+
# NEST is distributed in the hope that it will be useful,
22+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
23+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
24+
# GNU General Public License for more details.
25+
#
26+
# You should have received a copy of the GNU General Public License
27+
# along with NEST. If not, see <http://www.gnu.org/licenses/>.
28+
#
29+
model test_continuous_third_factor_in_modulatory_spikes_handler_neuron:
30+
31+
state:
32+
V_m mV = E_L # Membrane potential
33+
refr_t ms = 0 ms # Refractory period timer
34+
I_syn_exc pA = 0 pA
35+
I_syn_inh pA = 0 pA
36+
e_trace_dend real = 42
37+
38+
equations:
39+
I_syn_exc' = -I_syn_exc / tau_syn_exc
40+
I_syn_inh' = -I_syn_inh / tau_syn_inh
41+
V_m' = -(V_m - E_L) / tau_m + (I_syn_exc - I_syn_inh + I_e + I_stim) / C_m
42+
refr_t' = -1e3 * ms/s # refractoriness is implemented as an ODE, representing a timer counting back down to zero. XXX: TODO: This should simply read ``refr_t' = -1 / s`` (see https://github.com/nest/nestml/issues/984)
43+
44+
parameters:
45+
C_m pF = 250 pF # Capacitance of the membrane
46+
tau_m ms = 10 ms # Membrane time constant
47+
tau_syn_inh ms = 2 ms # Time constant of inhibitory synaptic current
48+
tau_syn_exc ms = 2 ms # Time constant of excitatory synaptic current
49+
refr_T ms = 2 ms # Duration of refractory period
50+
E_L mV = -70 mV # Resting potential
51+
V_reset mV = -70 mV # Reset value of the membrane potential
52+
V_th mV = -55 mV # Spike threshold potential
53+
54+
# constant external input current
55+
I_e pA = 0 pA
56+
57+
input:
58+
exc_spikes <- excitatory spike
59+
inh_spikes <- inhibitory spike
60+
I_stim pA <- continuous
61+
62+
output:
63+
spike
64+
65+
update:
66+
if refr_t > 0 ms:
67+
# neuron is absolute refractory, do not evolve V_m
68+
integrate_odes(I_syn_exc, I_syn_inh, refr_t)
69+
else:
70+
# neuron not refractory
71+
integrate_odes(I_syn_exc, I_syn_inh, V_m)
72+
73+
onReceive(exc_spikes):
74+
I_syn_exc += exc_spikes * pA * s
75+
76+
onReceive(inh_spikes):
77+
I_syn_inh += inh_spikes * pA * s
78+
79+
onCondition(refr_t <= 0 ms and V_m >= V_th):
80+
# threshold crossing
81+
refr_t = refr_T # start of the refractory period
82+
V_m = V_reset
83+
emit_spike()
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# test_continuous_third_factor_in_modulatory_spikes_handler_synapse
2+
# #################################################################
3+
#
4+
# Copyright statement
5+
# +++++++++++++++++++
6+
#
7+
# This file is part of NEST.
8+
#
9+
# Copyright (C) 2004 The NEST Initiative
10+
#
11+
# NEST is free software: you can redistribute it and/or modify
12+
# it under the terms of the GNU General Public License as published by
13+
# the Free Software Foundation, either version 2 of the License, or
14+
# (at your option) any later version.
15+
#
16+
# NEST is distributed in the hope that it will be useful,
17+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
18+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
19+
# GNU General Public License for more details.
20+
#
21+
# You should have received a copy of the GNU General Public License
22+
# along with NEST. If not, see <http://www.gnu.org/licenses/>.
23+
#
24+
model test_continuous_third_factor_in_modulatory_spikes_handler_synapse:
25+
state:
26+
w real = 1.0
27+
e_trace_d_read real = 0.0
28+
e_trace_d_mod real = 0.0
29+
30+
parameters:
31+
d ms = 1 ms # Synaptic transmission delay
32+
33+
input:
34+
pre_spikes <- spike
35+
post_spikes <- spike
36+
mod_spikes <- spike
37+
e_trace_d real <- continuous
38+
39+
output:
40+
spike(w real, d ms)
41+
42+
onReceive(pre_spikes):
43+
e_trace_d_read = e_trace_d
44+
println("e_trace_d_read: {e_trace_d_read}")
45+
emit_spike(w, d)
46+
47+
onReceive(mod_spikes):
48+
e_trace_d_mod = e_trace_d
49+
println("e_trace_d_mod: {e_trace_d_mod}")
50+
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
# -*- coding: utf-8 -*-
2+
#
3+
# test_continuous_third_factor_in_modulatory_spikes_handler.py
4+
#
5+
# This file is part of NEST.
6+
#
7+
# Copyright (C) 2004 The NEST Initiative
8+
#
9+
# NEST is free software: you can redistribute it and/or modify
10+
# it under the terms of the GNU General Public License as published by
11+
# the Free Software Foundation, either version 2 of the License, or
12+
# (at your option) any later version.
13+
#
14+
# NEST is distributed in the hope that it will be useful,
15+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
16+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17+
# GNU General Public License for more details.
18+
#
19+
# You should have received a copy of the GNU General Public License
20+
# along with NEST. If not, see <http://www.gnu.org/licenses/>.
21+
22+
import numpy as np
23+
import pytest
24+
import re
25+
26+
import nest
27+
28+
from pynestml.codegeneration.nest_tools import NESTTools
29+
from pynestml.frontend.pynestml_frontend import generate_nest_target
30+
31+
32+
@pytest.mark.skipif(NESTTools.detect_nest_version().startswith("v2"),
33+
reason="This test does not support NEST 2")
34+
class TestContinuousThirdFactorInModulatorySpikesHandler:
35+
r"""Test that continuous third factors can be accessed in the modulatory spikes handler in a synapse model."""
36+
37+
@pytest.fixture(scope="module", autouse=True)
38+
def generate_all_models(self):
39+
codegen_opts = {"continuous_state_buffering_method": "post_spike_based",
40+
"neuron_synapse_pairs": [{"neuron": "test_continuous_third_factor_in_modulatory_spikes_handler_neuron",
41+
"synapse": "test_continuous_third_factor_in_modulatory_spikes_handler_synapse",
42+
"vt_ports": ["mod_spikes"],
43+
"post_ports": ["post_spikes",
44+
["e_trace_d", "e_trace_dend"]]}],
45+
"delay_variable": {"test_continuous_third_factor_in_modulatory_spikes_handler_synapse": "d"},
46+
"weight_variable": {"test_continuous_third_factor_in_modulatory_spikes_handler_synapse": "w"}}
47+
48+
generate_nest_target(input_path=["tests/nest_tests/resources/test_continuous_third_factor_in_modulatory_spikes_handler_neuron.nestml",
49+
"tests/nest_tests/resources/test_continuous_third_factor_in_modulatory_spikes_handler_synapse.nestml"],
50+
logging_level="DEBUG",
51+
suffix="_nestml",
52+
codegen_opts=codegen_opts)
53+
54+
def test_continuous_third_factor_in_modulatory_spikes_handler(self):
55+
t_stop = 100. # [ms]
56+
57+
nest.ResetKernel()
58+
nest.Install("nestmlmodule")
59+
60+
# create spike_generator
61+
vt_sg = nest.Create("poisson_generator",
62+
params={"rate": 20.})
63+
64+
# create volume transmitter
65+
vt = nest.Create("volume_transmitter")
66+
vt_parrot = nest.Create("parrot_neuron")
67+
nest.Connect(vt_sg, vt_parrot)
68+
sr = nest.Create("spike_recorder")
69+
nest.Connect(vt_parrot, sr)
70+
nest.Connect(vt_parrot, vt, syn_spec={"synapse_model": "static_synapse",
71+
"weight": 1.,
72+
"delay": 1.}) # delay is ignored!
73+
74+
nestml_model_name = "test_continuous_third_factor_in_modulatory_spikes_handler_neuron_nestml__with_test_continuous_third_factor_in_modulatory_spikes_handler_synapse_nestml"
75+
synapse_model_name = "test_continuous_third_factor_in_modulatory_spikes_handler_synapse_nestml__with_test_continuous_third_factor_in_modulatory_spikes_handler_neuron_nestml"
76+
77+
neuron1 = nest.Create("parrot_neuron")
78+
neuron2 = nest.Create(nestml_model_name)
79+
nest.CopyModel(synapse_model_name, "my_nestml_synapse", {"volume_transmitter": vt})
80+
nest.Connect(neuron1, neuron2, syn_spec={"synapse_model": "my_nestml_synapse", "weight": 1., "delay": 1.})
81+
82+
sg = nest.Create("spike_generator",
83+
params={"spike_times": [20., 50., 80.]})
84+
nest.Connect(sg, neuron1, syn_spec={"synapse_model": "static_synapse", "weight": 1., "delay": 1.})
85+
86+
multimeter2 = nest.Create("multimeter")
87+
88+
V_m_specifier = "V_m" # "delta_V_m"
89+
nest.SetStatus(multimeter2, {"record_from": [V_m_specifier]})
90+
91+
nest.Connect(multimeter2, neuron2)
92+
93+
sd_pre_neuron = nest.Create("spike_recorder")
94+
sd_post_neuron = nest.Create("spike_recorder")
95+
96+
nest.Connect(neuron1, sd_pre_neuron)
97+
nest.Connect(neuron2, sd_post_neuron)
98+
99+
nest.Simulate(t_stop)
100+
101+
syn = nest.GetConnections(neuron1, neuron2)[0]
102+
e_trace_d_read = syn.e_trace_d_read
103+
e_trace_d_mod = syn.e_trace_d_mod
104+
np.testing.assert_allclose(e_trace_d_read, 42)
105+
np.testing.assert_allclose(e_trace_d_mod, 42)

0 commit comments

Comments
 (0)